diff options
1784 files changed, 35189 insertions, 219927 deletions
diff --git a/.bazelrc b/.bazelrc deleted file mode 100644 index ef214bcfa..000000000 --- a/.bazelrc +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Build with C++17. -build --cxxopt=-std=c++17 - -# Display the current git revision in the info block. -build --stamp --workspace_status_command tools/workspace_status.sh - -# Enable remote execution so actions are performed on the remote systems. -build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com -build:remote --project_id=gvisor-rbe -build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance -# Enable authentication. This will pick up application default credentials by -# default. You can use --google_credentials=some_file.json to use a service -# account credential instead. -build:remote --google_default_credentials=true -build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" - -# Add a custom platform and toolchain that builds in a privileged docker -# container, which is required by our syscall tests. -build:remote --host_platform=//:rbe_ubuntu1604 -build:remote --extra_toolchains=//:cc-toolchain-clang-x86_64-default -build:remote --extra_execution_platforms=//:rbe_ubuntu1604 -build:remote --platforms=//:rbe_ubuntu1604 -build:remote --crosstool_top=@rbe_default//cc:toolchain -build:remote --jobs=50 -build:remote --remote_timeout=3600 -# RBE requires a strong hash function, such as SHA256. -startup --host_jvm_args=-Dbazel.DigestFunction=SHA256 - -# Set flags for uploading to BES in order to view results in the Bazel Build -# Results UI. -build:results --bes_backend="buildeventservice.googleapis.com" -build:results --bes_timeout=60s -build:results --tls_enabled - -# Output BES results url -build:results --bes_results_url="https://source.cloud.google.com/results/invocations/" - -# Set flags for uploading to BES without Remote Build Execution. -build:results-local --bes_backend="buildeventservice.googleapis.com" -build:results-local --bes_timeout=60s -build:results-local --tls_enabled=true -build:results-local --auth_enabled=true -build:results-local --spawn_strategy=local -build:results-local --remote_cache=remotebuildexecution.googleapis.com -build:results-local --remote_timeout=3600 -build:results-local --bes_results_url="https://source.cloud.google.com/results/invocations/" diff --git a/.github/issue_template.md b/.github/issue_template.md deleted file mode 100644 index 77c401d22..000000000 --- a/.github/issue_template.md +++ /dev/null @@ -1,20 +0,0 @@ -Before filling an issue, please consult our FAQ: -https://gvisor.dev/docs/user_guide/faq/ - -Also check that the issue hasn't been reported before. - -If you have a question, please email gvisor-users@googlegroups.com rather than filing a bug. - -If you believe you've found a security issue, please email gvisor-security@googlegroups.com rather than filing a bug. - -If this is your first time compiling or running gVisor, please make sure that your system meets the minimum requirements: https://github.com/google/gvisor#requirements - -For all other issues, please attach debug logs. To get debug logs, follow the -instructions here: https://gvisor.dev/docs/user_guide/debugging/ - -Other useful information to include is: - -* `runsc -v` -* `docker version` or `docker info` if more relevant -* `uname -a` - `git describe` -* Detailed reproduction steps diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 13babef4d..000000000 --- a/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -# Generated bazel symlinks. -/bazel-* diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index a2a260538..000000000 --- a/.travis.yml +++ /dev/null @@ -1,19 +0,0 @@ -language: minimal -sudo: required -dist: xenial -cache: - directories: - - /home/travis/.cache/bazel/ -services: - - docker -matrix: - include: - - os: linux - arch: amd64 - env: RUNSC_PATH=./bazel-bin/runsc/linux_amd64_pure_stripped/runsc - - os: linux - arch: arm64 - env: RUNSC_PATH=./bazel-bin/runsc/linux_arm64_pure_stripped/runsc -script: - - uname -a - - make DOCKER_RUN_OPTIONS="" BAZEL_OPTIONS="build runsc:runsc" bazel && $RUNSC_PATH --alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do ls diff --git a/BUILD b/BUILD deleted file mode 100644 index 5fd929378..000000000 --- a/BUILD +++ /dev/null @@ -1,100 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_path", "nogo") -load("@bazel_gazelle//:def.bzl", "gazelle") - -package(licenses = ["notice"]) - -# The sandbox filegroup is used for sandbox-internal dependencies. -package_group( - name = "sandbox", - packages = [ - "//...", - ], -) - -# gopath defines a directory that is structured in a way that is compatible -# with standard Go tools. Things like godoc, editors and refactor tools should -# work as expected. -# -# The files in this tree are symlinks to the true sources. -go_path( - name = "gopath", - mode = "link", - deps = [ - "//runsc", - - # Packages that are not dependencies of //runsc. - "//pkg/sentry/kernel/memevent", - "//pkg/tcpip/adapters/gonet", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/muxed", - "//pkg/tcpip/link/sharedmem", - "//pkg/tcpip/link/sharedmem/pipe", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/link/waitable", - "//pkg/tcpip/sample/tun_tcp_connect", - "//pkg/tcpip/sample/tun_tcp_echo", - "//pkg/tcpip/transport/tcpconntrack", - ], -) - -# gazelle is a set of build tools. -# -# To update the WORKSPACE from go.mod, use: -# bazel run //:gazelle -- update-repos -from_file=go.mod -gazelle(name = "gazelle") - -# nogo applies checks to all Go source in this repository, enforcing code -# guidelines and restrictions. Note that the tool libraries themselves should -# live in the tools subdirectory (unless they are standard). -nogo( - name = "nogo", - config = "//tools:nogo.js", - visibility = ["//visibility:public"], - deps = [ - "//tools/checkunsafe", - ], -) - -# We need to define a bazel platform and toolchain to specify dockerPrivileged -# and dockerRunAsRoot options, they are required to run tests on the RBE -# cluster in Kokoro. -alias( - name = "rbe_ubuntu1604", - actual = ":rbe_ubuntu1604_r346485", -) - -platform( - name = "rbe_ubuntu1604_r346485", - constraint_values = [ - "@bazel_tools//platforms:x86_64", - "@bazel_tools//platforms:linux", - "@bazel_tools//tools/cpp:clang", - "@bazel_toolchains//constraints:xenial", - "@bazel_toolchains//constraints/sanitizers:support_msan", - ], - remote_execution_properties = """ - properties: { - name: "container-image" - value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" - } - properties: { - name: "dockerAddCapabilities" - value: "SYS_ADMIN" - } - properties: { - name: "dockerPrivileged" - value: "true" - } - """, -) - -toolchain( - name = "cc-toolchain-clang-x86_64-default", - exec_compatible_with = [ - ], - target_compatible_with = [ - ], - toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", - toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", -) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md deleted file mode 100644 index eb6c8edae..000000000 --- a/CODE_OF_CONDUCT.md +++ /dev/null @@ -1,92 +0,0 @@ -# Code of Conduct - -## Our Pledge - -In the interest of fostering an open and welcoming environment, we as -contributors and maintainers pledge to making participation in our project and -our community a harassment-free experience for everyone, regardless of age, body -size, disability, ethnicity, gender identity and expression, level of -experience, education, socio-economic status, nationality, personal appearance, -race, religion, or sexual identity and orientation. - -## Our Standards - -Examples of behavior that contributes to creating a positive environment -include: - -* Using welcoming and inclusive language -* Being respectful of differing viewpoints and experiences -* Gracefully accepting constructive criticism -* Focusing on what is best for the community -* Showing empathy towards other community members - -Examples of unacceptable behavior by participants include: - -* The use of sexualized language or imagery and unwelcome sexual attention or - advances -* Trolling, insulting/derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or electronic - address, without explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting - -## Our Responsibilities - -Project maintainers are responsible for clarifying the standards of acceptable -behavior and are expected to take appropriate and fair corrective action in -response to any instances of unacceptable behavior. - -Project maintainers have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions that are -not aligned to this Code of Conduct, or to ban temporarily or permanently any -contributor for other behaviors that they deem inappropriate, threatening, -offensive, or harmful. - -## Scope - -This Code of Conduct applies both within project spaces and in public spaces -when an individual is representing the project or its community. Examples of -representing a project or community include using an official project e-mail -address, posting via an official social media account, or acting as an appointed -representative at an online or offline event. Representation of a project may be -further defined and clarified by project maintainers. - -This Code of Conduct also applies outside the project spaces when the Project -Steward has a reasonable belief that an individual's behavior may have a -negative impact on the project or its community. - -## Conflict Resolution - -We do not believe that all conflict is bad; healthy debate and disagreement -often yield positive results. However, it is never okay to be disrespectful or -to engage in behavior that violates the project’s code of conduct. - -If you see someone violating the code of conduct, you are encouraged to address -the behavior directly with those involved. Many issues can be resolved quickly -and easily, and this gives people more control over the outcome of their -dispute. If you are unable to resolve the matter for any reason, or if the -behavior is threatening or harassing, report it. We are dedicated to providing -an environment where participants feel welcome and safe. - -Reports should be directed to Jaice Singer DuMars, jaice at google dot com, the -Project Steward for gVisor. It is the Project Steward’s duty to receive and -address reported violations of the code of conduct. They will then work with a -committee consisting of representatives from the Open Source Programs Office and -the Google Open Source Strategy team. If for any reason you are uncomfortable -reaching out the Project Steward, please email opensource@google.com. - -We will investigate every complaint, but you may not receive a direct response. -We will use our discretion in determining when and how to follow up on reported -incidents, which may range from not taking action to permanent expulsion from -the project and project-sponsored spaces. We will notify the accused of the -report and provide them an opportunity to discuss it before any action is taken. -The identity of the reporter will be omitted from the details of the report -supplied to the accused. In potentially harmful situations, such as ongoing -harassment or threats to anyone's safety, we may take action without notice. - -## Attribution - -This Code of Conduct is adapted from the Contributor Covenant, version 1.4, -available at -https://www.contributor-covenant.org/version/1/4/code-of-conduct.html diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index ad8e710da..000000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,151 +0,0 @@ -# Contributing - -Want to contribute? Great! First, read this page. - -### Contributor License Agreement - -Contributions to this project must be accompanied by a Contributor License -Agreement. You (or your employer) retain the copyright to your contribution; -this simply gives us permission to use and redistribute your contributions as -part of the project. Head over to <https://cla.developers.google.com/> to see -your current agreements on file or to sign a new one. - -You generally only need to submit a CLA once, so if you've already submitted one -(even if it was for a different project), you probably don't need to do it -again. - -### Using GOPATH - -Some editors may require the code to be structured in a `GOPATH` directory tree. -In this case, you may use the `:gopath` target to generate a directory tree with -symlinks to the original source files. - -``` -bazel build :gopath -``` - -You can then set the `GOPATH` in your editor to `bazel-bin/gopath`. - -If you use this mechanism, keep in mind that the generated tree is not the -canonical source. You will still need to build and test with `bazel`. New files -will need to be added to the appropriate `BUILD` files, and the `:gopath` target -will need to be re-run to generate appropriate symlinks in the `GOPATH` -directory tree. - -Dependencies can be added by using `go mod get`. In order to keep the -`WORKSPACE` file in sync, run `tools/go_mod.sh` in place of `go mod`. - -### Coding Guidelines - -All Go code should conform to the [Go style guidelines][gostyle]. C++ code -should conform to the [Google C++ Style Guide][cppstyle] and the guidelines -described for [tests][teststyle]. Note that code may be automatically formatted -per the guidelines when merged. - -As a secure runtime, we need to maintain the safety of all of code included in -gVisor. The following rules help mitigate issues. - -Definitions for the rules below: - -`core`: - -* `//pkg/sentry/...` -* Transitive dependencies in `//pkg/...`, etc. - -`runsc`: - -* `//runsc/...` - -Rules: - -* No cgo in `core` or `runsc`. The final binary must be a statically-linked - pure Go binary. - -* Any files importing "unsafe" must have a name ending in `_unsafe.go`. - -* `core` may only depend on the following packages: - - * Itself. - * Go standard library. - * Except (transitively) package "net" (this will result in a non-cgo - binary). Use `//pkg/unet` instead. - * `@org_golang_x_sys//unix:go_default_library` (Go import - `golang.org/x/sys/unix`). - * Generated Go protobuf packages. - * `@com_github_golang_protobuf//proto:go_default_library` (Go import - `github.com/golang/protobuf/proto`). - * `@com_github_golang_protobuf//ptypes:go_default_library` (Go import - `github.com/golang/protobuf/ptypes`). - -* `runsc` may only depend on the following packages: - - * All packages allowed for `core`. - * `@com_github_google_subcommands//:go_default_library` (Go import - `github.com/google/subcommands`). - * `@com_github_opencontainers_runtime_spec//specs_go:go_default_library` - (Go import `github.com/opencontainers/runtime-spec/specs_go`). - -### Code reviews - -Before sending code reviews, run `bazel test ...` to ensure tests are passing. - -Code changes are accepted via [pull request][github]. - -When approved, the change will be submitted by a team member and automatically -merged into the repository. - -### Presubmit checks - -Accessing check logs may require membership in the -[gvisor-dev mailing list][gvisor-dev-list], which is public. - -### Bug IDs - -Some TODOs and NOTEs sprinkled throughout the code have associated IDs of the -form `b/1234`. These correspond to bugs in our internal bug tracker. Eventually -these bugs will be moved to the GitHub Issues, but until then they can simply be -ignored. - -### Build and test with Docker - -`scripts/dev.sh` is a convenient script that builds and installs `runsc` as a -new Docker runtime for you. The scripts tries to extract the runtime name from -your local environment and will print it at the end. You can also customize it. -The script creates one regular runtime and another with debug flags enabled. -Here are a few examples: - -```bash -# Default case (inside branch my-branch) -$ scripts/dev.sh -... -Runtimes my-branch and my-branch-d (debug enabled) setup. -Use --runtime=my-branch with your Docker command. - docker run --rm --runtime=my-branch --rm hello-world - -If you rebuild, use scripts/dev.sh --refresh. -Logs are in: /tmp/my-branch/logs - -# --refresh just updates the runtime binary and doesn't restart docker. -$ git/my_branch> scripts/dev.sh --refresh - -# Using a custom runtime name -$ git/my_branch> scripts/dev.sh my-runtime -... -Runtimes my-runtime and my-runtime-d (debug enabled) setup. -Use --runtime=my-runtime with your Docker command. - docker run --rm --runtime=my-runtime --rm hello-world -``` - -### The small print - -Contributions made by corporations are covered by a different agreement than the -one above, the -[Software Grant and Corporate Contributor License Agreement][gccla]. - -[cppstyle]: https://google.github.io/styleguide/cppguide.html -[gcla]: https://cla.developers.google.com/about/google-individual -[gccla]: https://cla.developers.google.com/about/google-corporate -[github]: https://github.com/google/gvisor/compare -[gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev -[gostyle]: https://github.com/golang/go/wiki/CodeReviewComments -[teststyle]: ./test/ diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 2bfdfec6c..000000000 --- a/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -FROM fedora:31 - -RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel - -RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static - -RUN pip install pycparser - -WORKDIR /gvisor diff --git a/Makefile b/Makefile deleted file mode 100644 index d9531fbd5..000000000 --- a/Makefile +++ /dev/null @@ -1,50 +0,0 @@ -UID := $(shell id -u ${USER}) -GID := $(shell id -g ${USER}) -GVISOR_BAZEL_CACHE := $(shell readlink -f ~/.cache/bazel/) - -# The --privileged is required to run tests. -DOCKER_RUN_OPTIONS ?= --privileged - -all: runsc - -docker-build: - docker build -t gvisor-bazel . - -bazel-shutdown: - docker exec -i gvisor-bazel bazel shutdown && \ - docker kill gvisor-bazel - -bazel-server-start: docker-build - mkdir -p "$(GVISOR_BAZEL_CACHE)" && \ - docker run -d --rm --name gvisor-bazel \ - --user 0:0 \ - -v "$(GVISOR_BAZEL_CACHE):$(HOME)/.cache/bazel/" \ - -v "$(CURDIR):$(CURDIR)" \ - --workdir "$(CURDIR)" \ - --tmpfs /tmp:rw,exec \ - $(DOCKER_RUN_OPTIONS) \ - gvisor-bazel \ - sh -c "while :; do sleep 100; done" && \ - docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --non-unique --gid $(GID) -d $(HOME) gvisor" - -bazel-server: - docker exec gvisor-bazel true || \ - $(MAKE) bazel-server-start - -BAZEL_OPTIONS := build runsc -bazel: bazel-server - docker exec -u $(UID):$(GID) -i gvisor-bazel bazel $(BAZEL_OPTIONS) - -bazel-alias: - @echo "alias bazel='docker exec -u $(UID):$(GID) -i gvisor-bazel bazel'" - -runsc: - $(MAKE) BAZEL_OPTIONS="build runsc" bazel - -tests: - $(MAKE) BAZEL_OPTIONS="test --test_tag_filters runsc_ptrace //test/syscalls/..." bazel - -unit-tests: - $(MAKE) BAZEL_OPTIONS="test //pkg/... //runsc/... //tools/..." bazel - -.PHONY: docker-build bazel-shutdown bazel-server-start bazel-server bazel runsc tests @@ -1,156 +1,5 @@ -![gVisor](g3doc/logo.png) +# gVisor -[![Status](https://storage.googleapis.com/gvisor-build-badges/build.svg)](https://storage.googleapis.com/gvisor-build-badges/build.html) -[![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community) - -## What is gVisor? - -**gVisor** is a user-space kernel, written in Go, that implements a substantial -portion of the Linux system surface. It includes an -[Open Container Initiative (OCI)][oci] runtime called `runsc` that provides an -isolation boundary between the application and the host kernel. The `runsc` -runtime integrates with Docker and Kubernetes, making it simple to run sandboxed -containers. - -## Why does gVisor exist? - -Containers are not a [**sandbox**][sandbox]. While containers have -revolutionized how we develop, package, and deploy applications, running -untrusted or potentially malicious code without additional isolation is not a -good idea. The efficiency and performance gains from using a single, shared -kernel also mean that container escape is possible with a single vulnerability. - -gVisor is a user-space kernel for containers. It limits the host kernel surface -accessible to the application while still giving the application access to all -the features it expects. Unlike most kernels, gVisor does not assume or require -a fixed set of physical resources; instead, it leverages existing host kernel -functionality and runs as a normal user-space process. In other words, gVisor -implements Linux by way of Linux. - -gVisor should not be confused with technologies and tools to harden containers -against external threats, provide additional integrity checks, or limit the -scope of access for a service. One should always be careful about what data is -made available to a container. - -## Documentation - -User documentation and technical architecture, including quick start guides, can -be found at [gvisor.dev][gvisor-dev]. - -## Installing from source - -gVisor currently requires x86\_64 Linux to build, though support for other -architectures may become available in the future. - -### Requirements - -Make sure the following dependencies are installed: - -* Linux 4.14.77+ ([older linux][old-linux]) -* [git][git] -* [Bazel][bazel] 1.2+ -* [Python][python] -* [Docker version 17.09.0 or greater][docker] -* C++ toolchain supporting C++17 (GCC 7+, Clang 5+) -* Gold linker (e.g. `binutils-gold` package on Ubuntu) - -### Building - -Build and install the `runsc` binary: - -``` -bazel build runsc -sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin -``` - -If you don't want to install bazel on your system, you can build runsc in a -Docker container: - -``` -make runsc -sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin -``` - -### Testing - -The test suite can be run with Bazel: - -``` -bazel test //... -``` - -or in a Docker container: - -``` -make unit-tests -make tests -``` - -### Using remote execution - -If you have a [Remote Build Execution][rbe] environment, you can use it to speed -up build and test cycles. - -You must authenticate with the project first: - -``` -gcloud auth application-default login --no-launch-browser -``` - -Then invoke bazel with the following flags: - -``` ---config=remote ---project_id=$PROJECT ---remote_instance_name=projects/$PROJECT/instances/default_instance -``` - -You can also add those flags to your local ~/.bazelrc to avoid needing to -specify them each time on the command line. - -### Using `go get` - -This project uses [bazel][bazel] to build and manage dependencies. A synthetic -`go` branch is maintained that is compatible with standard `go` tooling for -convenience. - -For example, to build `runsc` directly from this branch: - -``` -echo "module runsc" > go.mod -GO111MODULE=on go get gvisor.dev/gvisor/runsc@go -CGO_ENABLED=0 GO111MODULE=on go install gvisor.dev/gvisor/runsc -``` - -Note that this branch is supported in a best effort capacity, and direct -development on this branch is not supported. Development should occur on the -`master` branch, which is then reflected into the `go` branch. - -## Community & Governance - -The governance model is documented in our [community][community] repository. - -The [gvisor-users mailing list][gvisor-users-list] and -[gvisor-dev mailing list][gvisor-dev-list] are good starting points for -questions and discussion. - -## Security Policy - -See [SECURITY.md](SECURITY.md). - -## Contributing - -See [Contributing.md](CONTRIBUTING.md). - -[bazel]: https://bazel.build -[community]: https://gvisor.googlesource.com/community -[docker]: https://www.docker.com -[git]: https://git-scm.com -[gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users -[gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev -[oci]: https://www.opencontainers.org -[old-linux]: https://gvisor.dev/docs/user_guide/networking/#gso -[python]: https://python.org -[rbe]: https://blog.bazel.build/2018/10/05/remote-build-execution.html -[sandbox]: https://en.wikipedia.org/wiki/Sandbox_(computer_security) -[gvisor-dev]: https://gvisor.dev +This branch is a synthetic branch, containing only Go sources, that is +compatible with standard Go tools. See the master branch for authoritative +sources and tests. diff --git a/SECURITY.md b/SECURITY.md deleted file mode 100644 index 154d68cb3..000000000 --- a/SECURITY.md +++ /dev/null @@ -1,11 +0,0 @@ -# Security and Vulnerability Reporting - -Sensitive security-related questions, comments, and reports should be sent to -the [gvisor-security mailing list][gvisor-security-list]. You should receive a -prompt response, typically within 48 hours. - -Policies for security list access, vulnerability embargo, and vulnerability -disclosure are outlined in the [community][community] repository. - -[community]: https://gvisor.googlesource.com/community -[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security diff --git a/WORKSPACE b/WORKSPACE deleted file mode 100644 index d2bbadc63..000000000 --- a/WORKSPACE +++ /dev/null @@ -1,383 +0,0 @@ -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") - -# Load go bazel rules and gazelle. -http_archive( - name = "io_bazel_rules_go", - sha256 = "94f90feaa65c9cdc840cd21f67d967870b5943d684966a47569da8073e42063d", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz", - ], -) - -http_archive( - name = "bazel_gazelle", - sha256 = "d8c45ee70ec39a57e7a05e5027c32b1576cc7f16d9dd37135b0eddde45cf1b10", - urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz", - ], -) - -load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") - -go_rules_dependencies() - -go_register_toolchains( - go_version = "1.14", - nogo = "@//:nogo", -) - -load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") - -gazelle_dependencies() - -# TODO(gvisor.dev/issue/1876): Move the statement to "External repositories" -# block below once 1876 is fixed. -# -# The com_google_protobuf repository below would trigger downloading a older -# version of org_golang_x_sys. If putting this repository statment in a place -# after that of the com_google_protobuf, this statement will not work as -# expectd to download a new version of org_golang_x_sys. -go_repository( - name = "org_golang_x_sys", - importpath = "golang.org/x/sys", - sum = "h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=", - version = "v0.0.0-20200302150141-5c8b2ff67527", -) - -# Load C++ rules. -http_archive( - name = "rules_cc", - sha256 = "67412176974bfce3f4cf8bdaff39784a72ed709fc58def599d1f68710b58d68b", - strip_prefix = "rules_cc-b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e.zip", - "https://github.com/bazelbuild/rules_cc/archive/b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e.zip", - ], -) - -# Load protobuf dependencies. -http_archive( - name = "rules_proto", - sha256 = "602e7161d9195e50246177e7c55b2f39950a9cf7366f74ed5f22fd45750cd208", - strip_prefix = "rules_proto-97d8af4dc474595af3900dd85cb3a29ad28cc313", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz", - "https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz", - ], -) - -load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains") - -rules_proto_dependencies() - -rules_proto_toolchains() - -# Load python dependencies. -git_repository( - name = "rules_python", - commit = "abc4869e02fe9b3866942e89f07b7341f830e805", - remote = "https://github.com/bazelbuild/rules_python.git", - shallow_since = "1583341286 -0500", -) - -load("@rules_python//python:pip.bzl", "pip_import") - -pip_import( - name = "pydeps", - python_interpreter = "python3", - requirements = "//benchmarks:requirements.txt", -) - -load("@pydeps//:requirements.bzl", "pip_install") - -pip_install() - -# Load bazel_toolchain to support Remote Build Execution. -# See releases at https://releases.bazel.build/bazel-toolchains.html -http_archive( - name = "bazel_toolchains", - sha256 = "b5a8039df7119d618402472f3adff8a1bd0ae9d5e253f53fcc4c47122e91a3d2", - strip_prefix = "bazel-toolchains-2.1.1", - urls = [ - "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.1.1/bazel-toolchains-2.1.1.tar.gz", - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.1.1.tar.gz", - ], -) - -# Creates a default toolchain config for RBE. -load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig") - -rbe_autoconfig(name = "rbe_default") - -http_archive( - name = "rules_pkg", - sha256 = "5bdc04987af79bd27bc5b00fe30f59a858f77ffa0bd2d8143d5b31ad8b1bd71c", - url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.0/rules_pkg-0.2.0.tar.gz", -) - -load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") - -rules_pkg_dependencies() - -# Container rules. -http_archive( - name = "io_bazel_rules_docker", - sha256 = "14ac30773fdb393ddec90e158c9ec7ebb3f8a4fd533ec2abbfd8789ad81a284b", - strip_prefix = "rules_docker-0.12.1", - urls = ["https://github.com/bazelbuild/rules_docker/releases/download/v0.12.1/rules_docker-v0.12.1.tar.gz"], -) - -load( - "@io_bazel_rules_docker//repositories:repositories.bzl", - container_repositories = "repositories", -) - -container_repositories() - -load("@io_bazel_rules_docker//repositories:deps.bzl", container_deps = "deps") - -container_deps() - -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_pull", -) - -# This container is built from the Dockerfile in test/iptables/runner. -container_pull( - name = "iptables-test", - digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65", - registry = "gcr.io", - repository = "gvisor-presubmit/iptables-test", -) - -load( - "@io_bazel_rules_docker//go:image.bzl", - _go_image_repos = "repositories", -) - -_go_image_repos() - -# External repositories, in sorted order. -go_repository( - name = "com_github_cenkalti_backoff", - importpath = "github.com/cenkalti/backoff", - sum = "h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8=", - version = "v0.0.0-20190506075156-2146c9339422", -) - -go_repository( - name = "com_github_gofrs_flock", - importpath = "github.com/gofrs/flock", - sum = "h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=", - version = "v0.6.1-0.20180915234121-886344bea079", -) - -go_repository( - name = "com_github_golang_mock", - importpath = "github.com/golang/mock", - sum = "h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=", - version = "v1.3.1", -) - -go_repository( - name = "com_github_google_go-cmp", - importpath = "github.com/google/go-cmp", - sum = "h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=", - version = "v0.2.0", -) - -go_repository( - name = "com_github_google_subcommands", - importpath = "github.com/google/subcommands", - sum = "h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI=", - version = "v0.0.0-20190508160503-636abe8753b8", -) - -go_repository( - name = "com_github_google_uuid", - importpath = "github.com/google/uuid", - sum = "h1:rXQlD9GXkjA/PQZhmEaF/8Pj/sJfdZJK7GJG0gkS8I0=", - version = "v0.0.0-20171129191014-dec09d789f3d", -) - -go_repository( - name = "com_github_kr_pretty", - importpath = "github.com/kr/pretty", - sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=", - version = "v0.2.0", -) - -go_repository( - name = "com_github_kr_pty", - importpath = "github.com/kr/pty", - sum = "h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=", - version = "v1.1.1", -) - -go_repository( - name = "com_github_kr_text", - importpath = "github.com/kr/text", - sum = "h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=", - version = "v0.1.0", -) - -go_repository( - name = "com_github_opencontainers_runtime-spec", - importpath = "github.com/opencontainers/runtime-spec", - sum = "h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=", - version = "v0.1.2-0.20171211145439-b2d941ef6a78", -) - -go_repository( - name = "com_github_syndtr_gocapability", - importpath = "github.com/syndtr/gocapability", - sum = "h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8=", - version = "v0.0.0-20180916011248-d98352740cb2", -) - -go_repository( - name = "com_github_vishvananda_netlink", - importpath = "github.com/vishvananda/netlink", - sum = "h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q=", - version = "v1.0.1-0.20190318003149-adb577d4a45e", -) - -go_repository( - name = "com_github_vishvananda_netns", - importpath = "github.com/vishvananda/netns", - sum = "h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk=", - version = "v0.0.0-20171111001504-be1fbeda1936", -) - -go_repository( - name = "in_gopkg_check_v1", - importpath = "gopkg.in/check.v1", - sum = "h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=", - version = "v1.0.0-20190902080502-41f04d3bba15", -) - -go_repository( - name = "org_golang_x_crypto", - importpath = "golang.org/x/crypto", - sum = "h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=", - version = "v0.0.0-20191011191535-87dc89f01550", -) - -go_repository( - name = "org_golang_x_mod", - importpath = "golang.org/x/mod", - sum = "h1:p1YOIz9H/mGN8k1XkaV5VFAq9+zhN9Obefv439UwRhI=", - version = "v0.2.1-0.20200224194123-e5e73c1b9c72", -) - -go_repository( - name = "org_golang_x_net", - importpath = "golang.org/x/net", - sum = "h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=", - version = "v0.0.0-20190620200207-3b0461eec859", -) - -go_repository( - name = "org_golang_x_sync", - importpath = "golang.org/x/sync", - sum = "h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=", - version = "v0.0.0-20190423024810-112230192c58", -) - -go_repository( - name = "org_golang_x_text", - importpath = "golang.org/x/text", - sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=", - version = "v0.3.0", -) - -go_repository( - name = "org_golang_x_time", - importpath = "golang.org/x/time", - sum = "h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=", - version = "v0.0.0-20191024005414-555d28b269f0", -) - -go_repository( - name = "org_golang_x_tools", - importpath = "golang.org/x/tools", - sum = "h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ=", - version = "v0.0.0-20191119224855-298f0cb1881e", -) - -go_repository( - name = "org_golang_x_xerrors", - importpath = "golang.org/x/xerrors", - sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=", - version = "v0.0.0-20191204190536-9bdfabe68543", -) - -go_repository( - name = "com_github_google_btree", - importpath = "github.com/google/btree", - sum = "h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=", - version = "v1.0.0", -) - -go_repository( - name = "com_github_golang_protobuf", - importpath = "github.com/golang/protobuf", - sum = "h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=", - version = "v1.3.1", -) - -go_repository( - name = "com_github_google_go-github", - importpath = "github.com/google/go-github", - sum = "h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=", - version = "v17.0.0", -) - -go_repository( - name = "org_golang_x_oauth2", - importpath = "golang.org/x/oauth2", - sum = "h1:pE8b58s1HRDMi8RDc79m0HISf9D4TzseP40cEA6IGfs=", - version = "v0.0.0-20191202225959-858c2ad4c8b6", -) - -go_repository( - name = "com_github_google_go-querystring", - importpath = "github.com/google/go-querystring", - sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=", - version = "v1.0.0", -) - -# System Call test dependencies. -http_archive( - name = "com_google_absl", - sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93", - strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e", - urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", - ], -) - -http_archive( - name = "com_google_googletest", - sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912", - strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34", - urls = [ - "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", - "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", - ], -) - -http_archive( - name = "com_google_benchmark", - sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a", - strip_prefix = "benchmark-1.5.0", - urls = [ - "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz", - "https://github.com/google/benchmark/archive/v1.5.0.tar.gz", - ], -) diff --git a/benchmarks/BUILD b/benchmarks/BUILD deleted file mode 100644 index 2a2d15d7e..000000000 --- a/benchmarks/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -package(licenses = ["notice"]) - -config_setting( - name = "gcloud_rule", - values = { - "define": "gcloud=off", - }, -) - -py_binary( - name = "benchmarks", - srcs = ["run.py"], - data = select({ - ":gcloud_rule": [], - "//conditions:default": [ - "//tools/images:ubuntu1604", - "//tools/images:zone", - ], - }), - main = "run.py", - python_version = "PY3", - srcs_version = "PY3", - tags = [ - "local", - "manual", - ], - deps = ["//benchmarks/runner"], -) diff --git a/benchmarks/README.md b/benchmarks/README.md deleted file mode 100644 index 6d1ea3ae2..000000000 --- a/benchmarks/README.md +++ /dev/null @@ -1,182 +0,0 @@ -# Benchmark tools - -These scripts are tools for collecting performance data for Docker-based tests. - -## Setup - -The scripts assume the following: - -* There are two sets of machines: one where the scripts will be run - (controller) and one or more machines on which docker containers will be run - (environment). -* The controller machine must have bazel installed along with this source - code. You should be able to run a command like `bazel run :benchmarks -- - --list` -* Environment machines must have docker and the required runtimes installed. - More specifically, you should be able to run a command like: `docker run - --runtime=$RUNTIME your/image`. -* The controller has ssh private key which can be used to login to environment - machines and run docker commands without using `sudo`. This is not required - if running locally via the `run-local` command. -* The docker daemon on each of your environment machines is listening on - `unix:///var/run/docker.sock` (docker's default). - -For configuring the environment manually, consult the -[dockerd documentation][dockerd]. - -## Running benchmarks - -### Locally - -The tool is built to, by default, use Google Cloud Platform to run benchmarks, -but it does support GCP workflows. To run locally, run the following from the -benchmarks directory: - -```bash -bazel run --define gcloud=off :benchmarks -- run-local startup - -... -method,metric,result -startup.empty,startup_time_ms,652.5772 -startup.node,startup_time_ms,1654.4042000000002 -startup.ruby,startup_time_ms,1429.835 -``` - -The above command ran the startup benchmark locally, which consists of three -benchmarks (empty, node, and ruby). Benchmark tools ran it on the default -runtime, runc. Running on another installed runtime, like say runsc, is as -simple as: - -```bash -bazel run --define gcloud=off :benchmarks -- run-local startup --runtime=runsc -``` - -There is help: `bash bazel run --define gcloud=off :benchmarks -- --help bazel -run --define gcloud=off :benchmarks -- run-local --help` - -To list available benchmarks, use the `list` commmand: - -```bash -bazel --define gcloud=off run :benchmarks -- list - -... -Benchmark: sysbench.cpu -Metrics: events_per_second - Run sysbench CPU test. Additional arguments can be provided for sysbench. - - :param max_prime: The maximum prime number to search. -``` - -You can choose benchmarks by name or regex like: - -```bash -bazel run --define gcloud=off :benchmarks -- run-local startup.node -... -metric,result -startup_time_ms,1671.7178000000001 - -``` - -or - -```bash -bazel run --define gcloud=off :benchmarks -- run-local s -... -method,metric,result -startup.empty,startup_time_ms,1792.8292 -startup.node,startup_time_ms,3113.5274 -startup.ruby,startup_time_ms,3025.2424 -sysbench.cpu,cpu_events_per_second,12661.47 -sysbench.memory,memory_ops_per_second,7228268.44 -sysbench.mutex,mutex_time,17.4835 -sysbench.mutex,mutex_latency,3496.7 -sysbench.mutex,mutex_deviation,0.04 -syscall.syscall,syscall_time_ns,2065.0 -``` - -You can run parameterized benchmarks, for example to run with different -runtimes: - -```bash -bazel run --define gcloud=off :benchmarks -- run-local --runtime=runc --runtime=runsc sysbench.cpu -``` - -Or with different parameters: - -```bash -bazel run --define gcloud=off :benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu -``` - -### On Google Compute Engine (GCE) - -Benchmarks may be run on GCE in an automated way. The default project configured -for `gcloud` will be used. - -An additional parameter `installers` may be provided to ensure that the latest -runtime is installed from the workspace. See the files in `tools/installers` for -supported install targets. - -```bash -bazel run :benchmarks -- run-gcp --installers=head --runtime=runsc sysbench.cpu -``` - -When running on GCE, the scripts generate a per run SSH key, which is added to -your project. The key is set to expire in GCE after 60 minutes and is stored in -a temporary directory on the local machine running the scripts. - -## Writing benchmarks - -To write new benchmarks, you should familiarize yourself with the structure of -the repository. There are three key components. - -## Harness - -The harness makes use of the [docker py SDK][docker-py]. It is advisable that -you familiarize yourself with that API when making changes, specifically: - -* clients -* containers -* images - -In general, benchmarks need only interact with the `Machine` objects provided to -the benchmark function, which are the machines defined in the environment. These -objects allow the benchmark to define the relationships between different -containers, and parse the output. - -## Workloads - -The harness requires workloads to run. These are all available in the -`workloads` directory. - -In general, a workload consists of a Dockerfile to build it (while these are not -hermetic, in general they should be as fixed and isolated as possible), some -parsers for output if required, parser tests and sample data. Provided the test -is named after the workload package and contains a function named `sample`, this -variable will be used to automatically mock workload output when the `--mock` -flag is provided to the main tool. - -## Writing benchmarks - -Benchmarks define the tests themselves. All benchmarks have the following -function signature: - -```python -def my_func(output) -> float: - return float(output) - -@benchmark(metrics = my_func, machines = 1) -def my_benchmark(machine: machine.Machine, arg: str): - return "3.4432" -``` - -Each benchmark takes a variable amount of position arguments as -`harness.Machine` objects and some set of keyword arguments. It is recommended -that you accept arbitrary keyword arguments and pass them through when -constructing the container under test. - -To write a new benchmark, open a module in the `suites` directory and use the -above signature. You should add a descriptive doc string to describe what your -benchmark is and any test centric arguments. - -[dockerd]: https://docs.docker.com/engine/reference/commandline/dockerd/ -[docker-py]: https://docker-py.readthedocs.io/en/stable/ diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl deleted file mode 100644 index 56d28223e..000000000 --- a/benchmarks/defs.bzl +++ /dev/null @@ -1,14 +0,0 @@ -"""Provides attributes common to many workload tests.""" - -load("//tools:defs.bzl", "py_requirement") - -test_deps = [ - py_requirement("attrs", direct = False), - py_requirement("atomicwrites", direct = False), - py_requirement("more-itertools", direct = False), - py_requirement("pathlib2", direct = False), - py_requirement("pluggy", direct = False), - py_requirement("py", direct = False), - py_requirement("pytest"), - py_requirement("six", direct = False), -] diff --git a/benchmarks/examples/localhost.yaml b/benchmarks/examples/localhost.yaml deleted file mode 100644 index f70fe0fb7..000000000 --- a/benchmarks/examples/localhost.yaml +++ /dev/null @@ -1,2 +0,0 @@ -client: localhost -server: localhost diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD deleted file mode 100644 index 48c548d59..000000000 --- a/benchmarks/harness/BUILD +++ /dev/null @@ -1,202 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "installers", - srcs = [ - "//tools/installers:head", - "//tools/installers:master", - "//tools/installers:runsc", - ], - mode = "0755", -) - -filegroup( - name = "files", - srcs = [ - ":installers", - ], -) - -py_library( - name = "harness", - srcs = ["__init__.py"], - data = [ - ":files", - ], -) - -py_library( - name = "benchmark_driver", - srcs = ["benchmark_driver.py"], - deps = [ - "//benchmarks/harness/machine_mocks", - "//benchmarks/harness/machine_producers:machine_producer", - "//benchmarks/suites", - ], -) - -py_library( - name = "container", - srcs = ["container.py"], - deps = [ - "//benchmarks/workloads", - py_requirement( - "asn1crypto", - direct = False, - ), - py_requirement( - "chardet", - direct = False, - ), - py_requirement( - "certifi", - direct = False, - ), - py_requirement("docker"), - py_requirement( - "docker-pycreds", - direct = False, - ), - py_requirement( - "idna", - direct = False, - ), - py_requirement( - "ptyprocess", - direct = False, - ), - py_requirement( - "requests", - direct = False, - ), - py_requirement( - "urllib3", - direct = False, - ), - py_requirement( - "websocket-client", - direct = False, - ), - ], -) - -py_library( - name = "machine", - srcs = ["machine.py"], - deps = [ - "//benchmarks/harness", - "//benchmarks/harness:container", - "//benchmarks/harness:ssh_connection", - "//benchmarks/harness:tunnel_dispatcher", - "//benchmarks/harness/machine_mocks", - py_requirement( - "asn1crypto", - direct = False, - ), - py_requirement( - "chardet", - direct = False, - ), - py_requirement( - "certifi", - direct = False, - ), - py_requirement("docker"), - py_requirement( - "docker-pycreds", - direct = False, - ), - py_requirement( - "idna", - direct = False, - ), - py_requirement( - "ptyprocess", - direct = False, - ), - py_requirement( - "requests", - direct = False, - ), - py_requirement( - "six", - direct = False, - ), - py_requirement( - "urllib3", - direct = False, - ), - py_requirement( - "websocket-client", - direct = False, - ), - ], -) - -py_library( - name = "ssh_connection", - srcs = ["ssh_connection.py"], - deps = [ - "//benchmarks/harness", - py_requirement( - "bcrypt", - direct = False, - ), - py_requirement("cffi"), - py_requirement("paramiko"), - py_requirement( - "cryptography", - direct = False, - ), - ], -) - -py_library( - name = "tunnel_dispatcher", - srcs = ["tunnel_dispatcher.py"], - deps = [ - py_requirement( - "asn1crypto", - direct = False, - ), - py_requirement( - "chardet", - direct = False, - ), - py_requirement( - "certifi", - direct = False, - ), - py_requirement("docker"), - py_requirement( - "docker-pycreds", - direct = False, - ), - py_requirement( - "idna", - direct = False, - ), - py_requirement("pexpect"), - py_requirement( - "ptyprocess", - direct = False, - ), - py_requirement( - "requests", - direct = False, - ), - py_requirement( - "urllib3", - direct = False, - ), - py_requirement( - "websocket-client", - direct = False, - ), - ], -) diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py deleted file mode 100644 index 15aa2a69a..000000000 --- a/benchmarks/harness/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Core benchmark utilities.""" - -import getpass -import os -import subprocess -import tempfile - -# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a -# format string that accepts a single string parameter. -LOCAL_WORKLOADS_PATH = os.path.dirname(__file__) + "/../workloads/{}/tar.tar" - -# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the -# remote host. This is a format string that accepts a single string parameter. -REMOTE_WORKLOADS_PATH = "workloads/{}" - -# INSTALLER_ROOT is the set of files that needs to be copied. -INSTALLER_ARCHIVE = os.readlink(os.path.join( - os.path.dirname(__file__), "installers.tar")) - -# SSH_KEY_DIR holds SSH_PRIVATE_KEY for this run. bm-tools paramiko requires -# keys generated with the '-t rsa -m PEM' options from ssh-keygen. This is -# abstracted away from the user. -SSH_KEY_DIR = tempfile.TemporaryDirectory() -SSH_PRIVATE_KEY = "key" - -# DEFAULT_USER is the default user running this script. -DEFAULT_USER = getpass.getuser() - -# DEFAULT_USER_HOME is the home directory of the user running the script. -DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else "" - -# Default directory to remotely installer "installer" targets. -REMOTE_INSTALLERS_PATH = "installers" - - -def make_key(): - """Wraps a valid ssh key in a temporary directory.""" - path = os.path.join(SSH_KEY_DIR.name, SSH_PRIVATE_KEY) - if not os.path.exists(path): - cmd = "ssh-keygen -t rsa -m PEM -b 4096 -f {key} -q -N".format( - key=path).split(" ") - cmd.append("") - subprocess.run(cmd, check=True) - return path - - -def delete_key(): - """Deletes temporary directory containing private key.""" - SSH_KEY_DIR.cleanup() diff --git a/benchmarks/harness/benchmark_driver.py b/benchmarks/harness/benchmark_driver.py deleted file mode 100644 index 9abc21b54..000000000 --- a/benchmarks/harness/benchmark_driver.py +++ /dev/null @@ -1,85 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Main driver for benchmarks.""" - -import copy -import statistics -import threading -import types - -from benchmarks import suites -from benchmarks.harness.machine_producers import machine_producer - - -# pylint: disable=too-many-instance-attributes -class BenchmarkDriver: - """Allocates machines and invokes a benchmark method.""" - - def __init__(self, - producer: machine_producer.MachineProducer, - method: types.FunctionType, - runs: int = 1, - **kwargs): - - self._producer = producer - self._method = method - self._kwargs = copy.deepcopy(kwargs) - self._threads = [] - self.lock = threading.RLock() - self._runs = runs - self._metric_results = {} - - def start(self): - """Starts a benchmark thread.""" - for _ in range(self._runs): - thread = threading.Thread(target=self._run_method) - thread.start() - self._threads.append(thread) - - def join(self): - """Joins the thread.""" - # pylint: disable=expression-not-assigned - [t.join() for t in self._threads] - - def _run_method(self): - """Runs all benchmarks.""" - machines = self._producer.get_machines( - suites.benchmark_machines(self._method)) - try: - result = self._method(*machines, **self._kwargs) - for name, res in result: - with self.lock: - if name in self._metric_results: - self._metric_results[name].append(res) - else: - self._metric_results[name] = [res] - finally: - # Always release. - self._producer.release_machines(machines) - - def median(self): - """Returns the median result, after join is finished.""" - for key, value in self._metric_results.items(): - yield key, [statistics.median(value)] - - def all(self): - """Returns all results.""" - for key, value in self._metric_results.items(): - yield key, value - - def meanstd(self): - """Returns all results.""" - for key, value in self._metric_results.items(): - mean = statistics.mean(value) - yield key, [mean, statistics.stdev(value, xbar=mean)] diff --git a/benchmarks/harness/container.py b/benchmarks/harness/container.py deleted file mode 100644 index 585436e20..000000000 --- a/benchmarks/harness/container.py +++ /dev/null @@ -1,181 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Container definitions.""" - -import contextlib -import logging -import pydoc -import types -from typing import Tuple - -import docker -import docker.errors - -from benchmarks import workloads - - -class Container: - """Abstract container. - - Must be a context manager. - - Usage: - - with Container(client, image, ...): - ... - """ - - def run(self, **env) -> str: - """Run the container synchronously.""" - raise NotImplementedError - - def detach(self, **env): - """Run the container asynchronously.""" - raise NotImplementedError - - def address(self) -> Tuple[str, int]: - """Return the bound address for the container.""" - raise NotImplementedError - - def get_names(self) -> types.GeneratorType: - """Return names of all containers.""" - raise NotImplementedError - - -# pylint: disable=too-many-instance-attributes -class DockerContainer(Container): - """Class that handles creating a docker container.""" - - # pylint: disable=too-many-arguments - def __init__(self, - client: docker.DockerClient, - host: str, - image: str, - count: int = 1, - runtime: str = "runc", - port: int = 0, - **kwargs): - """Trys to setup "count" containers. - - Args: - client: A docker client from dockerpy. - host: The host address the image is running on. - image: The name of the image to run. - count: The number of containers to setup. - runtime: The container runtime to use. - port: The port to reserve. - **kwargs: Additional container options. - """ - assert count >= 1 - assert port == 0 or count == 1 - self._client = client - self._host = host - self._containers = [] - self._count = count - self._image = image - self._runtime = runtime - self._port = port - self._kwargs = kwargs - if port != 0: - self._ports = {"%d/tcp" % port: None} - else: - self._ports = {} - - @contextlib.contextmanager - def detach(self, **env): - env = ["%s=%s" % (key, value) for (key, value) in env.items()] - # Start all containers. - for _ in range(self._count): - try: - # Start the container in a detached mode. - container = self._client.containers.run( - self._image, - detach=True, - remove=True, - runtime=self._runtime, - ports=self._ports, - environment=env, - **self._kwargs) - logging.info("Started detached container %s -> %s", self._image, - container.attrs["Id"]) - self._containers.append(container) - except Exception as exc: - self._clean_containers() - raise exc - try: - # Wait for all containers to be up. - for container in self._containers: - while not container.attrs["State"]["Running"]: - container = self._client.containers.get(container.attrs["Id"]) - yield self - finally: - self._clean_containers() - - def address(self) -> Tuple[str, int]: - assert self._count == 1 - assert self._port != 0 - container = self._client.containers.get(self._containers[0].attrs["Id"]) - port = container.attrs["NetworkSettings"]["Ports"][ - "%d/tcp" % self._port][0]["HostPort"] - return (self._host, port) - - def get_names(self) -> types.GeneratorType: - for container in self._containers: - yield container.name - - def run(self, **env) -> str: - env = ["%s=%s" % (key, value) for (key, value) in env.items()] - return self._client.containers.run( - self._image, - runtime=self._runtime, - ports=self._ports, - remove=True, - environment=env, - **self._kwargs).decode("utf-8") - - def _clean_containers(self): - """Kills all containers.""" - for container in self._containers: - try: - container.kill() - except docker.errors.NotFound: - pass - - -class MockContainer(Container): - """Mock of Container.""" - - def __init__(self, workload: str): - self._workload = workload - - def __enter__(self): - return self - - def run(self, **env): - # Lookup sample data if any exists for the workload module. We use a - # well-defined test locate and a well-defined sample function. - mod = pydoc.locate(workloads.__name__ + "." + self._workload) - if hasattr(mod, "sample"): - return mod.sample(**env) - return "" # No output. - - def address(self) -> Tuple[str, int]: - return ("example.com", 80) - - def get_names(self) -> types.GeneratorType: - yield "mock" - - @contextlib.contextmanager - def detach(self, **env): - yield self diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py deleted file mode 100644 index 5bdc4aa85..000000000 --- a/benchmarks/harness/machine.py +++ /dev/null @@ -1,265 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Machine abstraction passed to benchmarks to run docker containers. - -Abstraction for interacting with test machines. Machines are produced -by Machine producers and represent a local or remote machine. Benchmark -methods in /benchmarks/suite are passed the required number of machines in order -to run the benchmark. Machines contain methods to run commands via bash, -possibly over ssh. Machines also hold a connection to the docker UNIX socket -to run contianers. - - Typical usage example: - - machine = Machine() - machine.run(cmd) - machine.pull(path) - container = machine.container() -""" - -import logging -import os -import re -import subprocess -import time -from typing import List, Tuple - -import docker - -from benchmarks import harness -from benchmarks.harness import container -from benchmarks.harness import machine_mocks -from benchmarks.harness import ssh_connection -from benchmarks.harness import tunnel_dispatcher - -log = logging.getLogger(__name__) - - -class Machine(object): - """The machine object is the primary object for benchmarks. - - Machine objects are passed to each metric function call and benchmarks use - machines to access real connections to those machines. - - Attributes: - _name: Name as a string - """ - _name = "" - - def run(self, cmd: str) -> Tuple[str, str]: - """Convenience method for running a bash command on a machine object. - - Some machines may point to the local machine, and thus, do not have ssh - connections. Run runs a command either local or over ssh and returns the - output stdout and stderr as strings. - - Args: - cmd: The command to run as a string. - - Returns: - The command output. - """ - raise NotImplementedError - - def read(self, path: str) -> str: - """Reads the contents of some file. - - This will be mocked. - - Args: - path: The path to the file to be read. - - Returns: - The file contents. - """ - raise NotImplementedError - - def pull(self, workload: str) -> str: - """Send the given workload to the machine, build and tag it. - - All images must be defined by the workloads directory. - - Args: - workload: The workload name. - - Returns: - The workload tag. - """ - raise NotImplementedError - - def container(self, image: str, **kwargs) -> container.Container: - """Returns a container object. - - Args: - image: The pulled image tag. - **kwargs: Additional container options. - - Returns: - :return: a container.Container object. - """ - raise NotImplementedError - - def sleep(self, amount: float): - """Sleeps the given amount of time.""" - time.sleep(amount) - - def __str__(self): - return self._name - - -class MockMachine(Machine): - """A mocked machine.""" - _name = "mock" - - def run(self, cmd: str) -> Tuple[str, str]: - return "", "" - - def read(self, path: str) -> str: - return machine_mocks.Readfile(path) - - def pull(self, workload: str) -> str: - return workload # Workload is the tag. - - def container(self, image: str, **kwargs) -> container.Container: - return container.MockContainer(image) - - def sleep(self, amount: float): - pass - - -def get_address(machine: Machine) -> str: - """Return a machine's default address.""" - default_route, _ = machine.run("ip route get 8.8.8.8") - return re.search(" src ([0-9.]+) ", default_route).group(1) - - -class LocalMachine(Machine): - """The local machine. - - Attributes: - _name: Name as a string - _docker_client: a pythonic connection to to the local dockerd unix socket. - See: https://github.com/docker/docker-py - """ - - def __init__(self, name): - self._name = name - self._docker_client = docker.from_env() - - def run(self, cmd: str) -> Tuple[str, str]: - process = subprocess.Popen( - cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - return stdout.decode("utf-8"), stderr.decode("utf-8") - - def read(self, path: str) -> bytes: - # Read the exact path locally. - return open(path, "r").read() - - def pull(self, workload: str) -> str: - # Run the docker build command locally. - logging.info("Building %s@%s locally...", workload, self._name) - with open(harness.LOCAL_WORKLOADS_PATH.format(workload), - "rb") as dockerfile: - self._docker_client.images.build( - fileobj=dockerfile, tag=workload, custom_context=True) - return workload # Workload is the tag. - - def container(self, image: str, **kwargs) -> container.Container: - # Return a local docker container directly. - return container.DockerContainer(self._docker_client, get_address(self), - image, **kwargs) - - def sleep(self, amount: float): - time.sleep(amount) - - -class RemoteMachine(Machine): - """Remote machine accessible via an SSH connection. - - Attributes: - _name: Name as a string - _ssh_connection: a paramiko backed ssh connection which can be used to run - commands on this machine - _tunnel: a python wrapper around a port forwarded ssh connection between a - local unix socket and the remote machine's dockerd unix socket. - _docker_client: a pythonic wrapper backed by the _tunnel. Allows sending - docker commands: see https://github.com/docker/docker-py - """ - - def __init__(self, name, **kwargs): - self._name = name - self._ssh_connection = ssh_connection.SSHConnection(name, **kwargs) - self._tunnel = tunnel_dispatcher.Tunnel(name, **kwargs) - self._tunnel.connect() - self._docker_client = self._tunnel.get_docker_client() - self._has_installers = False - - def run(self, cmd: str) -> Tuple[str, str]: - return self._ssh_connection.run(cmd) - - def read(self, path: str) -> str: - # Just cat remotely. - stdout, stderr = self._ssh_connection.run("cat '{}'".format(path)) - return stdout + stderr - - def install(self, - installer: str, - results: List[bool] = None, - index: int = -1): - """Method unique to RemoteMachine to handle installation of installers. - - Handles installers, which install things that may change between runs (e.g. - runsc). Usually called from gcloud_producer, which expects this method to - to store results. - - Args: - installer: the installer target to run. - results: Passed by the caller of where to store success. - index: Index for this method to store the result in the passed results - list. - """ - # This generates a tarball of the full installer root (which will generate - # be the full bazel root directory) and sends it over. - if not self._has_installers: - archive = self._ssh_connection.send_installers() - self.run("tar -xvf {archive} -C {dir}".format( - archive=archive, dir=harness.REMOTE_INSTALLERS_PATH)) - self._has_installers = True - - # Execute the remote installer. - self.run("sudo {dir}/{file}".format( - dir=harness.REMOTE_INSTALLERS_PATH, file=installer)) - - if results: - results[index] = True - - def pull(self, workload: str) -> str: - # Push to the remote machine and build. - logging.info("Building %s@%s remotely...", workload, self._name) - remote_path = self._ssh_connection.send_workload(workload) - remote_dir = os.path.dirname(remote_path) - # Workloads are all tarballs. - self.run("tar -xvf {remote_path} -C {remote_dir}".format( - remote_path=remote_path, remote_dir=remote_dir)) - self.run("docker build --tag={} {}".format(workload, remote_dir)) - return workload # Workload is the tag. - - def container(self, image: str, **kwargs) -> container.Container: - # Return a remote docker container. - return container.DockerContainer(self._docker_client, get_address(self), - image, **kwargs) - - def sleep(self, amount: float): - time.sleep(amount) diff --git a/benchmarks/harness/machine_mocks/BUILD b/benchmarks/harness/machine_mocks/BUILD deleted file mode 100644 index c8ec4bc79..000000000 --- a/benchmarks/harness/machine_mocks/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "machine_mocks", - srcs = ["__init__.py"], -) diff --git a/benchmarks/harness/machine_mocks/__init__.py b/benchmarks/harness/machine_mocks/__init__.py deleted file mode 100644 index 00f0085d7..000000000 --- a/benchmarks/harness/machine_mocks/__init__.py +++ /dev/null @@ -1,81 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Machine mock files.""" - -MEMINFO = """\ -MemTotal: 7652344 kB -MemFree: 7174724 kB -MemAvailable: 7152008 kB -Buffers: 7544 kB -Cached: 178856 kB -SwapCached: 0 kB -Active: 270928 kB -Inactive: 68436 kB -Active(anon): 153124 kB -Inactive(anon): 880 kB -Active(file): 117804 kB -Inactive(file): 67556 kB -Unevictable: 0 kB -Mlocked: 0 kB -SwapTotal: 0 kB -SwapFree: 0 kB -Dirty: 900 kB -Writeback: 0 kB -AnonPages: 153000 kB -Mapped: 129120 kB -Shmem: 1044 kB -Slab: 60864 kB -SReclaimable: 22792 kB -SUnreclaim: 38072 kB -KernelStack: 2672 kB -PageTables: 5756 kB -NFS_Unstable: 0 kB -Bounce: 0 kB -WritebackTmp: 0 kB -CommitLimit: 3826172 kB -Committed_AS: 663836 kB -VmallocTotal: 34359738367 kB -VmallocUsed: 0 kB -VmallocChunk: 0 kB -HardwareCorrupted: 0 kB -AnonHugePages: 0 kB -ShmemHugePages: 0 kB -ShmemPmdMapped: 0 kB -CmaTotal: 0 kB -CmaFree: 0 kB -HugePages_Total: 0 -HugePages_Free: 0 -HugePages_Rsvd: 0 -HugePages_Surp: 0 -Hugepagesize: 2048 kB -DirectMap4k: 94196 kB -DirectMap2M: 4624384 kB -DirectMap1G: 3145728 kB -""" - -CONTENTS = { - "/proc/meminfo": MEMINFO, -} - - -def Readfile(path: str) -> str: - """Reads a mock file. - - Args: - path: The target path. - - Returns: - Mocked file contents or None. - """ - return CONTENTS.get(path, None) diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD deleted file mode 100644 index 81f19bd08..000000000 --- a/benchmarks/harness/machine_producers/BUILD +++ /dev/null @@ -1,84 +0,0 @@ -load("//tools:defs.bzl", "py_library", "py_requirement") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "harness", - srcs = ["__init__.py"], -) - -py_library( - name = "machine_producer", - srcs = ["machine_producer.py"], -) - -py_library( - name = "mock_producer", - srcs = ["mock_producer.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/harness/machine_producers:gcloud_producer", - "//benchmarks/harness/machine_producers:machine_producer", - ], -) - -py_library( - name = "yaml_producer", - srcs = ["yaml_producer.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/harness/machine_producers:machine_producer", - py_requirement( - "PyYAML", - direct = False, - ), - ], -) - -py_library( - name = "gcloud_mock_recorder", - srcs = ["gcloud_mock_recorder.py"], -) - -py_library( - name = "gcloud_producer", - srcs = ["gcloud_producer.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/harness/machine_producers:gcloud_mock_recorder", - "//benchmarks/harness/machine_producers:machine_producer", - ], -) - -filegroup( - name = "test_data", - srcs = [ - "testdata/get_five.json", - "testdata/get_one.json", - ], -) - -py_library( - name = "gcloud_producer_test_lib", - srcs = ["gcloud_producer_test.py"], - deps = [ - "//benchmarks/harness/machine_producers:machine_producer", - "//benchmarks/harness/machine_producers:mock_producer", - ], -) - -py_test( - name = "gcloud_producer_test", - srcs = [":gcloud_producer_test_lib"], - data = [ - ":test_data", - ], - python_version = "PY3", - tags = [ - "local", - "manual", - ], -) diff --git a/benchmarks/harness/machine_producers/__init__.py b/benchmarks/harness/machine_producers/__init__.py deleted file mode 100644 index 634ef4843..000000000 --- a/benchmarks/harness/machine_producers/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/benchmarks/harness/machine_producers/gcloud_mock_recorder.py b/benchmarks/harness/machine_producers/gcloud_mock_recorder.py deleted file mode 100644 index fd9837a37..000000000 --- a/benchmarks/harness/machine_producers/gcloud_mock_recorder.py +++ /dev/null @@ -1,97 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A recorder and replay for testing the GCloudProducer. - -MockPrinter and MockReader handle printing and reading mock data for the -purposes of testing. MockPrinter is passed to GCloudProducer objects. The user -can then run scenarios and record them for playback in tests later. - -MockReader is passed to MockGcloudProducer objects and handles reading the -previously recorded mock data. - -It is left to the user to check if data printed is properly redacted for their -own use. The intended usecase for this class is data coming from gcloud -commands, which will contain public IPs and other instance data. - -The data format is json and printed/read from the ./test_data directory. The -data is the output of subprocess.CompletedProcess objects in json format. - - Typical usage example: - - recorder = MockPrinter() - producer = GCloudProducer(args, recorder) - machines = producer.get_machines(1) - with open("my_file.json") as fd: - recorder.write_out(fd) - - reader = MockReader(filename) - producer = MockGcloudProducer(args, mock) - machines = producer.get_machines(1) - assert len(machines) == 1 -""" - -import io -import json -import subprocess - - -class MockPrinter(object): - """Handles printing Mock data for MockGcloudProducer. - - Attributes: - _records: list of json object records for printing - """ - - def __init__(self): - self._records = [] - - def record(self, entry: subprocess.CompletedProcess): - """Records data and strips out ip addresses.""" - - record = { - "args": entry.args, - "stdout": entry.stdout.decode("utf-8"), - "returncode": str(entry.returncode) - } - self._records.append(record) - - def write_out(self, fd: io.FileIO): - """Prints out the data into the given filepath.""" - fd.write(json.dumps(self._records, indent=4)) - - -class MockReader(object): - """Handles reading Mock data for MockGcloudProducer. - - Attributes: - _records: List[json] records read from the passed in file. - """ - - def __init__(self, filepath: str): - with open(filepath, "rb") as file: - self._records = json.loads(file.read()) - self._i = 0 - - def __iter__(self): - return self - - def __next__(self, args) -> subprocess.CompletedProcess: - """Returns the next record as a CompletedProcess.""" - if self._i < len(self._records): - record = self._records[self._i] - stdout = record["stdout"].encode("ascii") - returncode = int(record["returncode"]) - return subprocess.CompletedProcess( - args=args, returncode=returncode, stdout=stdout, stderr=b"") - raise StopIteration() diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py deleted file mode 100644 index 513d16e4f..000000000 --- a/benchmarks/harness/machine_producers/gcloud_producer.py +++ /dev/null @@ -1,244 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A machine producer which produces machine objects using `gcloud`. - -Machine producers produce valid harness.Machine objects which are backed by -real machines. This producer produces those machines on the given user's GCP -account using the `gcloud` tool. - -GCloudProducer creates instances on the given GCP account named like: -`machine-XXXXXXX-XXXX-XXXX-XXXXXXXXXXXX` in a randomized fashion such that name -collisions with user instances shouldn't happen. - - Typical usage example: - - producer = GCloudProducer(args) - machines = producer.get_machines(NUM_MACHINES) - # run stuff on machines with machines[i].run(CMD) - producer.release_machines(NUM_MACHINES) -""" -import datetime -import json -import subprocess -import threading -from typing import List, Dict, Any -import uuid - -from benchmarks.harness import machine -from benchmarks.harness.machine_producers import gcloud_mock_recorder -from benchmarks.harness.machine_producers import machine_producer - - -class GCloudProducer(machine_producer.MachineProducer): - """Implementation of MachineProducer backed by GCP. - - Produces Machine objects backed by GCP instances. - - Attributes: - image: image name as a string. - zone: string to a valid GCP zone. - machine_type: type of GCP to create (e.g. n1-standard-4). - installers: list of installers post-boot. - ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys. - ssh_user: string of user name for ssh_key - ssh_password: string of password for ssh key - mock: a mock printer which will print mock data if required. Mock data is - recorded output from subprocess calls (returncode, stdout, args). - condition: mutex for this class around machine creation and deleteion. - """ - - def __init__(self, - image: str, - zone: str, - machine_type: str, - installers: List[str], - ssh_key_file: str, - ssh_user: str, - ssh_password: str, - mock: gcloud_mock_recorder.MockPrinter = None): - self.image = image - self.zone = zone - self.machine_type = machine_type - self.installers = installers - self.ssh_key_file = ssh_key_file - self.ssh_user = ssh_user - self.ssh_password = ssh_password - self.mock = mock - self.condition = threading.Condition() - - def get_machines(self, num_machines: int) -> List[machine.Machine]: - """Returns requested number of machines backed by GCP instances.""" - if num_machines <= 0: - raise ValueError( - "Cannot ask for {num} machines!".format(num=num_machines)) - with self.condition: - names = self._get_unique_names(num_machines) - instances = self._build_instances(names) - self._add_ssh_key_to_instances(names) - machines = self._machines_from_instances(instances) - - # Install all bits in lock-step. - # - # This will perform paralell installations for however many machines we - # have, but it's easy to track errors because if installing (a, b, c), we - # won't install "c" until "b" is installed on all machines. - for installer in self.installers: - threads = [None] * len(machines) - results = [False] * len(machines) - for i in range(len(machines)): - threads[i] = threading.Thread( - target=machines[i].install, args=(installer, results, i)) - threads[i].start() - for thread in threads: - thread.join() - for result in results: - if not result: - raise NotImplementedError( - "Installers failed on at least one machine!") - - # Add this user to each machine's docker group. - for m in machines: - m.run("sudo setfacl -m user:$USER:rw /var/run/docker.sock") - - return machines - - def release_machines(self, machine_list: List[machine.Machine]): - """Releases the requested number of machines, deleting the instances.""" - if not machine_list: - return - cmd = "gcloud compute instances delete --quiet".split(" ") - names = [str(m) for m in machine_list] - cmd.extend(names) - cmd.append("--zone={zone}".format(zone=self.zone)) - self._run_command(cmd, detach=True) - - def _machines_from_instances( - self, instances: List[Dict[str, Any]]) -> List[machine.Machine]: - """Creates Machine Objects from json data describing created instances.""" - machines = [] - for instance in instances: - name = instance["name"] - kwargs = { - "hostname": - instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"], - "key_path": - self.ssh_key_file, - "username": - self.ssh_user, - "key_password": - self.ssh_password - } - machines.append(machine.RemoteMachine(name=name, **kwargs)) - return machines - - def _get_unique_names(self, num_names) -> List[str]: - """Returns num_names unique names based on data from the GCP project.""" - return ["machine-" + str(uuid.uuid4()) for _ in range(0, num_names)] - - def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]: - """Creates instances using gcloud command. - - Runs the command `gcloud compute instances create` and returns json data - on created instances on success. Creates len(names) instances, one for each - name. - - Args: - names: list of names of instances to create. - - Returns: - List of json data describing created machines. - """ - if not names: - raise ValueError( - "_build_instances cannot create instances without names.") - cmd = "gcloud compute instances create".split(" ") - cmd.extend(names) - cmd.append("--image=" + self.image) - cmd.append("--zone=" + self.zone) - cmd.append("--machine-type=" + self.machine_type) - res = self._run_command(cmd) - return json.loads(res.stdout) - - def _add_ssh_key_to_instances(self, names: List[str]) -> None: - """Adds ssh key to instances by calling gcloud ssh command. - - Runs the command `gcloud compute ssh instance_name` on list of images by - name. Tries to ssh into given instance. - - Args: - names: list of machine names to which to add the ssh-key - self.ssh_key_file. - - Raises: - subprocess.CalledProcessError: when underlying subprocess call returns an - error other than 255 (Connection closed by remote host). - TimeoutError: when 3 unsuccessful tries to ssh into the host return 255. - """ - for name in names: - cmd = "gcloud compute ssh {name}".format(name=name).split(" ") - cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file)) - cmd.append("--zone={zone}".format(zone=self.zone)) - cmd.append("--command=uname") - cmd.append("--ssh-key-expire-after=60m") - timeout = datetime.timedelta(seconds=5 * 60) - start = datetime.datetime.now() - while datetime.datetime.now() <= timeout + start: - try: - self._run_command(cmd) - break - except subprocess.CalledProcessError: - if datetime.datetime.now() > timeout + start: - raise TimeoutError( - "Could not SSH into instance after 5 min: {name}".format( - name=name)) - - def _run_command(self, - cmd: List[str], - detach: bool = False) -> [None, subprocess.CompletedProcess]: - """Runs command as a subprocess. - - Runs command as subprocess and returns the result. - If this has a mock recorder, use the record method to record the subprocess - call. - - Args: - cmd: command to be run as a list of strings. - detach: if True, run the child process and don't wait for it to return. - - Returns: - Completed process object to be parsed by caller or None if detach=True. - - Raises: - CalledProcessError: if subprocess.run returns an error. - """ - cmd = cmd + ["--format=json"] - if detach: - p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - if self.mock: - out, _ = p.communicate() - self.mock.record( - subprocess.CompletedProcess( - returncode=p.returncode, stdout=out, args=p.args)) - return - - res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - if self.mock: - self.mock.record(res) - if res.returncode != 0: - raise subprocess.CalledProcessError( - cmd=" ".join(res.args), - output=res.stdout, - stderr=res.stderr, - returncode=res.returncode) - return res diff --git a/benchmarks/harness/machine_producers/gcloud_producer_test.py b/benchmarks/harness/machine_producers/gcloud_producer_test.py deleted file mode 100644 index c8adb2bdc..000000000 --- a/benchmarks/harness/machine_producers/gcloud_producer_test.py +++ /dev/null @@ -1,48 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests GCloudProducer using mock data. - -GCloudProducer produces machines using 'get_machines' and 'release_machines' -methods. The tests check recorded data (jsonified subprocess.CompletedProcess -objects) of the producer producing one and five machines. -""" -import os -import types - -from benchmarks.harness.machine_producers import machine_producer -from benchmarks.harness.machine_producers import mock_producer - -TEST_DIR = os.path.dirname(__file__) - - -def run_get_release(producer: machine_producer.MachineProducer, - num_machines: int, - validator: types.FunctionType = None): - machines = producer.get_machines(num_machines) - assert len(machines) == num_machines - if validator: - validator(machines=machines, cmd="uname -a", workload=None) - producer.release_machines(machines) - - -def test_run_one(): - mock = mock_producer.MockReader(TEST_DIR + "get_one.json") - producer = mock_producer.MockGCloudProducer(mock) - run_get_release(producer, 1) - - -def test_run_five(): - mock = mock_producer.MockReader(TEST_DIR + "get_five.json") - producer = mock_producer.MockGCloudProducer(mock) - run_get_release(producer, 5) diff --git a/benchmarks/harness/machine_producers/machine_producer.py b/benchmarks/harness/machine_producers/machine_producer.py deleted file mode 100644 index f5591c026..000000000 --- a/benchmarks/harness/machine_producers/machine_producer.py +++ /dev/null @@ -1,51 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Abstract types.""" - -import threading -from typing import List - -from benchmarks.harness import machine - - -class MachineProducer: - """Abstract Machine producer.""" - - def get_machines(self, num_machines: int) -> List[machine.Machine]: - """Returns the requested number of machines.""" - raise NotImplementedError - - def release_machines(self, machine_list: List[machine.Machine]): - """Releases the given set of machines.""" - raise NotImplementedError - - -class LocalMachineProducer(MachineProducer): - """Produces Local Machines.""" - - def __init__(self, limit: int): - self.limit_sem = threading.Semaphore(value=limit) - - def get_machines(self, num_machines: int) -> List[machine.Machine]: - """Returns the request number of MockMachines.""" - - self.limit_sem.acquire() - return [machine.LocalMachine("local") for _ in range(num_machines)] - - def release_machines(self, machine_list: List[machine.MockMachine]): - """No-op.""" - if not machine_list: - raise ValueError("Cannot release an empty list!") - self.limit_sem.release() - machine_list.clear() diff --git a/benchmarks/harness/machine_producers/mock_producer.py b/benchmarks/harness/machine_producers/mock_producer.py deleted file mode 100644 index 37e9cb4b7..000000000 --- a/benchmarks/harness/machine_producers/mock_producer.py +++ /dev/null @@ -1,52 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Producers of mocks.""" - -from typing import List, Any - -from benchmarks.harness import machine -from benchmarks.harness.machine_producers import gcloud_mock_recorder -from benchmarks.harness.machine_producers import gcloud_producer -from benchmarks.harness.machine_producers import machine_producer - - -class MockMachineProducer(machine_producer.MachineProducer): - """Produces MockMachine objects.""" - - def get_machines(self, num_machines: int) -> List[machine.MockMachine]: - """Returns the request number of MockMachines.""" - return [machine.MockMachine() for i in range(num_machines)] - - def release_machines(self, machine_list: List[machine.MockMachine]): - """No-op.""" - return - - -class MockGCloudProducer(gcloud_producer.GCloudProducer): - """Mocks GCloudProducer for testing purposes.""" - - def __init__(self, mock: gcloud_mock_recorder.MockReader, **kwargs): - gcloud_producer.GCloudProducer.__init__( - self, project="mock", ssh_private_key_path="mock", **kwargs) - self.mock = mock - - def _validate_ssh_file(self): - pass - - def _run_command(self, cmd): - return self.mock.pop(cmd) - - def _machines_from_instances( - self, instances: List[Any]) -> List[machine.MockMachine]: - return [machine.MockMachine() for _ in instances] diff --git a/benchmarks/harness/machine_producers/testdata/get_five.json b/benchmarks/harness/machine_producers/testdata/get_five.json deleted file mode 100644 index 32bad1b06..000000000 --- a/benchmarks/harness/machine_producers/testdata/get_five.json +++ /dev/null @@ -1,211 +0,0 @@ -[ - { - "args": [ - "gcloud", - "compute", - "instances", - "list", - "--project", - "project", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":{\"natIP\":\"0.0.0.0\"}]}]}]", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "create", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92", - "machine-da5859b5-bae6-435d-8005-0202d6f6e065", - "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05", - "machine-1149147d-71e2-43ea-8fe1-49256e5c441c", - "--preemptible", - "--image=ubuntu-1910-eoan-v20191204", - "--zone=us-west1-b", - "--image-project=ubuntu-os-cloud", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "start", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92", - "machine-da5859b5-bae6-435d-8005-0202d6f6e065", - "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05", - "machine-1149147d-71e2-43ea-8fe1-49256e5c441c", - "--zone=us-west1-b", - "--project=project", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-da5859b5-bae6-435d-8005-0202d6f6e065", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-1149147d-71e2-43ea-8fe1-49256e5c441c", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "delete", - "--quiet", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92", - "machine-da5859b5-bae6-435d-8005-0202d6f6e065", - "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05", - "machine-1149147d-71e2-43ea-8fe1-49256e5c441c", - "--zone=us-west1-b", - "--format=json" - ], - "stdout": "[]\n", - "returncode": "0" - } -] diff --git a/benchmarks/harness/machine_producers/testdata/get_one.json b/benchmarks/harness/machine_producers/testdata/get_one.json deleted file mode 100644 index c359c19c8..000000000 --- a/benchmarks/harness/machine_producers/testdata/get_one.json +++ /dev/null @@ -1,145 +0,0 @@ -[ - { - "args": [ - "gcloud", - "compute", - "instances", - "list", - "--project", - "linux-testing-user", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "create", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--preemptible", - "--image=ubuntu-1910-eoan-v20191204", - "--zone=us-west1-b", - "--image-project=ubuntu-os-cloud", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "start", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--zone=us-west1-b", - "--project=linux-testing-user", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "delete", - "--quiet", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--zone=us-west1-b", - "--format=json" - ], - "stdout": "[]\n", - "returncode": "0" - } -] diff --git a/benchmarks/harness/machine_producers/yaml_producer.py b/benchmarks/harness/machine_producers/yaml_producer.py deleted file mode 100644 index 5d334e480..000000000 --- a/benchmarks/harness/machine_producers/yaml_producer.py +++ /dev/null @@ -1,106 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Producers based on yaml files.""" - -import os -import threading -from typing import Dict -from typing import List - -import yaml - -from benchmarks.harness import machine -from benchmarks.harness.machine_producers import machine_producer - - -class YamlMachineProducer(machine_producer.MachineProducer): - """Loads machines from a yaml file.""" - - def __init__(self, path: str): - self.machines = build_machines(path) - self.max_machines = len(self.machines) - self.machine_condition = threading.Condition() - - def get_machines(self, num_machines: int) -> List[machine.Machine]: - if num_machines > self.max_machines: - raise ValueError( - "Insufficient Ammount of Machines. {ask} asked for and have {max_num} max." - .format(ask=num_machines, max_num=self.max_machines)) - - with self.machine_condition: - while not self._enough_machines(num_machines): - self.machine_condition.wait(timeout=1) - return [self.machines.pop(0) for _ in range(num_machines)] - - def release_machines(self, machine_list: List[machine.Machine]): - with self.machine_condition: - while machine_list: - next_machine = machine_list.pop() - self.machines.append(next_machine) - self.machine_condition.notify() - - def _enough_machines(self, ask: int): - return ask <= len(self.machines) - - -def build_machines(path: str, num_machines: str = -1) -> List[machine.Machine]: - """Builds machine objects defined by the yaml file "path". - - Args: - path: The path to a yaml file which defines machines. - num_machines: Optional limit on how many machine objects to build. - - Returns: - Machine objects in a list. - - If num_machines is set, len(machines) <= num_machines. - """ - data = parse_yaml(path) - machines = [] - for key, value in data.items(): - if len(machines) == num_machines: - return machines - if isinstance(value, dict): - machines.append(machine.RemoteMachine(key, **value)) - else: - machines.append(machine.LocalMachine(key)) - return machines - - -def parse_yaml(path: str) -> Dict[str, Dict[str, str]]: - """Parse the yaml file pointed by path. - - Args: - path: The path to yaml file. - - Returns: - The contents of the yaml file as a dictionary. - """ - data = get_file_contents(path) - return yaml.load(data, Loader=yaml.Loader) - - -def get_file_contents(path: str) -> str: - """Dumps the file contents to a string and returns them. - - Args: - path: The path to dump. - - Returns: - The file contents as a string. - """ - if not os.path.isabs(path): - path = os.path.abspath(path) - with open(path) as input_file: - return input_file.read() diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py deleted file mode 100644 index b8c8e42d4..000000000 --- a/benchmarks/harness/ssh_connection.py +++ /dev/null @@ -1,126 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SSHConnection handles the details of SSH connections.""" - -import logging -import os -import warnings - -import paramiko - -from benchmarks import harness - -# Get rid of paramiko Cryptography Warnings. -warnings.filterwarnings(action="ignore", module=".*paramiko.*") - -log = logging.getLogger(__name__) - - -def send_one_file(client: paramiko.SSHClient, path: str, - remote_dir: str) -> str: - """Sends a single file via an SSH client. - - Args: - client: The existing SSH client. - path: The local path. - remote_dir: The remote directory. - - Returns: - :return: The remote path as a string. - """ - filename = path.split("/").pop() - if remote_dir != ".": - client.exec_command("mkdir -p " + remote_dir) - with client.open_sftp() as ftp_client: - ftp_client.put(path, os.path.join(remote_dir, filename)) - return os.path.join(remote_dir, filename) - - -class SSHConnection: - """SSH connection to a remote machine.""" - - def __init__(self, name: str, hostname: str, key_path: str, username: str, - **kwargs): - """Sets up a paramiko ssh connection to the given hostname.""" - self._name = name # Unused. - self._hostname = hostname - self._username = username - self._key_path = key_path # RSA Key path - self._kwargs = kwargs - # SSHConnection wraps paramiko. paramiko supports RSA, ECDSA, and Ed25519 - # keys, and we've chosen to only suport and require RSA keys. paramiko - # supports RSA keys that begin with '----BEGIN RSAKEY----'. - # https://stackoverflow.com/questions/53600581/ssh-key-generated-by-ssh-keygen-is-not-recognized-by-paramiko - self.rsa_key = self._rsa() - self.run("true") # Validate. - - def _client(self) -> paramiko.SSHClient: - """Returns a connected SSH client.""" - client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect( - hostname=self._hostname, - port=22, - username=self._username, - pkey=self.rsa_key, - allow_agent=False, - look_for_keys=False) - return client - - def _rsa(self): - if "key_password" in self._kwargs: - password = self._kwargs["key_password"] - else: - password = None - rsa = paramiko.RSAKey.from_private_key_file(self._key_path, password) - return rsa - - def run(self, cmd: str) -> (str, str): - """Runs a command via ssh. - - Args: - cmd: The shell command to run. - - Returns: - The contents of stdout and stderr. - """ - with self._client() as client: - log.info("running command: %s", cmd) - _, stdout, stderr = client.exec_command(command=cmd) - log.info("returned status: %d", stdout.channel.recv_exit_status()) - stdout = stdout.read().decode("utf-8") - stderr = stderr.read().decode("utf-8") - log.info("stdout: %s", stdout) - log.info("stderr: %s", stderr) - return stdout, stderr - - def send_workload(self, name: str) -> str: - """Sends a workload tarball to the remote machine. - - Args: - name: The workload name. - - Returns: - The remote path. - """ - with self._client() as client: - return send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name), - harness.REMOTE_WORKLOADS_PATH.format(name)) - - def send_installers(self) -> str: - with self._client() as client: - return send_one_file( - client, - path=harness.INSTALLER_ARCHIVE, - remote_dir=harness.REMOTE_INSTALLERS_PATH) diff --git a/benchmarks/harness/tunnel_dispatcher.py b/benchmarks/harness/tunnel_dispatcher.py deleted file mode 100644 index c56fd022a..000000000 --- a/benchmarks/harness/tunnel_dispatcher.py +++ /dev/null @@ -1,122 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tunnel handles setting up connections to remote machines. - -Tunnel dispatcher is a wrapper around the connection from a local UNIX socket -and a remote UNIX socket via SSH with port forwarding. This is done to -initialize the pythonic dockerpy client to run containers on the remote host by -connecting to /var/run/docker.sock (where Docker is listening). Tunnel -dispatcher sets up the local UNIX socket and calls the `ssh` command as a -subprocess, and holds a reference to that subprocess. It manages clean-up on -exit as best it can by killing the ssh subprocess and deleting the local UNIX -socket,stored in /tmp for easy cleanup in most systems if this fails. - - Typical usage example: - - t = Tunnel(name, **kwargs) - t.connect() - client = t.get_docker_client() # - client.containers.run("ubuntu", "echo hello world") - -""" - -import os -import tempfile -import time - -import docker -import pexpect - -SSH_TUNNEL_COMMAND = """ssh - -o GlobalKnownHostsFile=/dev/null - -o UserKnownHostsFile=/dev/null - -o StrictHostKeyChecking=no - -o IdentitiesOnly=yes - -nNT -L {filename}:/var/run/docker.sock - -i {key_path} - {username}@{hostname}""" - - -class Tunnel(object): - """The tunnel object represents the tunnel via ssh. - - This connects a local unix domain socket with a remote socket. - - Attributes: - _filename: a temporary name of the UNIX socket prefixed by the name - argument. - _hostname: the IP or resolvable hostname of the remote host. - _username: the username of the ssh_key used to run ssh. - _key_path: path to a valid key. - _key_password: optional password to the ssh key in _key_path - _process: holds reference to the ssh subprocess created. - - Returns: - The new minimum port. - - Raises: - ConnectionError: If no available port is found. - """ - - def __init__(self, - name: str, - hostname: str, - username: str, - key_path: str, - key_password: str = "", - **kwargs): - self._filename = tempfile.NamedTemporaryFile(prefix=name).name - self._hostname = hostname - self._username = username - self._key_path = key_path - self._key_password = key_password - self._kwargs = kwargs - self._process = None - - def connect(self): - """Connects the SSH tunnel and stores the subprocess reference in _process.""" - cmd = SSH_TUNNEL_COMMAND.format( - filename=self._filename, - key_path=self._key_path, - username=self._username, - hostname=self._hostname) - self._process = pexpect.spawn(cmd, timeout=10) - - # If given a password, assume we'll be asked for it. - if self._key_password: - self._process.expect(["Enter passphrase for key .*: "]) - self._process.sendline(self._key_password) - - while True: - # Wait for the tunnel to appear. - if self._process.exitstatus is not None: - raise ConnectionError("Error in setting up ssh tunnel") - if os.path.exists(self._filename): - return - time.sleep(0.1) - - def path(self): - """Return the socket file.""" - return self._filename - - def get_docker_client(self): - """Returns a docker client for this Tunnel.""" - return docker.DockerClient(base_url="unix:/" + self._filename) - - def __del__(self): - """Closes the ssh connection process and deletes the socket file.""" - if self._process: - self._process.close() - if os.path.exists(self._filename): - os.remove(self._filename) diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt deleted file mode 100644 index 577eb1a2e..000000000 --- a/benchmarks/requirements.txt +++ /dev/null @@ -1,32 +0,0 @@ -asn1crypto==1.2.0 -atomicwrites==1.3.0 -attrs==19.3.0 -bcrypt==3.1.7 -certifi==2019.9.11 -cffi==1.13.2 -chardet==3.0.4 -Click==7.0 -cryptography==2.8 -docker==3.7.0 -docker-pycreds==0.4.0 -idna==2.8 -importlib-metadata==0.23 -more-itertools==7.2.0 -packaging==19.2 -paramiko==2.6.0 -pathlib2==2.3.5 -pexpect==4.7.0 -pluggy==0.9.0 -ptyprocess==0.6.0 -py==1.8.0 -pycparser==2.19 -PyNaCl==1.3.0 -pyparsing==2.4.5 -pytest==4.3.0 -PyYAML==5.1.2 -requests==2.22.0 -six==1.13.0 -urllib3==1.25.7 -wcwidth==0.1.7 -websocket-client==0.56.0 -zipp==0.6.0 diff --git a/benchmarks/run.py b/benchmarks/run.py deleted file mode 100644 index a22eb8641..000000000 --- a/benchmarks/run.py +++ /dev/null @@ -1,19 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Benchmark runner.""" - -from benchmarks import runner - -if __name__ == "__main__": - runner.runner() diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD deleted file mode 100644 index 471debfdf..000000000 --- a/benchmarks/runner/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -load("//tools:defs.bzl", "py_library", "py_requirement", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package(licenses = ["notice"]) - -py_library( - name = "runner", - srcs = ["__init__.py"], - data = [ - "//benchmarks/workloads:files", - ], - visibility = ["//benchmarks:__pkg__"], - deps = [ - ":commands", - "//benchmarks/harness:benchmark_driver", - "//benchmarks/harness/machine_producers:machine_producer", - "//benchmarks/harness/machine_producers:mock_producer", - "//benchmarks/harness/machine_producers:yaml_producer", - "//benchmarks/suites", - "//benchmarks/suites:absl", - "//benchmarks/suites:density", - "//benchmarks/suites:fio", - "//benchmarks/suites:helpers", - "//benchmarks/suites:http", - "//benchmarks/suites:media", - "//benchmarks/suites:ml", - "//benchmarks/suites:network", - "//benchmarks/suites:redis", - "//benchmarks/suites:startup", - "//benchmarks/suites:sysbench", - "//benchmarks/suites:syscall", - py_requirement("click"), - ], -) - -py_library( - name = "commands", - srcs = ["commands.py"], - deps = [ - py_requirement("click"), - ], -) - -py_test( - name = "runner_test", - srcs = ["runner_test.py"], - python_version = "PY3", - tags = [ - "local", - "manual", - ], - deps = test_deps + [ - ":runner", - py_requirement("click"), - ], -) diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py deleted file mode 100644 index ba27dc69f..000000000 --- a/benchmarks/runner/__init__.py +++ /dev/null @@ -1,307 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""High-level benchmark utility.""" - -import copy -import csv -import logging -import pkgutil -import pydoc -import re -import sys -import types -from typing import List -from typing import Tuple - -import click - -from benchmarks import harness -from benchmarks import suites -from benchmarks.harness import benchmark_driver -from benchmarks.harness.machine_producers import gcloud_producer -from benchmarks.harness.machine_producers import machine_producer -from benchmarks.harness.machine_producers import mock_producer -from benchmarks.harness.machine_producers import yaml_producer -from benchmarks.runner import commands - - -@click.group() -@click.option( - "--verbose/--no-verbose", default=False, help="Enable verbose logging.") -@click.option("--debug/--no-debug", default=False, help="Enable debug logging.") -def runner(verbose: bool = False, debug: bool = False): - """Run distributed benchmarks. - - See the run and list commands for details. - - Args: - verbose: Enable verbose logging. - debug: Enable debug logging (supercedes verbose). - """ - if debug: - logging.basicConfig(level=logging.DEBUG) - elif verbose: - logging.basicConfig(level=logging.INFO) - - -def find_benchmarks( - regex: str) -> List[Tuple[str, types.ModuleType, types.FunctionType]]: - """Finds all available benchmarks. - - Args: - regex: A regular expression to match. - - Returns: - A (short_name, module, function) tuple for each match. - """ - pkgs = pkgutil.walk_packages(suites.__path__, suites.__name__ + ".") - found = [] - for _, name, _ in pkgs: - mod = pydoc.locate(name) - funcs = [ - getattr(mod, x) - for x in dir(mod) - if suites.is_benchmark(getattr(mod, x)) - ] - for func in funcs: - # Use the short_name with the benchmarks. prefix stripped. - prefix_len = len(suites.__name__ + ".") - short_name = mod.__name__[prefix_len:] + "." + func.__name__ - # Add to the list if a pattern is provided. - if re.compile(regex).match(short_name): - found.append((short_name, mod, func)) - return found - - -@runner.command("list") -@click.argument("method", nargs=-1) -def list_all(method): - """Lists available benchmarks.""" - if not method: - method = ".*" - else: - method = "(" + ",".join(method) + ")" - for (short_name, _, func) in find_benchmarks(method): - print("Benchmark %s:" % short_name) - metrics = suites.benchmark_metrics(func) - if func.__doc__: - print(" " + func.__doc__.lstrip().rstrip()) - if metrics: - print("\n Metrics:") - for metric in metrics: - print("\t{name}: {doc}".format(name=metric[0], doc=metric[1])) - print("\n") - - -@runner.command("run-local", commands.LocalCommand) -@click.pass_context -def run_local(ctx, limit: float, **kwargs): - """Runs benchmarks locally.""" - run(ctx, machine_producer.LocalMachineProducer(limit=limit), **kwargs) - - -@runner.command("run-mock", commands.RunCommand) -@click.pass_context -def run_mock(ctx, **kwargs): - """Runs benchmarks on Mock machines. Used for testing.""" - run(ctx, mock_producer.MockMachineProducer(), **kwargs) - - -@runner.command("run-gcp", commands.GCPCommand) -@click.pass_context -def run_gcp(ctx, image_file: str, zone_file: str, machine_type: str, - installers: List[str], **kwargs): - """Runs all benchmarks on GCP instances.""" - - # Resolve all files. - image = open(image_file).read().rstrip() - zone = open(zone_file).read().rstrip() - - key_file = harness.make_key() - - producer = gcloud_producer.GCloudProducer( - image, - zone, - machine_type, - installers, - ssh_key_file=key_file, - ssh_user=harness.DEFAULT_USER, - ssh_password="") - - try: - run(ctx, producer, **kwargs) - finally: - harness.delete_key() - - -def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int, - runtime: List[str], metric: List[str], stat: str, **kwargs): - """Runs arbitrary benchmarks. - - All unknown command line flags are passed through to the underlying benchmark - method. Flags may be specified multiple times, in which case it is considered - a "dimension" for the test, and a comma-separated table will be emitted - instead of a single result. - - See the output of list to see available metrics for any given benchmark - method. The method parameter is a regular expression that will match against - available benchmarks. If multiple benchmarks match, then that is considered a - distinct "dimension" for the test. - - All benchmarks are run in parallel where possible, but have exclusive - ownership over the individual machines. - - Every benchmark method will be run the times indicated by --runs. - - Args: - ctx: Click context. - producer: A Machine Producer from which to get Machines. - method: A regular expression for methods to be run. - runs: Number of runs. - runtime: A list of runtimes to test. - metric: A list of metrics to extract. - stat: The class of statistics to extract. - **kwargs: Dimensions to test. - """ - # First, calculate additional arguments. - # - # This essentially calculates any arguments that appear multiple times, and - # moves those to the "dimensions" dictionary, which maps to lists. These - # dimensions are then iterated over to generate the relevant csv output. - dimensions = {} - - if stat not in ["median", "all", "meanstd"]: - raise ValueError("Illegal value for --result, see help.") - - def squish(key: str, value: str): - """Collapse an argument into kwargs or dimensions.""" - if key in dimensions: - # Extend an existing dimension. - dimensions[key].append(value) - elif key in kwargs: - # Create a new dimension. - dimensions[key] = [kwargs[key], value] - del kwargs[key] - else: - # A single value. - kwargs[key] = value - - for item in ctx.args: - if "=" in method: - # This must be the method. The method is simply set to the first - # non-matching argument, which we're also parsing here. - item, method = method, item - if "=" not in item: - logging.error("illegal argument: %s", item) - sys.exit(1) - (key, value) = item.lstrip("-").split("=", 1) - squish(key, value) - - # Convert runtime and metric to dimensions. - # - # They exist only in the arguments above for documentation purposes. - # Essentially here we are treating them like anything else. Note however, - # that an empty set here will result in a dimension. This is important for - # metrics, where an empty set actually means all metrics. - def fold(key: str, value, allow_flatten=False): - """Collapse a list value into kwargs or dimensions.""" - if len(value) == 1 and allow_flatten: - kwargs[key] = value[0] - else: - dimensions[key] = value - - fold("runtime", runtime, allow_flatten=True) - fold("metric", metric) - - # Lookup the methods. - # - # We match the method parameter to a regular expression. This allows you to - # do things like `run --mock .*` for a broad test. Note that we track the - # short_names in the dimensions here, and look up again in the recursion. - methods = { - short_name: func for (short_name, _, func) in find_benchmarks(method) - } - if not methods: - # Must match at least one method. - logging.error("no matching benchmarks for %s: try list.", method) - sys.exit(1) - fold("method", list(methods.keys()), allow_flatten=True) - - # Spin up the drivers. - # - # We ensure that metric is the last entry, because we have special behavior. - # They actually run the test once and the benchmark is a generator that - # produces all viable metrics. - dimension_keys = list(dimensions.keys()) - if "metric" in dimension_keys: - dimension_keys.remove("metric") - dimension_keys.append("metric") - drivers = [] - - def _start(keywords, finished, left): - """Runs a test across dimensions recursively.""" - # Resolve the method fully, it starts as a string. - if "method" in keywords and isinstance(keywords["method"], str): - keywords["method"] = methods[keywords["method"]] - # Is this a non-recursive case? - if not left: - driver = benchmark_driver.BenchmarkDriver(producer, runs=runs, **keywords) - driver.start() - drivers.append((finished, driver)) - else: - # Recurse on the next dimension. - current, left = left[0], left[1:] - keywords = copy.deepcopy(keywords) - if current == "metric": - # We use a generator, popped below. Note that metric is - # guaranteed to be the last element here, and we will provide - # the value for 'done' below when generating the csv. - keywords[current] = dimensions[current] - _start(keywords, finished, left) - else: - # Generate manually. - for value in dimensions[current]: - keywords[current] = value - _start(keywords, finished + [value], left) - - # Start all the drivers, recursively. - _start(kwargs, [], dimension_keys) - - # Finish all tests, write results. - output = csv.writer(sys.stdout) - output.writerow(dimension_keys + ["result"]) - for (done, driver) in drivers: - driver.join() - for (metric_name, result) in getattr(driver, stat)(): - output.writerow([ # Collapse the method name. - hasattr(x, "__name__") and x.__name__ or x for x in done - ] + [metric_name] + result) - - -@runner.command() -@click.argument("env") -@click.option( - "--cmd", default="uname -a", help="command to run on all found machines") -@click.option( - "--workload", default="true", help="workload to run all found machines") -def validate(env, cmd, workload): - """Validates an environment described by yaml file.""" - producer = yaml_producer.YamlMachineProducer(env) - for machine in producer.machines: - print("Machine %s:" % machine) - stdout, _ = machine.run(cmd) - print(" Output of '%s': %s" % (cmd, stdout.lstrip().rstrip())) - image = machine.pull(workload) - stdout = machine.container(image).run() - print(" Container %s: %s" % (workload, stdout.lstrip().rstrip())) diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py deleted file mode 100644 index 0fccb2fad..000000000 --- a/benchmarks/runner/commands.py +++ /dev/null @@ -1,129 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Module with the guts of `click` commands. - -Overrides of the click.core.Command. This is done so flags are inherited between -similar commands (the run command). The classes below are meant to be used in -click templates like so. - -@runner.command("run-mock", RunCommand) -def run_mock(**kwargs): - # mock implementation - -""" -import os - -import click - - -class RunCommand(click.core.Command): - """Base Run Command with flags. - - Attributes: - method: regex of which suite to choose (e.g. sysbench would run - sysbench.cpu, sysbench.memory, and sysbench.mutex) See list command for - details. - metric: metric(s) to extract. See list command for details. - runtime: the runtime(s) on which to run. - runs: the number of runs to do of each method. - stat: how to compile results in the case of multiple run (e.g. median). - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - method = click.core.Argument(("method",)) - - metric = click.core.Option(("--metric",), - help="The metric to extract.", - multiple=True) - - runtime = click.core.Option(("--runtime",), - default=["runc"], - help="The runtime to use.", - multiple=True) - runs = click.core.Option(("--runs",), - default=1, - help="The number of times to run each benchmark.") - stat = click.core.Option( - ("--stat",), - default="median", - help="How to aggregate the data from all runs." - "\nmedian - returns the median of all runs (default)" - "\nall - returns all results comma separated" - "\nmeanstd - returns result as mean,std") - self.params.extend([method, runtime, runs, stat, metric]) - self.ignore_unknown_options = True - self.allow_extra_args = True - - -class LocalCommand(RunCommand): - """LocalCommand inherits all flags from RunCommand. - - Attributes: - limit: limits the number of machines on which to run benchmarks. This limits - for local how many benchmarks may run at a time. e.g. "startup" requires - one machine -- passing two machines would limit two startup jobs at a - time. Default is infinity. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.params.append( - click.core.Option( - ("--limit",), - default=1, - help="Limit of number of benchmarks that can run at a given time.")) - - -class GCPCommand(RunCommand): - """GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method. - - Attributes: - image_file: name of the image to build machines from - zone_file: a GCP zone (e.g. us-west1-b) - installers: named installers for post-create - machine_type: type of machine to create (e.g. n1-standard-4) - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - image_file = click.core.Option( - ("--image_file",), - help="The file containing the image for VMs.", - default=os.path.join( - os.path.dirname(__file__), "../../tools/images/ubuntu1604.txt"), - ) - zone_file = click.core.Option( - ("--zone_file",), - help="The file containing the GCP zone.", - default=os.path.join( - os.path.dirname(__file__), "../../tools/images/zone.txt"), - ) - installers = click.core.Option( - ("--installers",), - help="The set of installers to use.", - multiple=True, - ) - machine_type = click.core.Option( - ("--machine_type",), - help="Type to make all machines.", - default="n1-standard-4", - ) - self.params.extend([ - image_file, - zone_file, - machine_type, - installers, - ]) diff --git a/benchmarks/runner/runner_test.py b/benchmarks/runner/runner_test.py deleted file mode 100644 index 7818d631a..000000000 --- a/benchmarks/runner/runner_test.py +++ /dev/null @@ -1,59 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Top-level tests.""" - -import os -import subprocess -import sys - -from click import testing -import pytest - -from benchmarks import runner - - -def _get_locale(): - output = subprocess.check_output(["locale", "-a"]) - locales = output.split() - if b"en_US.utf8" in locales: - return "en_US.UTF-8" - else: - return "C.UTF-8" - - -def _set_locale(): - locale = _get_locale() - if os.getenv("LANG") != locale: - os.environ["LANG"] = locale - os.environ["LC_ALL"] = locale - os.execv("/proc/self/exe", ["python"] + sys.argv) - - -def test_list(): - cli_runner = testing.CliRunner() - result = cli_runner.invoke(runner.runner, ["list"]) - print(result.output) - assert result.exit_code == 0 - - -def test_run(): - cli_runner = testing.CliRunner() - result = cli_runner.invoke(runner.runner, ["run-mock", "."]) - print(result.output) - assert result.exit_code == 0 - - -if __name__ == "__main__": - _set_locale() - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/suites/BUILD b/benchmarks/suites/BUILD deleted file mode 100644 index 04fc23261..000000000 --- a/benchmarks/suites/BUILD +++ /dev/null @@ -1,130 +0,0 @@ -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "suites", - srcs = ["__init__.py"], -) - -py_library( - name = "absl", - srcs = ["absl.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/absl", - ], -) - -py_library( - name = "density", - srcs = ["density.py"], - deps = [ - "//benchmarks/harness:container", - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - ], -) - -py_library( - name = "fio", - srcs = ["fio.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - "//benchmarks/workloads/fio", - ], -) - -py_library( - name = "helpers", - srcs = ["helpers.py"], - deps = ["//benchmarks/harness:machine"], -) - -py_library( - name = "http", - srcs = ["http.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/ab", - ], -) - -py_library( - name = "media", - srcs = ["media.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - "//benchmarks/workloads/ffmpeg", - ], -) - -py_library( - name = "ml", - srcs = ["ml.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:startup", - "//benchmarks/workloads/tensorflow", - ], -) - -py_library( - name = "network", - srcs = ["network.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - "//benchmarks/workloads/iperf", - ], -) - -py_library( - name = "redis", - srcs = ["redis.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/redisbenchmark", - ], -) - -py_library( - name = "startup", - srcs = ["startup.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - ], -) - -py_library( - name = "sysbench", - srcs = ["sysbench.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/sysbench", - ], -) - -py_library( - name = "syscall", - srcs = ["syscall.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/syscall", - ], -) diff --git a/benchmarks/suites/__init__.py b/benchmarks/suites/__init__.py deleted file mode 100644 index 360736cc3..000000000 --- a/benchmarks/suites/__init__.py +++ /dev/null @@ -1,119 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Core benchmark annotations.""" - -import functools -import inspect -import types -from typing import List -from typing import Tuple - -BENCHMARK_METRICS = '__benchmark_metrics__' -BENCHMARK_MACHINES = '__benchmark_machines__' - - -def is_benchmark(func: types.FunctionType) -> bool: - """Returns true if the given function is a benchmark.""" - return isinstance(func, types.FunctionType) and \ - hasattr(func, BENCHMARK_METRICS) and \ - hasattr(func, BENCHMARK_MACHINES) - - -def benchmark_metrics(func: types.FunctionType) -> List[Tuple[str, str]]: - """Returns the list of available metrics.""" - return [(metric.__name__, metric.__doc__) - for metric in getattr(func, BENCHMARK_METRICS)] - - -def benchmark_machines(func: types.FunctionType) -> int: - """Returns the number of machines required.""" - return getattr(func, BENCHMARK_MACHINES) - - -# pylint: disable=unused-argument -def default(value, **kwargs): - """Returns the passed value.""" - return value - - -def benchmark(metrics: List[types.FunctionType] = None, - machines: int = 1) -> types.FunctionType: - """Define a benchmark function with metrics. - - Args: - metrics: A list of metric functions. - machines: The number of machines required. - - Returns: - A function that accepts the given number of machines, and iteratively - returns a set of (metric_name, metric_value) pairs when called repeatedly. - """ - if not metrics: - # The default passes through. - metrics = [default] - - def decorator(func: types.FunctionType) -> types.FunctionType: - """Decorator function.""" - # Every benchmark should accept at least two parameters: - # runtime: The runtime to use for the benchmark (str, required). - # metrics: The metrics to use, if not the default (str, optional). - @functools.wraps(func) - def wrapper(*args, runtime: str, metric: list = None, **kwargs): - """Wrapper function.""" - # First -- ensure that we marshall all types appropriately. In - # general, we will call this with only strings. These strings will - # need to be converted to their underlying types/classes. - sig = inspect.signature(func) - for param in sig.parameters.values(): - if param.annotation != inspect.Parameter.empty and \ - param.name in kwargs and not isinstance(kwargs[param.name], param.annotation): - try: - # Marshall to the appropriate type. - kwargs[param.name] = param.annotation(kwargs[param.name]) - except Exception as exc: - raise ValueError( - 'illegal type for %s(%s=%s): %s' % - (func.__name__, param.name, kwargs[param.name], exc)) - elif param.default != inspect.Parameter.empty and \ - param.name not in kwargs: - # Ensure that we have the value set, because it will - # be passed to the metric function for evaluation. - kwargs[param.name] = param.default - - # Next, figure out how to apply a metric. We do this prior to - # running the underlying function to prevent having to wait a few - # minutes for a result just to see some error. - if not metric: - # Return all metrics in the iterator. - result = func(*args, runtime=runtime, **kwargs) - for metric_func in metrics: - yield (metric_func.__name__, metric_func(result, **kwargs)) - else: - result = None - for single_metric in metric: - for metric_func in metrics: - # Is this a function that matches the name? - # Apply this function to the result. - if metric_func.__name__ == single_metric: - if not result: - # Lazy evaluation: only if metric matches. - result = func(*args, runtime=runtime, **kwargs) - yield single_metric, metric_func(result, **kwargs) - - # Set metadata on the benchmark (used above). - setattr(wrapper, BENCHMARK_METRICS, metrics) - setattr(wrapper, BENCHMARK_MACHINES, machines) - return wrapper - - return decorator diff --git a/benchmarks/suites/absl.py b/benchmarks/suites/absl.py deleted file mode 100644 index 5d9b57a09..000000000 --- a/benchmarks/suites/absl.py +++ /dev/null @@ -1,37 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""absl build benchmark.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads import absl - - -@suites.benchmark(metrics=[absl.elapsed_time], machines=1) -def build(target: machine.Machine, **kwargs) -> str: - """Runs the absl workload and report the absl build time. - - Runs the 'bazel build //absl/...' in a clean bazel directory and - monitors time elapsed. - - Args: - target: A machine object. - **kwargs: Additional container options. - - Returns: - Container output. - """ - image = target.pull("absl") - return target.container(image, **kwargs).run() diff --git a/benchmarks/suites/density.py b/benchmarks/suites/density.py deleted file mode 100644 index 89d29fb26..000000000 --- a/benchmarks/suites/density.py +++ /dev/null @@ -1,121 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Density tests.""" - -import re -import types - -from benchmarks import suites -from benchmarks.harness import container -from benchmarks.harness import machine -from benchmarks.suites import helpers - - -# pylint: disable=unused-argument -def memory_usage(value, **kwargs): - """Returns the passed value.""" - return value - - -def density(target: machine.Machine, - workload: str, - count: int = 50, - wait: float = 0, - load_func: types.FunctionType = None, - **kwargs): - """Calculate the average memory usage per container. - - Args: - target: A machine object. - workload: The workload to run. - count: The number of containers to start. - wait: The time to wait after starting. - load_func: Callback that is called after count images have been started on - the given machine. - **kwargs: Additional container options. - - Returns: - The average usage in Kb per container. - """ - count = int(count) - - # Drop all caches. - helpers.drop_caches(target) - before = target.read("/proc/meminfo") - - # Load the workload. - image = target.pull(workload) - - with target.container( - image=image, count=count, **kwargs).detach() as containers: - # Call the optional load function callback if given. - if load_func: - load_func(target, containers) - # Wait 'wait' time before taking a measurement. - target.sleep(wait) - - # Drop caches again. - helpers.drop_caches(target) - after = target.read("/proc/meminfo") - - # Calculate the memory used. - available_re = re.compile(r"MemAvailable:\s*(\d+)\skB\n") - before_available = available_re.findall(before) - after_available = available_re.findall(after) - return 1024 * float(int(before_available[0]) - - int(after_available[0])) / float(count) - - -def load_redis(target: machine.Machine, containers: container.Container): - """Use redis-benchmark "LPUSH" to load each container with 1G of data. - - Args: - target: A machine object. - containers: A set of containers. - """ - target.pull("redisbenchmark") - for name in containers.get_names(): - flags = "-d 10000 -t LPUSH" - target.container( - "redisbenchmark", links={ - name: name - }).run( - host=name, flags=flags) - - -@suites.benchmark(metrics=[memory_usage], machines=1) -def empty(target: machine.Machine, **kwargs) -> float: - """Run trivial containers in a density test.""" - return density(target, workload="sleep", wait=1.0, **kwargs) - - -@suites.benchmark(metrics=[memory_usage], machines=1) -def node(target: machine.Machine, **kwargs) -> float: - """Run node containers in a density test.""" - return density(target, workload="node", wait=3.0, **kwargs) - - -@suites.benchmark(metrics=[memory_usage], machines=1) -def ruby(target: machine.Machine, **kwargs) -> float: - """Run ruby containers in a density test.""" - return density(target, workload="ruby", wait=3.0, **kwargs) - - -@suites.benchmark(metrics=[memory_usage], machines=1) -def redis(target: machine.Machine, **kwargs) -> float: - """Run redis containers in a density test.""" - if "count" not in kwargs: - kwargs["count"] = 5 - return density( - target, workload="redis", wait=3.0, load_func=load_redis, **kwargs) diff --git a/benchmarks/suites/fio.py b/benchmarks/suites/fio.py deleted file mode 100644 index 2171790c5..000000000 --- a/benchmarks/suites/fio.py +++ /dev/null @@ -1,165 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""File I/O tests.""" - -import os - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import helpers -from benchmarks.workloads import fio - - -# pylint: disable=too-many-arguments -# pylint: disable=too-many-locals -def run_fio(target: machine.Machine, - test: str, - ioengine: str = "sync", - size: int = 1024 * 1024 * 1024, - iodepth: int = 4, - blocksize: int = 1024 * 1024, - time: int = -1, - mount_dir: str = "", - filename: str = "file.dat", - tmpfs: bool = False, - ramp_time: int = 0, - **kwargs) -> str: - """FIO benchmarks. - - For more on fio see: - https://media.readthedocs.org/pdf/fio/latest/fio.pdf - - Args: - target: A machine object. - test: The test to run (read, write, randread, randwrite, etc.) - ioengine: The engine for I/O. - size: The size of the generated file in bytes (if an integer) or 5g, 16k, - etc. - iodepth: The I/O for certain engines. - blocksize: The blocksize for reads and writes in bytes (if an integer) or - 4k, etc. - time: If test is time based, how long to run in seconds. - mount_dir: The absolute path on the host to mount a bind mount. - filename: The name of the file to creat inside container. For a path of - /dir/dir/file, the script setup a volume like 'docker run -v - mount_dir:/dir/dir fio' and fio will create (and delete) the file - /dir/dir/file. If tmpfs is set, this /dir/dir will be a tmpfs. - tmpfs: If true, mount on tmpfs. - ramp_time: The time to run before recording statistics - **kwargs: Additional container options. - - Returns: - The output of fio as a string. - """ - # Pull the image before dropping caches. - image = target.pull("fio") - - if not mount_dir: - stdout, _ = target.run("pwd") - mount_dir = stdout.rstrip() - - # Setup the volumes. - volumes = {mount_dir: {"bind": "/disk", "mode": "rw"}} if not tmpfs else None - tmpfs = {"/disk": ""} if tmpfs else None - - # Construct a file in the volume. - filepath = os.path.join("/disk", filename) - - # If we are running a read test, us fio to write a file and then flush file - # data from memory. - if "read" in test: - target.container( - image, volumes=volumes, tmpfs=tmpfs, **kwargs).run( - test="write", - ioengine="sync", - size=size, - iodepth=iodepth, - blocksize=blocksize, - path=filepath) - helpers.drop_caches(target) - - # Run the test. - time_str = "--time_base --runtime={time}".format( - time=time) if int(time) > 0 else "" - res = target.container( - image, volumes=volumes, tmpfs=tmpfs, **kwargs).run( - test=test, - ioengine=ioengine, - size=size, - iodepth=iodepth, - blocksize=blocksize, - time=time_str, - path=filepath, - ramp_time=ramp_time) - - target.run( - "rm {path}".format(path=os.path.join(mount_dir.rstrip(), filename))) - - return res - - -@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1) -def read(*args, **kwargs): - """Read test. - - Args: - *args: None. - **kwargs: Additional container options. - - Returns: - The output of fio. - """ - return run_fio(*args, test="read", **kwargs) - - -@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1) -def randread(*args, **kwargs): - """Random read test. - - Args: - *args: None. - **kwargs: Additional container options. - - Returns: - The output of fio. - """ - return run_fio(*args, test="randread", **kwargs) - - -@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1) -def write(*args, **kwargs): - """Write test. - - Args: - *args: None. - **kwargs: Additional container options. - - Returns: - The output of fio. - """ - return run_fio(*args, test="write", **kwargs) - - -@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1) -def randwrite(*args, **kwargs): - """Random write test. - - Args: - *args: None. - **kwargs: Additional container options. - - Returns: - The output of fio. - """ - return run_fio(*args, test="randwrite", **kwargs) diff --git a/benchmarks/suites/helpers.py b/benchmarks/suites/helpers.py deleted file mode 100644 index b3c7360ab..000000000 --- a/benchmarks/suites/helpers.py +++ /dev/null @@ -1,57 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Benchmark helpers.""" - -import datetime -from benchmarks.harness import machine - - -class Timer: - """Helper to time runtime of some call. - - Usage: - - with Timer as t: - # do something. - t.get_time_in_seconds() - """ - - def __init__(self): - self._start = datetime.datetime.now() - - def __enter__(self): - self.start() - return self - - def start(self): - """Starts the timer.""" - self._start = datetime.datetime.now() - - def elapsed(self) -> float: - """Returns the elapsed time in seconds.""" - return (datetime.datetime.now() - self._start).total_seconds() - - def __exit__(self, exception_type, exception_value, exception_traceback): - pass - - -def drop_caches(target: machine.Machine): - """Drops caches on the machine. - - Args: - target: A machine object. - """ - target.run("sudo sync") - target.run("sudo sysctl vm.drop_caches=3") - target.run("sudo sysctl vm.drop_caches=3") diff --git a/benchmarks/suites/http.py b/benchmarks/suites/http.py deleted file mode 100644 index 6efea938c..000000000 --- a/benchmarks/suites/http.py +++ /dev/null @@ -1,138 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""HTTP benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads import ab - - -# pylint: disable=too-many-arguments -def http(server: machine.Machine, - client: machine.Machine, - workload: str, - requests: int = 5000, - connections: int = 10, - port: int = 80, - path: str = "notfound", - **kwargs) -> str: - """Run apachebench (ab) against an http server. - - Args: - server: A machine object. - client: A machine object. - workload: The http-serving workload. - requests: Number of requests to send the server. Default is 5000. - connections: Number of concurent connections to use. Default is 10. - port: The port to access in benchmarking. - path: File to download, generally workload-specific. - **kwargs: Additional container options. - - Returns: - The full apachebench output. - """ - # Pull the client & server. - apachebench = client.pull("ab") - netcat = client.pull("netcat") - image = server.pull(workload) - - with server.container(image, port=port, **kwargs).detach() as container: - (host, port) = container.address() - # Wait for the server to come up. - client.container(netcat).run(host=host, port=port) - # Run the benchmark, no arguments. - return client.container(apachebench).run( - host=host, - port=port, - requests=requests, - connections=connections, - path=path) - - -# pylint: disable=too-many-arguments -# pylint: disable=too-many-locals -def http_app(server: machine.Machine, - client: machine.Machine, - workload: str, - requests: int = 5000, - connections: int = 10, - port: int = 80, - path: str = "notfound", - **kwargs) -> str: - """Run apachebench (ab) against an http application. - - Args: - server: A machine object. - client: A machine object. - workload: The http-serving workload. - requests: Number of requests to send the server. Default is 5000. - connections: Number of concurent connections to use. Default is 10. - port: The port to use for benchmarking. - path: File to download, generally workload-specific. - **kwargs: Additional container options. - - Returns: - The full apachebench output. - """ - # Pull the client & server. - apachebench = client.pull("ab") - netcat = client.pull("netcat") - server_netcat = server.pull("netcat") - redis = server.pull("redis") - image = server.pull(workload) - redis_port = 6379 - redis_name = "{workload}_redis_server".format(workload=workload) - - with server.container(redis, name=redis_name).detach(): - server.container(server_netcat, links={redis_name: redis_name})\ - .run(host=redis_name, port=redis_port) - with server.container(image, port=port, links={redis_name: redis_name}, **kwargs)\ - .detach(host=redis_name) as container: - (host, port) = container.address() - # Wait for the server to come up. - client.container(netcat).run(host=host, port=port) - # Run the benchmark, no arguments. - return client.container(apachebench).run( - host=host, - port=port, - requests=requests, - connections=connections, - path=path) - - -@suites.benchmark(metrics=[ab.transfer_rate, ab.latency], machines=2) -def httpd(*args, **kwargs) -> str: - """Apache2 benchmark.""" - return http(*args, workload="httpd", port=80, **kwargs) - - -@suites.benchmark( - metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) -def nginx(*args, **kwargs) -> str: - """Nginx benchmark.""" - return http(*args, workload="nginx", port=80, **kwargs) - - -@suites.benchmark( - metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) -def node(*args, **kwargs) -> str: - """Node benchmark.""" - return http_app(*args, workload="node_template", path="", port=8080, **kwargs) - - -@suites.benchmark( - metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) -def ruby(*args, **kwargs) -> str: - """Ruby benchmark.""" - return http_app(*args, workload="ruby_template", path="", port=9292, **kwargs) diff --git a/benchmarks/suites/media.py b/benchmarks/suites/media.py deleted file mode 100644 index 9cbffdaa1..000000000 --- a/benchmarks/suites/media.py +++ /dev/null @@ -1,42 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Media processing benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import helpers -from benchmarks.workloads import ffmpeg - - -@suites.benchmark(metrics=[ffmpeg.run_time], machines=1) -def transcode(target: machine.Machine, **kwargs) -> float: - """Runs a video transcoding workload and times it. - - Args: - target: A machine object. - **kwargs: Additional container options. - - Returns: - Total workload runtime. - """ - # Load before timing. - image = target.pull("ffmpeg") - - # Drop caches. - helpers.drop_caches(target) - - # Time startup + transcoding. - with helpers.Timer() as timer: - target.container(image, **kwargs).run() - return timer.elapsed() diff --git a/benchmarks/suites/ml.py b/benchmarks/suites/ml.py deleted file mode 100644 index a394d1f69..000000000 --- a/benchmarks/suites/ml.py +++ /dev/null @@ -1,33 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Machine Learning tests.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import startup -from benchmarks.workloads import tensorflow - - -@suites.benchmark(metrics=[tensorflow.run_time], machines=1) -def train(target: machine.Machine, **kwargs): - """Run the tensorflow benchmark and return the runtime in seconds of workload. - - Args: - target: A machine object. - **kwargs: Additional container options. - - Returns: - The total runtime. - """ - return startup.startup(target, workload="tensorflow", count=1, **kwargs) diff --git a/benchmarks/suites/network.py b/benchmarks/suites/network.py deleted file mode 100644 index f973cf3f1..000000000 --- a/benchmarks/suites/network.py +++ /dev/null @@ -1,101 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Network microbenchmarks.""" - -from typing import Dict - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import helpers -from benchmarks.workloads import iperf - - -def run_iperf(client: machine.Machine, - server: machine.Machine, - client_kwargs: Dict[str, str] = None, - server_kwargs: Dict[str, str] = None) -> str: - """Measure iperf performance. - - Args: - client: A machine object. - server: A machine object. - client_kwargs: Additional client container options. - server_kwargs: Additional server container options. - - Returns: - The output of iperf. - """ - if not client_kwargs: - client_kwargs = dict() - if not server_kwargs: - server_kwargs = dict() - - # Pull images. - netcat = client.pull("netcat") - iperf_client_image = client.pull("iperf") - iperf_server_image = server.pull("iperf") - - # Set this due to a bug in the kernel that resets connections. - client.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1") - server.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1") - - with server.container( - iperf_server_image, port=5001, **server_kwargs).detach() as iperf_server: - (host, port) = iperf_server.address() - # Wait until the service is available. - client.container(netcat).run(host=host, port=port) - # Run a warm-up run. - client.container( - iperf_client_image, stderr=True, **client_kwargs).run( - host=host, port=port) - # Run the client with relevant arguments. - res = client.container(iperf_client_image, stderr=True, **client_kwargs)\ - .run(host=host, port=port) - helpers.drop_caches(client) - return res - - -@suites.benchmark(metrics=[iperf.bandwidth], machines=2) -def upload(client: machine.Machine, server: machine.Machine, **kwargs) -> str: - """Measure upload performance. - - Args: - client: A machine object. - server: A machine object. - **kwargs: Client container options. - - Returns: - The output of iperf. - """ - if kwargs["runtime"] == "runc": - kwargs["network_mode"] = "host" - return run_iperf(client, server, client_kwargs=kwargs) - - -@suites.benchmark(metrics=[iperf.bandwidth], machines=2) -def download(client: machine.Machine, server: machine.Machine, **kwargs) -> str: - """Measure download performance. - - Args: - client: A machine object. - server: A machine object. - **kwargs: Server container options. - - Returns: - The output of iperf. - """ - - client_kwargs = {"network_mode": "host"} - return run_iperf( - client, server, client_kwargs=client_kwargs, server_kwargs=kwargs) diff --git a/benchmarks/suites/redis.py b/benchmarks/suites/redis.py deleted file mode 100644 index b84dd073d..000000000 --- a/benchmarks/suites/redis.py +++ /dev/null @@ -1,46 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Redis benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads import redisbenchmark - - -@suites.benchmark(metrics=list(redisbenchmark.METRICS.values()), machines=2) -def redis(server: machine.Machine, - client: machine.Machine, - flags: str = "", - **kwargs) -> str: - """Run redis-benchmark on client pointing at server machine. - - Args: - server: A machine object. - client: A machine object. - flags: Flags to pass redis-benchmark. - **kwargs: Additional container options. - - Returns: - Output from redis-benchmark. - """ - redis_server = server.pull("redis") - redis_client = client.pull("redisbenchmark") - netcat = client.pull("netcat") - with server.container( - redis_server, port=6379, **kwargs).detach() as container: - (host, port) = container.address() - # Wait for the container to be up. - client.container(netcat).run(host=host, port=port) - # Run all redis benchmarks. - return client.container(redis_client).run(host=host, port=port, flags=flags) diff --git a/benchmarks/suites/startup.py b/benchmarks/suites/startup.py deleted file mode 100644 index a1b6c5753..000000000 --- a/benchmarks/suites/startup.py +++ /dev/null @@ -1,110 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Start-up benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import helpers - - -# pylint: disable=unused-argument -def startup_time_ms(value, **kwargs): - """Returns average startup time per container in milliseconds. - - Args: - value: The floating point time in seconds. - **kwargs: Ignored. - - Returns: - The time given in milliseconds. - """ - return value * 1000 - - -def startup(target: machine.Machine, - workload: str, - count: int = 5, - port: int = 0, - **kwargs): - """Time the startup of some workload. - - Args: - target: A machine object. - workload: The workload to run. - count: Number of containers to start. - port: The port to check for liveness, if provided. - **kwargs: Additional container options. - - Returns: - The mean start-up time in seconds. - """ - # Load before timing. - image = target.pull(workload) - netcat = target.pull("netcat") - count = int(count) - port = int(port) - - with helpers.Timer() as timer: - for _ in range(count): - if not port: - # Run the container synchronously. - target.container(image, **kwargs).run() - else: - # Run a detached container until httpd available. - with target.container(image, port=port, **kwargs).detach() as server: - (server_host, server_port) = server.address() - target.container(netcat).run(host=server_host, port=server_port) - return timer.elapsed() / float(count) - - -@suites.benchmark(metrics=[startup_time_ms], machines=1) -def empty(target: machine.Machine, **kwargs) -> float: - """Time the startup of a trivial container. - - Args: - target: A machine object. - **kwargs: Additional startup options. - - Returns: - The time to run the container. - """ - return startup(target, workload="true", **kwargs) - - -@suites.benchmark(metrics=[startup_time_ms], machines=1) -def node(target: machine.Machine, **kwargs) -> float: - """Time the startup of the node container. - - Args: - target: A machine object. - **kwargs: Additional statup options. - - Returns: - The time to run the container. - """ - return startup(target, workload="node", port=8080, **kwargs) - - -@suites.benchmark(metrics=[startup_time_ms], machines=1) -def ruby(target: machine.Machine, **kwargs) -> float: - """Time the startup of the ruby container. - - Args: - target: A machine object. - **kwargs: Additional startup options. - - Returns: - The time to run the container. - """ - return startup(target, workload="ruby", port=3000, **kwargs) diff --git a/benchmarks/suites/sysbench.py b/benchmarks/suites/sysbench.py deleted file mode 100644 index 2a6e2126c..000000000 --- a/benchmarks/suites/sysbench.py +++ /dev/null @@ -1,119 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Sysbench-based benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads import sysbench - - -def run_sysbench(target: machine.Machine, - test: str = "cpu", - threads: int = 8, - time: int = 5, - options: str = "", - **kwargs) -> str: - """Run sysbench container with arguments. - - Args: - target: A machine object. - test: Relevant sysbench test to run (e.g. cpu, memory). - threads: The number of threads to use for tests. - time: The time to run tests. - options: Additional sysbench options. - **kwargs: Additional container options. - - Returns: - The output of the command as a string. - """ - image = target.pull("sysbench") - return target.container(image, **kwargs).run( - test=test, threads=threads, time=time, options=options) - - -@suites.benchmark(metrics=[sysbench.cpu_events_per_second], machines=1) -def cpu(target: machine.Machine, max_prime: int = 5000, **kwargs) -> str: - """Run sysbench CPU test. - - Additional arguments can be provided for sysbench. - - Args: - target: A machine object. - max_prime: The maximum prime number to search. - **kwargs: - - threads: The number of threads to use for tests. - - time: The time to run tests. - - options: Additional sysbench options. See sysbench tool: - https://github.com/akopytov/sysbench - - Returns: - Sysbench output. - """ - options = kwargs.pop("options", "") - options += " --cpu-max-prime={}".format(max_prime) - return run_sysbench(target, test="cpu", options=options, **kwargs) - - -@suites.benchmark(metrics=[sysbench.memory_ops_per_second], machines=1) -def memory(target: machine.Machine, **kwargs) -> str: - """Run sysbench memory test. - - Additional arguments can be provided per sysbench. - - Args: - target: A machine object. - **kwargs: - - threads: The number of threads to use for tests. - - time: The time to run tests. - - options: Additional sysbench options. See sysbench tool: - https://github.com/akopytov/sysbench - - Returns: - Sysbench output. - """ - return run_sysbench(target, test="memory", **kwargs) - - -@suites.benchmark( - metrics=[ - sysbench.mutex_time, sysbench.mutex_latency, sysbench.mutex_deviation - ], - machines=1) -def mutex(target: machine.Machine, - locks: int = 4, - count: int = 10000000, - threads: int = 8, - **kwargs) -> str: - """Run sysbench mutex test. - - Additional arguments can be provided per sysbench. - - Args: - target: A machine object. - locks: The number of locks to use. - count: The number of mutexes. - threads: The number of threads to use for tests. - **kwargs: - - time: The time to run tests. - - options: Additional sysbench options. See sysbench tool: - https://github.com/akopytov/sysbench - - Returns: - Sysbench output. - """ - options = kwargs.pop("options", "") - options += " --mutex-loops=1 --mutex-locks={} --mutex-num={}".format( - count, locks) - return run_sysbench( - target, test="mutex", options=options, threads=threads, **kwargs) diff --git a/benchmarks/suites/syscall.py b/benchmarks/suites/syscall.py deleted file mode 100644 index fa7665b00..000000000 --- a/benchmarks/suites/syscall.py +++ /dev/null @@ -1,37 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Syscall microbenchmark.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads.syscall import syscall_time_ns - - -@suites.benchmark(metrics=[syscall_time_ns], machines=1) -def syscall(target: machine.Machine, count: int = 1000000, **kwargs) -> str: - """Runs the syscall workload and report the syscall time. - - Runs the syscall 'SYS_gettimeofday(0,0)' 'count' times and monitors time - elapsed based on the runtime's MONOTONIC clock. - - Args: - target: A machine object. - count: The number of syscalls to execute. - **kwargs: Additional container options. - - Returns: - Container output. - """ - image = target.pull("syscall") - return target.container(image, **kwargs).run(count=count) diff --git a/benchmarks/tcp/BUILD b/benchmarks/tcp/BUILD deleted file mode 100644 index d5e401acc..000000000 --- a/benchmarks/tcp/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "cc_binary", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tcp_proxy", - srcs = ["tcp_proxy.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/adapters/gonet", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -# nsjoin is a trivial replacement for nsenter. This is used because nsenter is -# not available on all systems where this benchmark is run (and we aim to -# minimize external dependencies.) - -cc_binary( - name = "nsjoin", - srcs = ["nsjoin.c"], - visibility = ["//:sandbox"], -) - -sh_binary( - name = "tcp_benchmark", - srcs = ["tcp_benchmark.sh"], - data = [ - ":nsjoin", - ":tcp_proxy", - ], - visibility = ["//:sandbox"], -) diff --git a/benchmarks/tcp/README.md b/benchmarks/tcp/README.md deleted file mode 100644 index 38e6e69f0..000000000 --- a/benchmarks/tcp/README.md +++ /dev/null @@ -1,87 +0,0 @@ -# TCP Benchmarks - -This directory contains a standardized TCP benchmark. This helps to evaluate the -performance of netstack and native networking stacks under various conditions. - -## `tcp_benchmark` - -This benchmark allows TCP throughput testing under various conditions. The setup -consists of an iperf client, a client proxy, a server proxy and an iperf server. -The client proxy and server proxy abstract the network mechanism used to -communicate between the iperf client and server. - -The setup looks like the following: - -``` - +--------------+ (native) +--------------+ - | iperf client |[lo @ 10.0.0.1]------>| client proxy | - +--------------+ +--------------+ - [client.0 @ 10.0.0.2] - (netstack) | | (native) - +------+-----+ - | - [br0] - | - Network emulation applied ---> [wan.0:wan.1] - | - [br1] - | - +------+-----+ - (netstack) | | (native) - [server.0 @ 10.0.0.3] - +--------------+ +--------------+ - | iperf server |<------[lo @ 10.0.0.4]| server proxy | - +--------------+ (native) +--------------+ -``` - -Different configurations can be run using different arguments. For example: - -* Native test under normal internet conditions: `tcp_benchmark` -* Native test under ideal conditions: `tcp_benchmark --ideal` -* Netstack client under ideal conditions: `tcp_benchmark --client --ideal` -* Netstack client with 5% packet loss: `tcp_benchmark --client --ideal --loss - 5` - -Use `tcp_benchmark --help` for full arguments. - -This tool may be used to easily generate data for graphing. For example, to -generate a CSV for various latencies, you might do: - -``` -rm -f /tmp/netstack_latency.csv /tmp/native_latency.csv -latencies=$(seq 0 5 50; - seq 60 10 100; - seq 125 25 250; - seq 300 50 500) -for latency in $latencies; do - read throughput client_cpu server_cpu <<< \ - $(./tcp_benchmark --duration 30 --client --ideal --latency $latency) - echo $latency,$throughput,$client_cpu >> /tmp/netstack_latency.csv -done -for latency in $latencies; do - read throughput client_cpu server_cpu <<< \ - $(./tcp_benchmark --duration 30 --ideal --latency $latency) - echo $latency,$throughput,$client_cpu >> /tmp/native_latency.csv -done -``` - -Similarly, to generate a CSV for various levels of packet loss, the following -would be appropriate: - -``` -rm -f /tmp/netstack_loss.csv /tmp/native_loss.csv -losses=$(seq 0 0.1 1.0; - seq 1.2 0.2 2.0; - seq 2.5 0.5 5.0; - seq 6.0 1.0 10.0) -for loss in $losses; do - read throughput client_cpu server_cpu <<< \ - $(./tcp_benchmark --duration 30 --client --ideal --latency 10 --loss $loss) - echo $loss,$throughput,$client_cpu >> /tmp/netstack_loss.csv -done -for loss in $losses; do - read throughput client_cpu server_cpu <<< \ - $(./tcp_benchmark --duration 30 --ideal --latency 10 --loss $loss) - echo $loss,$throughput,$client_cpu >> /tmp/native_loss.csv -done -``` diff --git a/benchmarks/tcp/nsjoin.c b/benchmarks/tcp/nsjoin.c deleted file mode 100644 index 524b4d549..000000000 --- a/benchmarks/tcp/nsjoin.c +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef _GNU_SOURCE -#define _GNU_SOURCE -#endif - -#include <errno.h> -#include <fcntl.h> -#include <sched.h> -#include <stdio.h> -#include <string.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -int main(int argc, char** argv) { - if (argc <= 2) { - fprintf(stderr, "error: must provide a namespace file.\n"); - fprintf(stderr, "usage: %s <file> [arguments...]\n", argv[0]); - return 1; - } - - int fd = open(argv[1], O_RDONLY); - if (fd < 0) { - fprintf(stderr, "error opening %s: %s\n", argv[1], strerror(errno)); - return 1; - } - if (setns(fd, 0) < 0) { - fprintf(stderr, "error joining %s: %s\n", argv[1], strerror(errno)); - return 1; - } - - execvp(argv[2], &argv[2]); - return 1; -} diff --git a/benchmarks/tcp/tcp_benchmark.sh b/benchmarks/tcp/tcp_benchmark.sh deleted file mode 100755 index e65801a7b..000000000 --- a/benchmarks/tcp/tcp_benchmark.sh +++ /dev/null @@ -1,388 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TCP benchmark; see README.md for documentation. - -# Fixed parameters. -iperf_port=45201 # Not likely to be privileged. -proxy_port=44000 # Ditto. -client_addr=10.0.0.1 -client_proxy_addr=10.0.0.2 -server_proxy_addr=10.0.0.3 -server_addr=10.0.0.4 -mask=8 - -# Defaults; this provides a reasonable approximation of a decent internet link. -# Parameters can be varied independently from this set to see response to -# various changes in the kind of link available. -client=false -server=false -verbose=false -gso=0 -swgso=false -mtu=1280 # 1280 is a reasonable lowest-common-denominator. -latency=10 # 10ms approximates a fast, dedicated connection. -latency_variation=1 # +/- 1ms is a relatively low amount of jitter. -loss=0.1 # 0.1% loss is non-zero, but not extremely high. -duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses. -duration=30 # 30s is enough time to consistent results (experimentally). -helper_dir=$(dirname $0) -netstack_opts= -disable_linux_gso= -num_client_threads=1 - -# Check for netem support. -lsmod_output=$(lsmod | grep sch_netem) -if [ "$?" != "0" ]; then - echo "warning: sch_netem may not be installed." >&2 -fi - -while [ $# -gt 0 ]; do - case "$1" in - --client) - client=true - ;; - --client_tcp_probe_file) - shift - netstack_opts="${netstack_opts} -client_tcp_probe_file=$1" - ;; - --server) - server=true - ;; - --verbose) - verbose=true - ;; - --gso) - shift - gso=$1 - ;; - --swgso) - swgso=true - ;; - --server_tcp_probe_file) - shift - netstack_opts="${netstack_opts} -server_tcp_probe_file=$1" - ;; - --ideal) - mtu=1500 # Standard ethernet. - latency=0 # No latency. - latency_variation=0 # No jitter. - loss=0 # No loss. - duplicate=0 # No duplicates. - ;; - --mtu) - shift - [ "$#" -le 0 ] && echo "no mtu provided" && exit 1 - mtu=$1 - ;; - --sack) - netstack_opts="${netstack_opts} -sack" - ;; - --cubic) - netstack_opts="${netstack_opts} -cubic" - ;; - --duration) - shift - [ "$#" -le 0 ] && echo "no duration provided" && exit 1 - duration=$1 - ;; - --latency) - shift - [ "$#" -le 0 ] && echo "no latency provided" && exit 1 - latency=$1 - ;; - --latency-variation) - shift - [ "$#" -le 0 ] && echo "no latency variation provided" && exit 1 - latency_variation=$1 - ;; - --loss) - shift - [ "$#" -le 0 ] && echo "no loss probability provided" && exit 1 - loss=$1 - ;; - --duplicate) - shift - [ "$#" -le 0 ] && echo "no duplicate provided" && exit 1 - duplicate=$1 - ;; - --cpuprofile) - shift - netstack_opts="${netstack_opts} -cpuprofile=$1" - ;; - --memprofile) - shift - netstack_opts="${netstack_opts} -memprofile=$1" - ;; - --disable-linux-gso) - disable_linux_gso=1 - ;; - --num-client-threads) - shift - num_client_threads=$1 - ;; - --helpers) - shift - [ "$#" -le 0 ] && echo "no helper dir provided" && exit 1 - helper_dir=$1 - ;; - *) - echo "usage: $0 [options]" - echo "options:" - echo " --help show this message" - echo " --verbose verbose output" - echo " --client use netstack as the client" - echo " --ideal reset all network emulation" - echo " --server use netstack as the server" - echo " --mtu set the mtu (bytes)" - echo " --sack enable SACK support" - echo " --cubic enable CUBIC congestion control for Netstack" - echo " --duration set the test duration (s)" - echo " --latency set the latency (ms)" - echo " --latency-variation set the latency variation" - echo " --loss set the loss probability (%)" - echo " --duplicate set the duplicate probability (%)" - echo " --helpers set the helper directory" - echo " --num-client-threads number of parallel client threads to run" - echo " --disable-linux-gso disable segmentation offload in the Linux network stack" - echo "" - echo "The output will of the script will be:" - echo " <throughput> <client-cpu-usage> <server-cpu-usage>" - exit 1 - esac - shift -done - -if [ ${verbose} == "true" ]; then - set -x -fi - -# Latency needs to be halved, since it's applied on both ways. -half_latency=$(echo ${latency}/2 | bc -l | awk '{printf "%1.2f", $0}') -half_loss=$(echo ${loss}/2 | bc -l | awk '{printf "%1.6f", $0}') -half_duplicate=$(echo ${duplicate}/2 | bc -l | awk '{printf "%1.6f", $0}') -helper_dir=${helper_dir#$(pwd)/} # Use relative paths. -proxy_binary=${helper_dir}/tcp_proxy -nsjoin_binary=${helper_dir}/nsjoin - -if [ ! -e ${proxy_binary} ]; then - echo "Could not locate ${proxy_binary}, please make sure you've built the binary" - exit 1 -fi - -if [ ! -e ${nsjoin_binary} ]; then - echo "Could not locate ${nsjoin_binary}, please make sure you've built the binary" - exit 1 -fi - -if [ $(echo ${latency_variation} | awk '{printf "%1.2f", $0}') != "0.00" ]; then - # As long as there's some jitter, then we use the paretonormal distribution. - # This will preserve the minimum RTT, but add a realistic amount of jitter to - # the connection and cause re-ordering, etc. The regular pareto distribution - # appears to an unreasonable level of delay (we want only small spikes.) - distribution="distribution paretonormal" -else - distribution="" -fi - -# Client proxy that will listen on the client's iperf target forward traffic -# using the host networking stack. -client_args="${proxy_binary} -port ${proxy_port} -forward ${server_proxy_addr}:${proxy_port}" -if ${client}; then - # Client proxy that will listen on the client's iperf target - # and forward traffic using netstack. - client_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -client \\ - -mtu ${mtu} -iface client.0 -addr ${client_proxy_addr} -mask ${mask} \\ - -forward ${server_proxy_addr}:${proxy_port} -gso=${gso} -swgso=${swgso}" -fi - -# Server proxy that will listen on the proxy port and forward to the server's -# iperf server using the host networking stack. -server_args="${proxy_binary} -port ${proxy_port} -forward ${server_addr}:${iperf_port}" -if ${server}; then - # Server proxy that will listen on the proxy port and forward to the servers' - # iperf server using netstack. - server_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -server \\ - -mtu ${mtu} -iface server.0 -addr ${server_proxy_addr} -mask ${mask} \\ - -forward ${server_addr}:${iperf_port} -gso=${gso} -swgso=${swgso}" -fi - -# Specify loss and duplicate parameters only if they are non-zero -loss_opt="" -if [ "$(echo $half_loss | bc -q)" != "0" ]; then - loss_opt="loss random ${half_loss}%" -fi -duplicate_opt="" -if [ "$(echo $half_duplicate | bc -q)" != "0" ]; then - duplicate_opt="duplicate ${half_duplicate}%" -fi - -exec unshare -U -m -n -r -f -p --mount-proc /bin/bash << EOF -set -e -m - -if [ ${verbose} == "true" ]; then - set -x -fi - -mount -t tmpfs netstack-bench /tmp - -# We may have reset the path in the unshare if the shell loaded some public -# profiles. Ensure that tools are discoverable via the parent's PATH. -export PATH=${PATH} - -# Add client, server interfaces. -ip link add client.0 type veth peer name client.1 -ip link add server.0 type veth peer name server.1 - -# Add network emulation devices. -ip link add wan.0 type veth peer name wan.1 -ip link set wan.0 up -ip link set wan.1 up - -# Enroll on the bridge. -ip link add name br0 type bridge -ip link add name br1 type bridge -ip link set client.1 master br0 -ip link set server.1 master br1 -ip link set wan.0 master br0 -ip link set wan.1 master br1 -ip link set br0 up -ip link set br1 up - -# Set the MTU appropriately. -ip link set client.0 mtu ${mtu} -ip link set server.0 mtu ${mtu} -ip link set wan.0 mtu ${mtu} -ip link set wan.1 mtu ${mtu} - -# Add appropriate latency, loss and duplication. -# -# This is added in at the point of bridge connection. -for device in wan.0 wan.1; do - # NOTE: We don't support a loss correlation as testing has shown that it - # actually doesn't work. The man page actually has a small comment about this - # "It is also possible to add a correlation, but this option is now deprecated - # due to the noticed bad behavior." For more information see netem(8). - tc qdisc add dev \$device root netem \\ - delay ${half_latency}ms ${latency_variation}ms ${distribution} \\ - ${loss_opt} ${duplicate_opt} -done - -# Start a client proxy. -touch /tmp/client.netns -unshare -n mount --bind /proc/self/ns/net /tmp/client.netns - -# Move the endpoint into the namespace. -while ip link | grep client.0 > /dev/null; do - ip link set dev client.0 netns /tmp/client.netns -done - -if ! ${client}; then - # Only add the address to NIC if netstack is not in use. Otherwise the host - # will also process the inbound SYN and send a RST back. - ${nsjoin_binary} /tmp/client.netns ip addr add ${client_proxy_addr}/${mask} dev client.0 -fi - -# Start a server proxy. -touch /tmp/server.netns -unshare -n mount --bind /proc/self/ns/net /tmp/server.netns -# Move the endpoint into the namespace. -while ip link | grep server.0 > /dev/null; do - ip link set dev server.0 netns /tmp/server.netns -done -if ! ${server}; then - # Only add the address to NIC if netstack is not in use. Otherwise the host - # will also process the inbound SYN and send a RST back. - ${nsjoin_binary} /tmp/server.netns ip addr add ${server_proxy_addr}/${mask} dev server.0 -fi - -# Add client and server addresses, and bring everything up. -${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0 -${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0 -if [ "${disable_linux_gso}" == "1" ]; then - ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 tso off - ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gro off - ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gso off - ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 tso off - ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gso off - ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gro off -fi -${nsjoin_binary} /tmp/client.netns ip link set client.0 up -${nsjoin_binary} /tmp/client.netns ip link set lo up -${nsjoin_binary} /tmp/server.netns ip link set server.0 up -${nsjoin_binary} /tmp/server.netns ip link set lo up -ip link set dev client.1 up -ip link set dev server.1 up - -${nsjoin_binary} /tmp/client.netns ${client_args} & -client_pid=\$! -${nsjoin_binary} /tmp/server.netns ${server_args} & -server_pid=\$! - -# Start the iperf server. -${nsjoin_binary} /tmp/server.netns iperf -p ${iperf_port} -s >&2 & -iperf_pid=\$! - -# Show traffic information. -if ! ${client} && ! ${server}; then - ${nsjoin_binary} /tmp/client.netns ping -c 100 -i 0.001 -W 1 ${server_addr} >&2 || true -fi - -results_file=\$(mktemp) -function cleanup { - rm -f \$results_file - kill -TERM \$client_pid - kill -TERM \$server_pid - wait \$client_pid - wait \$server_pid - kill -9 \$iperf_pid 2>/dev/null -} - -# Allow failure from this point. -set +e -trap cleanup EXIT - -# Run the benchmark, recording the results file. -while ${nsjoin_binary} /tmp/client.netns iperf \\ - -p ${proxy_port} -c ${client_addr} -t ${duration} -f m -P ${num_client_threads} 2>&1 \\ - | tee \$results_file \\ - | grep "connect failed" >/dev/null; do - sleep 0.1 # Wait for all services. -done - -# Unlink all relevant devices from the bridge. This is because when the bridge -# is deleted, the kernel may hang. It appears that this problem is fixed in -# upstream commit 1ce5cce895309862d2c35d922816adebe094fe4a. -ip link set client.1 nomaster -ip link set server.1 nomaster -ip link set wan.0 nomaster -ip link set wan.1 nomaster - -# Emit raw results. -cat \$results_file >&2 - -# Emit a useful result (final throughput). -mbits=\$(grep Mbits/sec \$results_file \\ - | sed -n -e 's/^.*[[:space:]]\\([[:digit:]]\\+\\(\\.[[:digit:]]\\+\\)\\?\\)[[:space:]]*Mbits\\/sec.*/\\1/p') -client_cpu_ticks=\$(cat /proc/\$client_pid/stat \\ - | awk '{print (\$14+\$15);}') -server_cpu_ticks=\$(cat /proc/\$server_pid/stat \\ - | awk '{print (\$14+\$15);}') -ticks_per_sec=\$(getconf CLK_TCK) -client_cpu_load=\$(bc -l <<< \$client_cpu_ticks/\$ticks_per_sec/${duration}) -server_cpu_load=\$(bc -l <<< \$server_cpu_ticks/\$ticks_per_sec/${duration}) -echo \$mbits \$client_cpu_load \$server_cpu_load -EOF diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go deleted file mode 100644 index 73b7c4f5b..000000000 --- a/benchmarks/tcp/tcp_proxy.go +++ /dev/null @@ -1,444 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary tcp_proxy is a simple TCP proxy. -package main - -import ( - "encoding/gob" - "flag" - "fmt" - "io" - "log" - "math/rand" - "net" - "os" - "os/signal" - "regexp" - "runtime" - "runtime/pprof" - "strconv" - "syscall" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -var ( - port = flag.Int("port", 0, "bind port (all addresses)") - forward = flag.String("forward", "", "forwarding target") - client = flag.Bool("client", false, "use netstack for listen") - server = flag.Bool("server", false, "use netstack for dial") - - // Netstack-specific options. - mtu = flag.Int("mtu", 1280, "mtu for network stack") - addr = flag.String("addr", "", "address for tap-based netstack") - mask = flag.Int("mask", 8, "mask size for address") - iface = flag.String("iface", "", "network interface name to bind for netstack") - sack = flag.Bool("sack", false, "enable SACK support for netstack") - cubic = flag.Bool("cubic", false, "enable use of CUBIC congestion control for netstack") - gso = flag.Int("gso", 0, "GSO maximum size") - swgso = flag.Bool("swgso", false, "software-level GSO") - clientTCPProbeFile = flag.String("client_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") - serverTCPProbeFile = flag.String("server_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") - cpuprofile = flag.String("cpuprofile", "", "write cpu profile to the specified file.") - memprofile = flag.String("memprofile", "", "write memory profile to the specified file.") -) - -type impl interface { - dial(address string) (net.Conn, error) - listen(port int) (net.Listener, error) - printStats() -} - -type netImpl struct{} - -func (netImpl) dial(address string) (net.Conn, error) { - return net.Dial("tcp", address) -} - -func (netImpl) listen(port int) (net.Listener, error) { - return net.Listen("tcp", fmt.Sprintf(":%d", port)) -} - -func (netImpl) printStats() { -} - -const ( - nicID = 1 // Fixed. - bufSize = 4 << 20 // 4MB. -) - -type netstackImpl struct { - s *stack.Stack - addr tcpip.Address - mode string -} - -func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { - // Get all interfaces in the namespace. - ifaces, err := net.Interfaces() - if err != nil { - return nil, fmt.Errorf("querying interfaces: %v", err) - } - - for _, iface := range ifaces { - if iface.Name != ifaceName { - continue - } - // Create the socket. - const protocol = 0x0300 // htons(ETH_P_ALL) - fds := make([]int, numChannels) - for i := range fds { - fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) - if err != nil { - return nil, fmt.Errorf("unable to create raw socket: %v", err) - } - - // Bind to the appropriate device. - ll := syscall.SockaddrLinklayer{ - Protocol: protocol, - Ifindex: iface.Index, - Pkttype: syscall.PACKET_HOST, - } - if err := syscall.Bind(fd, &ll); err != nil { - return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) - } - - // RAW Sockets by default have a very small SO_RCVBUF of 256KB, - // up it to at least 4MB to reduce packet drops. - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil { - return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err) - } - - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil { - return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err) - } - - if !*swgso && *gso != 0 { - if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { - return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) - } - } - fds[i] = fd - } - return fds, nil - } - return nil, fmt.Errorf("failed to find interface: %v", ifaceName) -} - -func newNetstackImpl(mode string) (impl, error) { - fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1)) - if err != nil { - return nil, err - } - - // Parse details. - parsedAddr := tcpip.Address(net.ParseIP(*addr).To4()) - parsedDest := tcpip.Address("") // Filled in below. - parsedMask := tcpip.AddressMask("") // Filled in below. - switch *mask { - case 8: - parsedDest = tcpip.Address([]byte{parsedAddr[0], 0, 0, 0}) - parsedMask = tcpip.AddressMask([]byte{0xff, 0, 0, 0}) - case 16: - parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], 0, 0}) - parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0, 0}) - case 24: - parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], parsedAddr[2], 0}) - parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0xff, 0}) - default: - // This is just laziness; we don't expect a different mask. - return nil, fmt.Errorf("mask %d not supported", mask) - } - - // Create a new network stack. - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()} - s := stack.New(stack.Options{ - NetworkProtocols: netProtos, - TransportProtocols: transProtos, - }) - - // Generate a new mac for the eth device. - mac := make(net.HardwareAddr, 6) - rand.Read(mac) // Fill with random data. - mac[0] &^= 0x1 // Clear multicast bit. - mac[0] |= 0x2 // Set local assignment bit (IEEE802). - ep, err := fdbased.New(&fdbased.Options{ - FDs: fds, - MTU: uint32(*mtu), - EthernetHeader: true, - Address: tcpip.LinkAddress(mac), - // Enable checksum generation as we need to generate valid - // checksums for the veth device to deliver our packets to the - // peer. But we do want to disable checksum verification as veth - // devices do perform GRO and the linux host kernel may not - // regenerate valid checksums after GRO. - TXChecksumOffload: false, - RXChecksumOffload: true, - PacketDispatchMode: fdbased.RecvMMsg, - GSOMaxSize: uint32(*gso), - SoftwareGSOEnabled: *swgso, - }) - if err != nil { - return nil, fmt.Errorf("failed to create FD endpoint: %v", err) - } - if err := s.CreateNIC(nicID, ep); err != nil { - return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err) - } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - return nil, fmt.Errorf("error adding ARP address to %q: %v", *iface, err) - } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil { - return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err) - } - - subnet, err := tcpip.NewSubnet(parsedDest, parsedMask) - if err != nil { - return nil, fmt.Errorf("tcpip.Subnet(%s, %s): %s", parsedDest, parsedMask, err) - } - // Add default route; we only support - s.SetRouteTable([]tcpip.Route{ - { - Destination: subnet, - NIC: nicID, - }, - }) - - // Set protocol options. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %v", err) - } - - // Set Congestion Control to cubic if requested. - if *cubic { - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %v", err) - } - } - - return netstackImpl{ - s: s, - addr: parsedAddr, - mode: mode, - }, nil -} - -func (n netstackImpl) dial(address string) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - return nil, err - } - if host == "" { - // A host must be provided for the dial. - return nil, fmt.Errorf("no host provided") - } - portNumber, err := strconv.Atoi(port) - if err != nil { - return nil, err - } - addr := tcpip.FullAddress{ - NIC: nicID, - Addr: tcpip.Address(net.ParseIP(host).To4()), - Port: uint16(portNumber), - } - conn, err := gonet.DialTCP(n.s, addr, ipv4.ProtocolNumber) - if err != nil { - return nil, err - } - return conn, nil -} - -func (n netstackImpl) listen(port int) (net.Listener, error) { - addr := tcpip.FullAddress{ - NIC: nicID, - Port: uint16(port), - } - listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber) - if err != nil { - return nil, err - } - return listener, nil -} - -var zeroFieldsRegexp = regexp.MustCompile(`\s*[a-zA-Z0-9]*:0`) - -func (n netstackImpl) printStats() { - // Don't show zero fields. - stats := zeroFieldsRegexp.ReplaceAllString(fmt.Sprintf("%+v", n.s.Stats()), "") - log.Printf("netstack %s Stats: %+v\n", n.mode, stats) -} - -// installProbe installs a TCP Probe function that will dump endpoint -// state to the specified file. It also returns a close func() that -// can be used to close the probeFile. -func (n netstackImpl) installProbe(probeFileName string) (close func()) { - // Install Probe to dump out end point state. - probeFile, err := os.Create(probeFileName) - if err != nil { - log.Fatalf("failed to create tcp_probe file %s: %v", probeFileName, err) - } - probeEncoder := gob.NewEncoder(probeFile) - // Install a TCP Probe. - n.s.AddTCPProbe(func(state stack.TCPEndpointState) { - probeEncoder.Encode(state) - }) - return func() { probeFile.Close() } -} - -func main() { - flag.Parse() - if *port == 0 { - log.Fatalf("no port provided") - } - if *forward == "" { - log.Fatalf("no forward provided") - } - // Seed the random number generator to ensure that we are given MAC addresses that don't - // for the case of the client and server stack. - rand.Seed(time.Now().UTC().UnixNano()) - - if *cpuprofile != "" { - f, err := os.Create(*cpuprofile) - if err != nil { - log.Fatal("could not create CPU profile: ", err) - } - defer func() { - if err := f.Close(); err != nil { - log.Print("error closing CPU profile: ", err) - } - }() - if err := pprof.StartCPUProfile(f); err != nil { - log.Fatal("could not start CPU profile: ", err) - } - defer pprof.StopCPUProfile() - } - - var ( - in impl - out impl - err error - ) - if *server { - in, err = newNetstackImpl("server") - if *serverTCPProbeFile != "" { - defer in.(netstackImpl).installProbe(*serverTCPProbeFile)() - } - - } else { - in = netImpl{} - } - if err != nil { - log.Fatalf("netstack error: %v", err) - } - if *client { - out, err = newNetstackImpl("client") - if *clientTCPProbeFile != "" { - defer out.(netstackImpl).installProbe(*clientTCPProbeFile)() - } - } else { - out = netImpl{} - } - if err != nil { - log.Fatalf("netstack error: %v", err) - } - - // Dial forward before binding. - var next net.Conn - for { - next, err = out.dial(*forward) - if err == nil { - break - } - time.Sleep(50 * time.Millisecond) - log.Printf("connect failed retrying: %v", err) - } - - // Bind once to the server socket. - listener, err := in.listen(*port) - if err != nil { - // Should not happen, everything must be bound by this time - // this proxy is started. - log.Fatalf("unable to listen: %v", err) - } - log.Printf("client=%v, server=%v, ready.", *client, *server) - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM) - go func() { - <-sigs - if *cpuprofile != "" { - pprof.StopCPUProfile() - } - if *memprofile != "" { - f, err := os.Create(*memprofile) - if err != nil { - log.Fatal("could not create memory profile: ", err) - } - defer func() { - if err := f.Close(); err != nil { - log.Print("error closing memory profile: ", err) - } - }() - runtime.GC() // get up-to-date statistics - if err := pprof.WriteHeapProfile(f); err != nil { - log.Fatalf("Unable to write heap profile: %v", err) - } - } - os.Exit(0) - }() - - for { - // Forward all connections. - inConn, err := listener.Accept() - if err != nil { - // This should not happen; we are listening - // successfully. Exhausted all available FDs? - log.Fatalf("accept error: %v", err) - } - log.Printf("incoming connection established.") - - // Copy both ways. - go io.Copy(inConn, next) - go io.Copy(next, inConn) - - // Print stats every second. - go func() { - t := time.NewTicker(time.Second) - defer t.Stop() - for { - <-t.C - in.printStats() - out.printStats() - } - }() - - for { - // Dial again. - next, err = out.dial(*forward) - if err == nil { - break - } - } - } -} diff --git a/benchmarks/workloads/BUILD b/benchmarks/workloads/BUILD deleted file mode 100644 index ccb86af5b..000000000 --- a/benchmarks/workloads/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "workloads", - srcs = ["__init__.py"], -) - -filegroup( - name = "files", - srcs = [ - "//benchmarks/workloads/ab:tar", - "//benchmarks/workloads/absl:tar", - "//benchmarks/workloads/curl:tar", - "//benchmarks/workloads/ffmpeg:tar", - "//benchmarks/workloads/fio:tar", - "//benchmarks/workloads/httpd:tar", - "//benchmarks/workloads/iperf:tar", - "//benchmarks/workloads/netcat:tar", - "//benchmarks/workloads/nginx:tar", - "//benchmarks/workloads/node:tar", - "//benchmarks/workloads/node_template:tar", - "//benchmarks/workloads/redis:tar", - "//benchmarks/workloads/redisbenchmark:tar", - "//benchmarks/workloads/ruby:tar", - "//benchmarks/workloads/ruby_template:tar", - "//benchmarks/workloads/sleep:tar", - "//benchmarks/workloads/sysbench:tar", - "//benchmarks/workloads/syscall:tar", - "//benchmarks/workloads/tensorflow:tar", - "//benchmarks/workloads/true:tar", - ], -) diff --git a/benchmarks/workloads/__init__.py b/benchmarks/workloads/__init__.py deleted file mode 100644 index e12651e76..000000000 --- a/benchmarks/workloads/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Workloads, parsers and test data.""" diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD deleted file mode 100644 index 945ac7026..000000000 --- a/benchmarks/workloads/ab/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "ab", - srcs = ["__init__.py"], -) - -py_test( - name = "ab_test", - srcs = ["ab_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":ab", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/ab/Dockerfile b/benchmarks/workloads/ab/Dockerfile deleted file mode 100644 index 0d0b6e2eb..000000000 --- a/benchmarks/workloads/ab/Dockerfile +++ /dev/null @@ -1,15 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - apache2-utils \ - && rm -rf /var/lib/apt/lists/* - -# Parameterized workload. -ENV requests 5000 -ENV connections 10 -ENV host localhost -ENV port 8080 -ENV path notfound -CMD ["sh", "-c", "ab -n ${requests} -c ${connections} http://${host}:${port}/${path}"] diff --git a/benchmarks/workloads/ab/__init__.py b/benchmarks/workloads/ab/__init__.py deleted file mode 100644 index eedf8e083..000000000 --- a/benchmarks/workloads/ab/__init__.py +++ /dev/null @@ -1,88 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Apachebench tool.""" - -import re - -SAMPLE_DATA = """This is ApacheBench, Version 2.3 <$Revision: 1826891 $> -Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ -Licensed to The Apache Software Foundation, http://www.apache.org/ - -Benchmarking 10.10.10.10 (be patient).....done - - -Server Software: Apache/2.4.38 -Server Hostname: 10.10.10.10 -Server Port: 80 - -Document Path: /latin10k.txt -Document Length: 210 bytes - -Concurrency Level: 1 -Time taken for tests: 0.180 seconds -Complete requests: 100 -Failed requests: 0 -Non-2xx responses: 100 -Total transferred: 38800 bytes -HTML transferred: 21000 bytes -Requests per second: 556.44 [#/sec] (mean) -Time per request: 1.797 [ms] (mean) -Time per request: 1.797 [ms] (mean, across all concurrent requests) -Transfer rate: 210.84 [Kbytes/sec] received - -Connection Times (ms) - min mean[+/-sd] median max -Connect: 0 0 0.2 0 2 -Processing: 1 2 1.0 1 8 -Waiting: 1 1 1.0 1 7 -Total: 1 2 1.2 1 10 - -Percentage of the requests served within a certain time (ms) - 50% 1 - 66% 2 - 75% 2 - 80% 2 - 90% 2 - 95% 3 - 98% 7 - 99% 10 - 100% 10 (longest request)""" - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# pylint: disable=unused-argument -def transfer_rate(data: str, **kwargs) -> float: - """Mean transfer rate in Kbytes/sec.""" - regex = r"Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received" - return float(re.compile(regex).search(data).group(1)) - - -# pylint: disable=unused-argument -def latency(data: str, **kwargs) -> float: - """Mean latency in milliseconds.""" - regex = r"Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s" - res = re.compile(regex).search(data) - return float(res.group(1)) - - -# pylint: disable=unused-argument -def requests_per_second(data: str, **kwargs) -> float: - """Requests per second.""" - regex = r"Requests per second:\s+(\d+\.?\d+?)\s+" - res = re.compile(regex).search(data) - return float(res.group(1)) diff --git a/benchmarks/workloads/ab/ab_test.py b/benchmarks/workloads/ab/ab_test.py deleted file mode 100644 index 4afac2996..000000000 --- a/benchmarks/workloads/ab/ab_test.py +++ /dev/null @@ -1,42 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Parser test.""" - -import sys - -import pytest - -from benchmarks.workloads import ab - - -def test_transfer_rate_parser(): - """Test transfer rate parser.""" - res = ab.transfer_rate(ab.sample()) - assert res == 210.84 - - -def test_latency_parser(): - """Test latency parser.""" - res = ab.latency(ab.sample()) - assert res == 2 - - -def test_requests_per_second(): - """Test requests per second parser.""" - res = ab.requests_per_second(ab.sample()) - assert res == 556.44 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD deleted file mode 100644 index bb1a308bf..000000000 --- a/benchmarks/workloads/absl/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "absl", - srcs = ["__init__.py"], -) - -py_test( - name = "absl_test", - srcs = ["absl_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":absl", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/absl/Dockerfile b/benchmarks/workloads/absl/Dockerfile deleted file mode 100644 index e935c5ddc..000000000 --- a/benchmarks/workloads/absl/Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - wget \ - git \ - pkg-config \ - zip \ - g++ \ - zlib1g-dev \ - unzip \ - python3 \ - && rm -rf /var/lib/apt/lists/* -RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh -RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh -RUN ./bazel-0.27.0-installer-linux-x86_64.sh - -RUN git clone https://github.com/abseil/abseil-cpp.git -WORKDIR abseil-cpp -RUN git checkout 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f -RUN bazel clean -ENV path "absl/base/..." -CMD bazel build ${path} 2>&1 diff --git a/benchmarks/workloads/absl/__init__.py b/benchmarks/workloads/absl/__init__.py deleted file mode 100644 index b40e3f915..000000000 --- a/benchmarks/workloads/absl/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""ABSL build benchmark.""" - -import re - -SAMPLE_BAZEL_OUTPUT = """Extracting Bazel installation... -Starting local Bazel server and connecting to it... -Loading: -Loading: 0 packages loaded -Loading: 0 packages loaded - currently loading: absl/algorithm ... (11 packages) -Analyzing: 241 targets (16 packages loaded, 0 targets configured) -Analyzing: 241 targets (21 packages loaded, 617 targets configured) -Analyzing: 241 targets (27 packages loaded, 687 targets configured) -Analyzing: 241 targets (32 packages loaded, 1105 targets configured) -Analyzing: 241 targets (32 packages loaded, 1294 targets configured) -Analyzing: 241 targets (35 packages loaded, 1575 targets configured) -Analyzing: 241 targets (35 packages loaded, 1575 targets configured) -Analyzing: 241 targets (36 packages loaded, 1603 targets configured) -Analyzing: 241 targets (36 packages loaded, 1603 targets configured) -INFO: Analyzed 241 targets (37 packages loaded, 1864 targets configured). -INFO: Found 241 targets... -[0 / 5] [Prepa] BazelWorkspaceStatusAction stable-status.txt -[16 / 50] [Analy] Compiling absl/base/dynamic_annotations.cc ... (20 actions, 10 running) -[60 / 77] Compiling external/com_google_googletest/googletest/src/gtest.cc; 5s processwrapper-sandbox ... (12 actions, 11 running) -[158 / 174] Compiling absl/container/internal/raw_hash_set_test.cc; 2s processwrapper-sandbox ... (12 actions, 11 running) -[278 / 302] Compiling absl/container/internal/raw_hash_set_test.cc; 6s processwrapper-sandbox ... (12 actions, 11 running) -[384 / 406] Compiling absl/container/internal/raw_hash_set_test.cc; 10s processwrapper-sandbox ... (12 actions, 11 running) -[581 / 604] Compiling absl/container/flat_hash_set_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running) -[722 / 745] Compiling absl/container/node_hash_set_test.cc; 9s processwrapper-sandbox ... (12 actions, 11 running) -[846 / 867] Compiling absl/hash/hash_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running) -INFO: From Compiling absl/debugging/symbolize_test.cc: -/tmp/cclCVipU.s: Assembler messages: -/tmp/cclCVipU.s:1662: Warning: ignoring changed section attributes for .text -[999 / 1,022] Compiling absl/hash/hash_test.cc; 19s processwrapper-sandbox ... (12 actions, 11 running) -[1,082 / 1,084] Compiling absl/container/flat_hash_map_test.cc; 7s processwrapper-sandbox -INFO: Elapsed time: 81.861s, Critical Path: 23.81s -INFO: 515 processes: 515 processwrapper-sandbox. -INFO: Build completed successfully, 1084 total actions -INFO: Build completed successfully, 1084 total actions""" - - -def sample(): - return SAMPLE_BAZEL_OUTPUT - - -# pylint: disable=unused-argument -def elapsed_time(data: str, **kwargs) -> float: - """Returns the elapsed time for running an absl build.""" - return float(re.compile(r"Elapsed time: (\d*.?\d*)s").search(data).group(1)) diff --git a/benchmarks/workloads/absl/absl_test.py b/benchmarks/workloads/absl/absl_test.py deleted file mode 100644 index 41f216999..000000000 --- a/benchmarks/workloads/absl/absl_test.py +++ /dev/null @@ -1,31 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""ABSL build test.""" - -import sys - -import pytest - -from benchmarks.workloads import absl - - -def test_elapsed_time(): - """Test elapsed_time.""" - res = absl.elapsed_time(absl.sample()) - assert res == 81.861 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/curl/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/curl/Dockerfile b/benchmarks/workloads/curl/Dockerfile deleted file mode 100644 index 336cb088a..000000000 --- a/benchmarks/workloads/curl/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - curl \ - && rm -rf /var/lib/apt/lists/* - -# Accept a host and port parameter. -ENV host localhost -ENV port 8080 - -# Spin until we make a successful request. -CMD ["sh", "-c", "while ! curl -v -i http://$host:$port; do true; done"] diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD deleted file mode 100644 index 7c41ba631..000000000 --- a/benchmarks/workloads/ffmpeg/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "ffmpeg", - srcs = ["__init__.py"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/ffmpeg/Dockerfile b/benchmarks/workloads/ffmpeg/Dockerfile deleted file mode 100644 index f2f530d7c..000000000 --- a/benchmarks/workloads/ffmpeg/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - ffmpeg \ - && rm -rf /var/lib/apt/lists/* -WORKDIR /media -ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4 -CMD ["ffmpeg", "-i", "video.mp4", "-c:v", "libx264", "-preset", "veryslow", "output.mp4"] diff --git a/benchmarks/workloads/ffmpeg/__init__.py b/benchmarks/workloads/ffmpeg/__init__.py deleted file mode 100644 index 7578a443b..000000000 --- a/benchmarks/workloads/ffmpeg/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Simple ffmpeg workload.""" - - -# pylint: disable=unused-argument -def run_time(value, **kwargs): - """Returns the startup and runtime of the ffmpeg workload in seconds.""" - return value diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD deleted file mode 100644 index 24d909c53..000000000 --- a/benchmarks/workloads/fio/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "fio", - srcs = ["__init__.py"], -) - -py_test( - name = "fio_test", - srcs = ["fio_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":fio", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/fio/Dockerfile b/benchmarks/workloads/fio/Dockerfile deleted file mode 100644 index b3cf864eb..000000000 --- a/benchmarks/workloads/fio/Dockerfile +++ /dev/null @@ -1,23 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - fio \ - && rm -rf /var/lib/apt/lists/* - -# Parameterized test. -ENV test write -ENV ioengine sync -ENV size 5000000 -ENV iodepth 4 -ENV blocksize "1m" -ENV time "" -ENV path "/disk/file.dat" -ENV ramp_time 0 - -CMD ["sh", "-c", "fio --output-format=json --name=test --ramp_time=${ramp_time} --ioengine=${ioengine} --size=${size} \ ---filename=${path} --iodepth=${iodepth} --bs=${blocksize} --rw=${test} ${time}"] - - - diff --git a/benchmarks/workloads/fio/__init__.py b/benchmarks/workloads/fio/__init__.py deleted file mode 100644 index 52711e956..000000000 --- a/benchmarks/workloads/fio/__init__.py +++ /dev/null @@ -1,369 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""FIO benchmark tool.""" - -import json - -SAMPLE_DATA = """ -{ - "fio version" : "fio-3.1", - "timestamp" : 1554837456, - "timestamp_ms" : 1554837456621, - "time" : "Tue Apr 9 19:17:36 2019", - "jobs" : [ - { - "jobname" : "test", - "groupid" : 0, - "error" : 0, - "eta" : 2147483647, - "elapsed" : 1, - "job options" : { - "name" : "test", - "ioengine" : "sync", - "size" : "1073741824", - "filename" : "/disk/file.dat", - "iodepth" : "4", - "bs" : "4096", - "rw" : "write" - }, - "read" : { - "io_bytes" : 0, - "io_kbytes" : 0, - "bw" : 0, - "iops" : 0.000000, - "runtime" : 0, - "total_ios" : 0, - "short_ios" : 0, - "drop_ios" : 0, - "slat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "clat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000, - "percentile" : { - "1.000000" : 0, - "5.000000" : 0, - "10.000000" : 0, - "20.000000" : 0, - "30.000000" : 0, - "40.000000" : 0, - "50.000000" : 0, - "60.000000" : 0, - "70.000000" : 0, - "80.000000" : 0, - "90.000000" : 0, - "95.000000" : 0, - "99.000000" : 0, - "99.500000" : 0, - "99.900000" : 0, - "99.950000" : 0, - "99.990000" : 0, - "0.00" : 0, - "0.00" : 0, - "0.00" : 0 - } - }, - "lat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "bw_min" : 0, - "bw_max" : 0, - "bw_agg" : 0.000000, - "bw_mean" : 0.000000, - "bw_dev" : 0.000000, - "bw_samples" : 0, - "iops_min" : 0, - "iops_max" : 0, - "iops_mean" : 0.000000, - "iops_stddev" : 0.000000, - "iops_samples" : 0 - }, - "write" : { - "io_bytes" : 1073741824, - "io_kbytes" : 1048576, - "bw" : 1753471, - "iops" : 438367.892977, - "runtime" : 598, - "total_ios" : 262144, - "short_ios" : 0, - "drop_ios" : 0, - "slat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "clat_ns" : { - "min" : 1693, - "max" : 754733, - "mean" : 2076.404373, - "stddev" : 1724.195529, - "percentile" : { - "1.000000" : 1736, - "5.000000" : 1752, - "10.000000" : 1768, - "20.000000" : 1784, - "30.000000" : 1800, - "40.000000" : 1800, - "50.000000" : 1816, - "60.000000" : 1816, - "70.000000" : 1848, - "80.000000" : 1928, - "90.000000" : 2512, - "95.000000" : 2992, - "99.000000" : 6176, - "99.500000" : 6304, - "99.900000" : 11328, - "99.950000" : 15168, - "99.990000" : 17792, - "0.00" : 0, - "0.00" : 0, - "0.00" : 0 - } - }, - "lat_ns" : { - "min" : 1731, - "max" : 754770, - "mean" : 2117.878979, - "stddev" : 1730.290512 - }, - "bw_min" : 1731120, - "bw_max" : 1731120, - "bw_agg" : 98.725328, - "bw_mean" : 1731120.000000, - "bw_dev" : 0.000000, - "bw_samples" : 1, - "iops_min" : 432780, - "iops_max" : 432780, - "iops_mean" : 432780.000000, - "iops_stddev" : 0.000000, - "iops_samples" : 1 - }, - "trim" : { - "io_bytes" : 0, - "io_kbytes" : 0, - "bw" : 0, - "iops" : 0.000000, - "runtime" : 0, - "total_ios" : 0, - "short_ios" : 0, - "drop_ios" : 0, - "slat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "clat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000, - "percentile" : { - "1.000000" : 0, - "5.000000" : 0, - "10.000000" : 0, - "20.000000" : 0, - "30.000000" : 0, - "40.000000" : 0, - "50.000000" : 0, - "60.000000" : 0, - "70.000000" : 0, - "80.000000" : 0, - "90.000000" : 0, - "95.000000" : 0, - "99.000000" : 0, - "99.500000" : 0, - "99.900000" : 0, - "99.950000" : 0, - "99.990000" : 0, - "0.00" : 0, - "0.00" : 0, - "0.00" : 0 - } - }, - "lat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "bw_min" : 0, - "bw_max" : 0, - "bw_agg" : 0.000000, - "bw_mean" : 0.000000, - "bw_dev" : 0.000000, - "bw_samples" : 0, - "iops_min" : 0, - "iops_max" : 0, - "iops_mean" : 0.000000, - "iops_stddev" : 0.000000, - "iops_samples" : 0 - }, - "usr_cpu" : 17.922948, - "sys_cpu" : 81.574539, - "ctx" : 3, - "majf" : 0, - "minf" : 10, - "iodepth_level" : { - "1" : 100.000000, - "2" : 0.000000, - "4" : 0.000000, - "8" : 0.000000, - "16" : 0.000000, - "32" : 0.000000, - ">=64" : 0.000000 - }, - "latency_ns" : { - "2" : 0.000000, - "4" : 0.000000, - "10" : 0.000000, - "20" : 0.000000, - "50" : 0.000000, - "100" : 0.000000, - "250" : 0.000000, - "500" : 0.000000, - "750" : 0.000000, - "1000" : 0.000000 - }, - "latency_us" : { - "2" : 82.737350, - "4" : 12.605286, - "10" : 4.543686, - "20" : 0.107956, - "50" : 0.010000, - "100" : 0.000000, - "250" : 0.000000, - "500" : 0.000000, - "750" : 0.000000, - "1000" : 0.010000 - }, - "latency_ms" : { - "2" : 0.000000, - "4" : 0.000000, - "10" : 0.000000, - "20" : 0.000000, - "50" : 0.000000, - "100" : 0.000000, - "250" : 0.000000, - "500" : 0.000000, - "750" : 0.000000, - "1000" : 0.000000, - "2000" : 0.000000, - ">=2000" : 0.000000 - }, - "latency_depth" : 4, - "latency_target" : 0, - "latency_percentile" : 100.000000, - "latency_window" : 0 - } - ], - "disk_util" : [ - { - "name" : "dm-1", - "read_ios" : 0, - "write_ios" : 3, - "read_merges" : 0, - "write_merges" : 0, - "read_ticks" : 0, - "write_ticks" : 0, - "in_queue" : 0, - "util" : 0.000000, - "aggr_read_ios" : 0, - "aggr_write_ios" : 3, - "aggr_read_merges" : 0, - "aggr_write_merge" : 0, - "aggr_read_ticks" : 0, - "aggr_write_ticks" : 0, - "aggr_in_queue" : 0, - "aggr_util" : 0.000000 - }, - { - "name" : "dm-0", - "read_ios" : 0, - "write_ios" : 3, - "read_merges" : 0, - "write_merges" : 0, - "read_ticks" : 0, - "write_ticks" : 0, - "in_queue" : 0, - "util" : 0.000000, - "aggr_read_ios" : 0, - "aggr_write_ios" : 3, - "aggr_read_merges" : 0, - "aggr_write_merge" : 0, - "aggr_read_ticks" : 0, - "aggr_write_ticks" : 2, - "aggr_in_queue" : 0, - "aggr_util" : 0.000000 - }, - { - "name" : "nvme0n1", - "read_ios" : 0, - "write_ios" : 3, - "read_merges" : 0, - "write_merges" : 0, - "read_ticks" : 0, - "write_ticks" : 2, - "in_queue" : 0, - "util" : 0.000000 - } - ] -} -""" - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# pylint: disable=unused-argument -def read_bandwidth(data: str, **kwargs) -> int: - """File I/O bandwidth.""" - return json.loads(data)["jobs"][0]["read"]["bw"] * 1024 - - -# pylint: disable=unused-argument -def write_bandwidth(data: str, **kwargs) -> int: - """File I/O bandwidth.""" - return json.loads(data)["jobs"][0]["write"]["bw"] * 1024 - - -# pylint: disable=unused-argument -def read_io_ops(data: str, **kwargs) -> float: - """File I/O operations per second.""" - return float(json.loads(data)["jobs"][0]["read"]["iops"]) - - -# pylint: disable=unused-argument -def write_io_ops(data: str, **kwargs) -> float: - """File I/O operations per second.""" - return float(json.loads(data)["jobs"][0]["write"]["iops"]) - - -# Change function names so we just print "bandwidth" and "io_ops". -read_bandwidth.__name__ = "bandwidth" -write_bandwidth.__name__ = "bandwidth" -read_io_ops.__name__ = "io_ops" -write_io_ops.__name__ = "io_ops" diff --git a/benchmarks/workloads/fio/fio_test.py b/benchmarks/workloads/fio/fio_test.py deleted file mode 100644 index 04a6eeb7e..000000000 --- a/benchmarks/workloads/fio/fio_test.py +++ /dev/null @@ -1,44 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Parser tests.""" - -import sys - -import pytest - -from benchmarks.workloads import fio - - -def test_read_io_ops(): - """Test read ops parser.""" - assert fio.read_io_ops(fio.sample()) == 0.0 - - -def test_write_io_ops(): - """Test write ops parser.""" - assert fio.write_io_ops(fio.sample()) == 438367.892977 - - -def test_read_bandwidth(): - """Test read bandwidth parser.""" - assert fio.read_bandwidth(fio.sample()) == 0.0 - - -def test_write_bandwith(): - """Test write bandwidth parser.""" - assert fio.write_bandwidth(fio.sample()) == 1753471 * 1024 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/httpd/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/httpd/Dockerfile b/benchmarks/workloads/httpd/Dockerfile deleted file mode 100644 index 5259c8f4f..000000000 --- a/benchmarks/workloads/httpd/Dockerfile +++ /dev/null @@ -1,27 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - apache2 \ - && rm -rf /var/lib/apt/lists/* - -# Link the htdoc directory to tmp. -RUN mkdir -p /usr/local/apache2/htdocs && \ - cd /usr/local/apache2 && ln -s /tmp htdocs - -# Generate a bunch of relevant files. -RUN mkdir -p /local && \ - for size in 1 10 100 1000 1024 10240; do \ - dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ - done - -# Standard settings. -ENV APACHE_RUN_DIR /tmp -ENV APACHE_RUN_USER nobody -ENV APACHE_RUN_GROUP nogroup -ENV APACHE_LOG_DIR /tmp -ENV APACHE_PID_FILE /tmp/apache.pid - -# Copy on start-up; serve everything from /tmp (including the configuration). -CMD ["sh", "-c", "cp -a /local/* /tmp && apache2 -c \"ServerName localhost\" -c \"DocumentRoot /tmp\" -X"] diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD deleted file mode 100644 index 91b953718..000000000 --- a/benchmarks/workloads/iperf/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "iperf", - srcs = ["__init__.py"], -) - -py_test( - name = "iperf_test", - srcs = ["iperf_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":iperf", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/iperf/Dockerfile b/benchmarks/workloads/iperf/Dockerfile deleted file mode 100644 index 9704c506c..000000000 --- a/benchmarks/workloads/iperf/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - iperf \ - && rm -rf /var/lib/apt/lists/* - -# Accept a host parameter. -ENV host "" -ENV port 5001 - -# Start as client if the host is provided. -CMD ["sh", "-c", "test -z \"${host}\" && iperf -s || iperf -f K --realtime -c ${host} -p ${port}"] diff --git a/benchmarks/workloads/iperf/__init__.py b/benchmarks/workloads/iperf/__init__.py deleted file mode 100644 index 3817a7ade..000000000 --- a/benchmarks/workloads/iperf/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""iperf.""" - -import re - -SAMPLE_DATA = """ ------------------------------------------------------------- -Client connecting to 10.138.15.215, TCP port 32779 -TCP window size: 45.0 KByte (default) ------------------------------------------------------------- -[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779 -[ ID] Interval Transfer Bandwidth -[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec - -""" - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# pylint: disable=unused-argument -def bandwidth(data: str, **kwargs) -> float: - """Calculate the bandwidth.""" - regex = r"\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec" - res = re.compile(regex).search(data) - return float(res.group(1)) * 1000 diff --git a/benchmarks/workloads/iperf/iperf_test.py b/benchmarks/workloads/iperf/iperf_test.py deleted file mode 100644 index 6959b7e8a..000000000 --- a/benchmarks/workloads/iperf/iperf_test.py +++ /dev/null @@ -1,28 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for iperf.""" - -import sys - -import pytest - -from benchmarks.workloads import iperf - - -def test_bandwidth(): - assert iperf.bandwidth(iperf.sample()) == 45900 * 1000 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/netcat/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/netcat/Dockerfile b/benchmarks/workloads/netcat/Dockerfile deleted file mode 100644 index d8548d89a..000000000 --- a/benchmarks/workloads/netcat/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - netcat \ - && rm -rf /var/lib/apt/lists/* - -# Accept a host and port parameter. -ENV host localhost -ENV port 8080 - -# Spin until we make a successful request. -CMD ["sh", "-c", "while ! nc -zv $host $port; do true; done"] diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/nginx/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/nginx/Dockerfile b/benchmarks/workloads/nginx/Dockerfile deleted file mode 100644 index b64eb52ae..000000000 --- a/benchmarks/workloads/nginx/Dockerfile +++ /dev/null @@ -1 +0,0 @@ -FROM nginx:1.15.10 diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD deleted file mode 100644 index bfcf78cf9..000000000 --- a/benchmarks/workloads/node/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "index.js", - "package.json", - ], -) diff --git a/benchmarks/workloads/node/Dockerfile b/benchmarks/workloads/node/Dockerfile deleted file mode 100644 index 139a38bf5..000000000 --- a/benchmarks/workloads/node/Dockerfile +++ /dev/null @@ -1,2 +0,0 @@ -FROM node:onbuild -CMD ["node", "index.js"] diff --git a/benchmarks/workloads/node/index.js b/benchmarks/workloads/node/index.js deleted file mode 100644 index 584158462..000000000 --- a/benchmarks/workloads/node/index.js +++ /dev/null @@ -1,28 +0,0 @@ -'use strict'; - -var start = new Date().getTime(); - -// Load dependencies to simulate an average nodejs app. -var req_0 = require('async'); -var req_1 = require('bluebird'); -var req_2 = require('firebase'); -var req_3 = require('firebase-admin'); -var req_4 = require('@google-cloud/container'); -var req_5 = require('@google-cloud/logging'); -var req_6 = require('@google-cloud/monitoring'); -var req_7 = require('@google-cloud/spanner'); -var req_8 = require('lodash'); -var req_9 = require('mailgun-js'); -var req_10 = require('request'); -var express = require('express'); -var app = express(); - -var loaded = new Date().getTime() - start; -app.get('/', function(req, res) { - res.send('Hello World!<br>Loaded in ' + loaded + 'ms'); -}); - -console.log('Loaded in ' + loaded + ' ms'); -app.listen(8080, function() { - console.log('Listening on port 8080...'); -}); diff --git a/benchmarks/workloads/node/package.json b/benchmarks/workloads/node/package.json deleted file mode 100644 index c00b9b3cb..000000000 --- a/benchmarks/workloads/node/package.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "name": "node", - "version": "1.0.0", - "main": "index.js", - "dependencies": { - "@google-cloud/container": "^0.3.0", - "@google-cloud/logging": "^4.2.0", - "@google-cloud/monitoring": "^0.6.0", - "@google-cloud/spanner": "^2.2.1", - "async": "^2.6.1", - "bluebird": "^3.5.3", - "express": "^4.16.4", - "firebase": "^5.7.2", - "firebase-admin": "^6.4.0", - "lodash": "^4.17.11", - "mailgun-js": "^0.22.0", - "request": "^2.88.0" - } -} diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD deleted file mode 100644 index e142f082a..000000000 --- a/benchmarks/workloads/node_template/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "index.hbs", - "index.js", - "package.json", - "package-lock.json", - ], -) diff --git a/benchmarks/workloads/node_template/Dockerfile b/benchmarks/workloads/node_template/Dockerfile deleted file mode 100644 index 7eb065d54..000000000 --- a/benchmarks/workloads/node_template/Dockerfile +++ /dev/null @@ -1,5 +0,0 @@ -FROM node:onbuild - -ENV host "127.0.0.1" - -CMD ["sh", "-c", "node index.js ${host}"] diff --git a/benchmarks/workloads/node_template/index.hbs b/benchmarks/workloads/node_template/index.hbs deleted file mode 100644 index 03feceb75..000000000 --- a/benchmarks/workloads/node_template/index.hbs +++ /dev/null @@ -1,8 +0,0 @@ -<!DOCTYPE html> -<html> -<body> - {{#each text}} - <p>{{this}}</p> - {{/each}} -</body> -</html> diff --git a/benchmarks/workloads/node_template/index.js b/benchmarks/workloads/node_template/index.js deleted file mode 100644 index 04a27f356..000000000 --- a/benchmarks/workloads/node_template/index.js +++ /dev/null @@ -1,43 +0,0 @@ -const app = require('express')(); -const path = require('path'); -const redis = require('redis'); -const srs = require('secure-random-string'); - -// The hostname is the first argument. -const host_name = process.argv[2]; - -var client = redis.createClient({host: host_name, detect_buffers: true}); - -app.set('views', __dirname); -app.set('view engine', 'hbs'); - -app.get('/', (req, res) => { - var tmp = []; - /* Pull four random keys from the redis server. */ - for (i = 0; i < 4; i++) { - client.get(Math.floor(Math.random() * (100)), function(err, reply) { - tmp.push(reply.toString()); - }); - } - - res.render('index', {text: tmp}); -}); - -/** - * Securely generate a random string. - * @param {number} len - * @return {string} - */ -function randomBody(len) { - return srs({alphanumeric: true, length: len}); -} - -/** Mutates one hundred keys randomly. */ -function generateText() { - for (i = 0; i < 100; i++) { - client.set(i, randomBody(1024)); - } -} - -generateText(); -app.listen(8080); diff --git a/benchmarks/workloads/node_template/package-lock.json b/benchmarks/workloads/node_template/package-lock.json deleted file mode 100644 index 580e68aa5..000000000 --- a/benchmarks/workloads/node_template/package-lock.json +++ /dev/null @@ -1,486 +0,0 @@ -{ - "name": "nodedum", - "version": "1.0.0", - "lockfileVersion": 1, - "requires": true, - "dependencies": { - "accepts": { - "version": "1.3.5", - "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz", - "integrity": "sha1-63d99gEXI6OxTopywIBcjoZ0a9I=", - "requires": { - "mime-types": "~2.1.18", - "negotiator": "0.6.1" - } - }, - "array-flatten": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", - "integrity": "sha1-ml9pkFGx5wczKPKgCJaLZOopVdI=" - }, - "async": { - "version": "2.6.2", - "resolved": "https://registry.npmjs.org/async/-/async-2.6.2.tgz", - "integrity": "sha512-H1qVYh1MYhEEFLsP97cVKqCGo7KfCyTt6uEWqsTBr9SO84oK9Uwbyd/yCW+6rKJLHksBNUVWZDAjfS+Ccx0Bbg==", - "requires": { - "lodash": "^4.17.11" - } - }, - "body-parser": { - "version": "1.18.3", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.18.3.tgz", - "integrity": "sha1-WykhmP/dVTs6DyDe0FkrlWlVyLQ=", - "requires": { - "bytes": "3.0.0", - "content-type": "~1.0.4", - "debug": "2.6.9", - "depd": "~1.1.2", - "http-errors": "~1.6.3", - "iconv-lite": "0.4.23", - "on-finished": "~2.3.0", - "qs": "6.5.2", - "raw-body": "2.3.3", - "type-is": "~1.6.16" - } - }, - "bytes": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", - "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=" - }, - "commander": { - "version": "2.20.0", - "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.0.tgz", - "integrity": "sha512-7j2y+40w61zy6YC2iRNpUe/NwhNyoXrYpHMrSunaMG64nRnaf96zO/KMQR4OyN/UnE5KLyEBnKHd4aG3rskjpQ==", - "optional": true - }, - "content-disposition": { - "version": "0.5.2", - "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.2.tgz", - "integrity": "sha1-DPaLud318r55YcOoUXjLhdunjLQ=" - }, - "content-type": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.4.tgz", - "integrity": "sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA==" - }, - "cookie": { - "version": "0.3.1", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz", - "integrity": "sha1-5+Ch+e9DtMi6klxcWpboBtFoc7s=" - }, - "cookie-signature": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", - "integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw=" - }, - "debug": { - "version": "2.6.9", - "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", - "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", - "requires": { - "ms": "2.0.0" - } - }, - "depd": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz", - "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=" - }, - "destroy": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.0.4.tgz", - "integrity": "sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA=" - }, - "double-ended-queue": { - "version": "2.1.0-0", - "resolved": "https://registry.npmjs.org/double-ended-queue/-/double-ended-queue-2.1.0-0.tgz", - "integrity": "sha1-ED01J/0xUo9AGIEwyEHv3XgmTlw=" - }, - "ee-first": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", - "integrity": "sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0=" - }, - "encodeurl": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", - "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k=" - }, - "escape-html": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", - "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg=" - }, - "etag": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", - "integrity": "sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc=" - }, - "express": { - "version": "4.16.4", - "resolved": "https://registry.npmjs.org/express/-/express-4.16.4.tgz", - "integrity": "sha512-j12Uuyb4FMrd/qQAm6uCHAkPtO8FDTRJZBDd5D2KOL2eLaz1yUNdUB/NOIyq0iU4q4cFarsUCrnFDPBcnksuOg==", - "requires": { - "accepts": "~1.3.5", - "array-flatten": "1.1.1", - "body-parser": "1.18.3", - "content-disposition": "0.5.2", - "content-type": "~1.0.4", - "cookie": "0.3.1", - "cookie-signature": "1.0.6", - "debug": "2.6.9", - "depd": "~1.1.2", - "encodeurl": "~1.0.2", - "escape-html": "~1.0.3", - "etag": "~1.8.1", - "finalhandler": "1.1.1", - "fresh": "0.5.2", - "merge-descriptors": "1.0.1", - "methods": "~1.1.2", - "on-finished": "~2.3.0", - "parseurl": "~1.3.2", - "path-to-regexp": "0.1.7", - "proxy-addr": "~2.0.4", - "qs": "6.5.2", - "range-parser": "~1.2.0", - "safe-buffer": "5.1.2", - "send": "0.16.2", - "serve-static": "1.13.2", - "setprototypeof": "1.1.0", - "statuses": "~1.4.0", - "type-is": "~1.6.16", - "utils-merge": "1.0.1", - "vary": "~1.1.2" - } - }, - "finalhandler": { - "version": "1.1.1", - "resolved": "http://registry.npmjs.org/finalhandler/-/finalhandler-1.1.1.tgz", - "integrity": "sha512-Y1GUDo39ez4aHAw7MysnUD5JzYX+WaIj8I57kO3aEPT1fFRL4sr7mjei97FgnwhAyyzRYmQZaTHb2+9uZ1dPtg==", - "requires": { - "debug": "2.6.9", - "encodeurl": "~1.0.2", - "escape-html": "~1.0.3", - "on-finished": "~2.3.0", - "parseurl": "~1.3.2", - "statuses": "~1.4.0", - "unpipe": "~1.0.0" - } - }, - "foreachasync": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/foreachasync/-/foreachasync-3.0.0.tgz", - "integrity": "sha1-VQKYfchxS+M5IJfzLgBxyd7gfPY=" - }, - "forwarded": { - "version": "0.1.2", - "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz", - "integrity": "sha1-mMI9qxF1ZXuMBXPozszZGw/xjIQ=" - }, - "fresh": { - "version": "0.5.2", - "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", - "integrity": "sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac=" - }, - "handlebars": { - "version": "4.0.14", - "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.0.14.tgz", - "integrity": "sha512-E7tDoyAA8ilZIV3xDJgl18sX3M8xB9/fMw8+mfW4msLW8jlX97bAnWgT3pmaNXuvzIEgSBMnAHfuXsB2hdzfow==", - "requires": { - "async": "^2.5.0", - "optimist": "^0.6.1", - "source-map": "^0.6.1", - "uglify-js": "^3.1.4" - } - }, - "hbs": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.0.4.tgz", - "integrity": "sha512-esVlyV/V59mKkwFai5YmPRSNIWZzhqL5YMN0++ueMxyK1cCfPa5f6JiHtapPKAIVAhQR6rpGxow0troav9WMEg==", - "requires": { - "handlebars": "4.0.14", - "walk": "2.3.9" - } - }, - "http-errors": { - "version": "1.6.3", - "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", - "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=", - "requires": { - "depd": "~1.1.2", - "inherits": "2.0.3", - "setprototypeof": "1.1.0", - "statuses": ">= 1.4.0 < 2" - } - }, - "iconv-lite": { - "version": "0.4.23", - "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.23.tgz", - "integrity": "sha512-neyTUVFtahjf0mB3dZT77u+8O0QB89jFdnBkd5P1JgYPbPaia3gXXOVL2fq8VyU2gMMD7SaN7QukTB/pmXYvDA==", - "requires": { - "safer-buffer": ">= 2.1.2 < 3" - } - }, - "inherits": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", - "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=" - }, - "ipaddr.js": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.8.0.tgz", - "integrity": "sha1-6qM9bd16zo9/b+DJygRA5wZzix4=" - }, - "lodash": { - "version": "4.17.15", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.15.tgz", - "integrity": "sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A==" - }, - "media-typer": { - "version": "0.3.0", - "resolved": "http://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", - "integrity": "sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g=" - }, - "merge-descriptors": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz", - "integrity": "sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E=" - }, - "methods": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz", - "integrity": "sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4=" - }, - "mime": { - "version": "1.4.1", - "resolved": "https://registry.npmjs.org/mime/-/mime-1.4.1.tgz", - "integrity": "sha512-KI1+qOZu5DcW6wayYHSzR/tXKCDC5Om4s1z2QJjDULzLcmf3DvzS7oluY4HCTrc+9FiKmWUgeNLg7W3uIQvxtQ==" - }, - "mime-db": { - "version": "1.37.0", - "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.37.0.tgz", - "integrity": "sha512-R3C4db6bgQhlIhPU48fUtdVmKnflq+hRdad7IyKhtFj06VPNVdk2RhiYL3UjQIlso8L+YxAtFkobT0VK+S/ybg==" - }, - "mime-types": { - "version": "2.1.21", - "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.21.tgz", - "integrity": "sha512-3iL6DbwpyLzjR3xHSFNFeb9Nz/M8WDkX33t1GFQnFOllWk8pOrh/LSrB5OXlnlW5P9LH73X6loW/eogc+F5lJg==", - "requires": { - "mime-db": "~1.37.0" - } - }, - "minimist": { - "version": "0.0.10", - "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz", - "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8=" - }, - "ms": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", - "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=" - }, - "negotiator": { - "version": "0.6.1", - "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.1.tgz", - "integrity": "sha1-KzJxhOiZIQEXeyhWP7XnECrNDKk=" - }, - "on-finished": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.3.0.tgz", - "integrity": "sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=", - "requires": { - "ee-first": "1.1.1" - } - }, - "optimist": { - "version": "0.6.1", - "resolved": "https://registry.npmjs.org/optimist/-/optimist-0.6.1.tgz", - "integrity": "sha1-2j6nRob6IaGaERwybpDrFaAZZoY=", - "requires": { - "minimist": "~0.0.1", - "wordwrap": "~0.0.2" - } - }, - "parseurl": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.2.tgz", - "integrity": "sha1-/CidTtiZMRlGDBViUyYs3I3mW/M=" - }, - "path-to-regexp": { - "version": "0.1.7", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz", - "integrity": "sha1-32BBeABfUi8V60SQ5yR6G/qmf4w=" - }, - "proxy-addr": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.4.tgz", - "integrity": "sha512-5erio2h9jp5CHGwcybmxmVqHmnCBZeewlfJ0pex+UW7Qny7OOZXTtH56TGNyBizkgiOwhJtMKrVzDTeKcySZwA==", - "requires": { - "forwarded": "~0.1.2", - "ipaddr.js": "1.8.0" - } - }, - "qs": { - "version": "6.5.2", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.5.2.tgz", - "integrity": "sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA==" - }, - "range-parser": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.0.tgz", - "integrity": "sha1-9JvmtIeJTdxA3MlKMi9hEJLgDV4=" - }, - "raw-body": { - "version": "2.3.3", - "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.3.3.tgz", - "integrity": "sha512-9esiElv1BrZoI3rCDuOuKCBRbuApGGaDPQfjSflGxdy4oyzqghxu6klEkkVIvBje+FF0BX9coEv8KqW6X/7njw==", - "requires": { - "bytes": "3.0.0", - "http-errors": "1.6.3", - "iconv-lite": "0.4.23", - "unpipe": "1.0.0" - } - }, - "redis": { - "version": "2.8.0", - "resolved": "https://registry.npmjs.org/redis/-/redis-2.8.0.tgz", - "integrity": "sha512-M1OkonEQwtRmZv4tEWF2VgpG0JWJ8Fv1PhlgT5+B+uNq2cA3Rt1Yt/ryoR+vQNOQcIEgdCdfH0jr3bDpihAw1A==", - "requires": { - "double-ended-queue": "^2.1.0-0", - "redis-commands": "^1.2.0", - "redis-parser": "^2.6.0" - }, - "dependencies": { - "redis-commands": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.4.0.tgz", - "integrity": "sha512-cu8EF+MtkwI4DLIT0x9P8qNTLFhQD4jLfxLR0cCNkeGzs87FN6879JOJwNQR/1zD7aSYNbU0hgsV9zGY71Itvw==" - }, - "redis-parser": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz", - "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=" - } - } - }, - "redis-commands": { - "version": "1.5.0", - "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.5.0.tgz", - "integrity": "sha512-6KxamqpZ468MeQC3bkWmCB1fp56XL64D4Kf0zJSwDZbVLLm7KFkoIcHrgRvQ+sk8dnhySs7+yBg94yIkAK7aJg==" - }, - "redis-parser": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz", - "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=" - }, - "safe-buffer": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", - "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==" - }, - "safer-buffer": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", - "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==" - }, - "secure-random-string": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.0.tgz", - "integrity": "sha512-V/h8jqoz58zklNGybVhP++cWrxEPXlLM/6BeJ4e0a8zlb4BsbYRzFs16snrxByPa5LUxCVTD3M6EYIVIHR1fAg==" - }, - "send": { - "version": "0.16.2", - "resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz", - "integrity": "sha512-E64YFPUssFHEFBvpbbjr44NCLtI1AohxQ8ZSiJjQLskAdKuriYEP6VyGEsRDH8ScozGpkaX1BGvhanqCwkcEZw==", - "requires": { - "debug": "2.6.9", - "depd": "~1.1.2", - "destroy": "~1.0.4", - "encodeurl": "~1.0.2", - "escape-html": "~1.0.3", - "etag": "~1.8.1", - "fresh": "0.5.2", - "http-errors": "~1.6.2", - "mime": "1.4.1", - "ms": "2.0.0", - "on-finished": "~2.3.0", - "range-parser": "~1.2.0", - "statuses": "~1.4.0" - } - }, - "serve-static": { - "version": "1.13.2", - "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.13.2.tgz", - "integrity": "sha512-p/tdJrO4U387R9oMjb1oj7qSMaMfmOyd4j9hOFoxZe2baQszgHcSWjuya/CiT5kgZZKRudHNOA0pYXOl8rQ5nw==", - "requires": { - "encodeurl": "~1.0.2", - "escape-html": "~1.0.3", - "parseurl": "~1.3.2", - "send": "0.16.2" - } - }, - "setprototypeof": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz", - "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ==" - }, - "source-map": { - "version": "0.6.1", - "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", - "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==" - }, - "statuses": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.4.0.tgz", - "integrity": "sha512-zhSCtt8v2NDrRlPQpCNtw/heZLtfUDqxBM1udqikb/Hbk52LK4nQSwr10u77iopCW5LsyHpuXS0GnEc48mLeew==" - }, - "type-is": { - "version": "1.6.16", - "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.16.tgz", - "integrity": "sha512-HRkVv/5qY2G6I8iab9cI7v1bOIdhm94dVjQCPFElW9W+3GeDOSHmy2EBYe4VTApuzolPcmgFTN3ftVJRKR2J9Q==", - "requires": { - "media-typer": "0.3.0", - "mime-types": "~2.1.18" - } - }, - "uglify-js": { - "version": "3.5.9", - "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.5.9.tgz", - "integrity": "sha512-WpT0RqsDtAWPNJK955DEnb6xjymR8Fn0OlK4TT4pS0ASYsVPqr5ELhgwOwLCP5J5vHeJ4xmMmz3DEgdqC10JeQ==", - "optional": true, - "requires": { - "commander": "~2.20.0", - "source-map": "~0.6.1" - } - }, - "unpipe": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", - "integrity": "sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw=" - }, - "utils-merge": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz", - "integrity": "sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM=" - }, - "vary": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", - "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=" - }, - "walk": { - "version": "2.3.9", - "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.9.tgz", - "integrity": "sha1-MbTbZnjyrgHDnqn7hyWpAx5Vins=", - "requires": { - "foreachasync": "^3.0.0" - } - }, - "wordwrap": { - "version": "0.0.3", - "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz", - "integrity": "sha1-o9XabNXAvAAI03I0u68b7WMFkQc=" - } - } -} diff --git a/benchmarks/workloads/node_template/package.json b/benchmarks/workloads/node_template/package.json deleted file mode 100644 index 7dcadd523..000000000 --- a/benchmarks/workloads/node_template/package.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "name": "nodedum", - "version": "1.0.0", - "description": "", - "main": "index.js", - "scripts": { - "test": "echo \"Error: no test specified\" && exit 1" - }, - "author": "", - "license": "ISC", - "dependencies": { - "express": "^4.16.4", - "hbs": "^4.0.4", - "redis": "^2.8.0", - "redis-commands": "^1.2.0", - "redis-parser": "^2.6.0", - "secure-random-string": "^1.1.0" - } -} diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/redis/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/redis/Dockerfile b/benchmarks/workloads/redis/Dockerfile deleted file mode 100644 index 0f17249af..000000000 --- a/benchmarks/workloads/redis/Dockerfile +++ /dev/null @@ -1 +0,0 @@ -FROM redis:5.0.4 diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD deleted file mode 100644 index 147cfedd2..000000000 --- a/benchmarks/workloads/redisbenchmark/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "redisbenchmark", - srcs = ["__init__.py"], -) - -py_test( - name = "redisbenchmark_test", - srcs = ["redisbenchmark_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":redisbenchmark", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/redisbenchmark/Dockerfile b/benchmarks/workloads/redisbenchmark/Dockerfile deleted file mode 100644 index f94f6442e..000000000 --- a/benchmarks/workloads/redisbenchmark/Dockerfile +++ /dev/null @@ -1,4 +0,0 @@ -FROM redis:5.0.4 -ENV host localhost -ENV port 6379 -CMD ["sh", "-c", "redis-benchmark --csv -h ${host} -p ${port} ${flags}"] diff --git a/benchmarks/workloads/redisbenchmark/__init__.py b/benchmarks/workloads/redisbenchmark/__init__.py deleted file mode 100644 index 229cef5fa..000000000 --- a/benchmarks/workloads/redisbenchmark/__init__.py +++ /dev/null @@ -1,85 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Redis-benchmark tool.""" - -import re - -OPERATIONS = [ - "PING_INLINE", - "PING_BULK", - "SET", - "GET", - "INCR", - "LPUSH", - "RPUSH", - "LPOP", - "RPOP", - "SADD", - "HSET", - "SPOP", - "LRANGE_100", - "LRANGE_300", - "LRANGE_500", - "LRANGE_600", - "MSET", -] - -METRICS = dict() - -SAMPLE_DATA = """ -"PING_INLINE","48661.80" -"PING_BULK","50301.81" -"SET","48923.68" -"GET","49382.71" -"INCR","49975.02" -"LPUSH","49875.31" -"RPUSH","50276.52" -"LPOP","50327.12" -"RPOP","50556.12" -"SADD","49504.95" -"HSET","49504.95" -"SPOP","50025.02" -"LPUSH (needed to benchmark LRANGE)","48875.86" -"LRANGE_100 (first 100 elements)","33955.86" -"LRANGE_300 (first 300 elements)","16550.81" -"LRANGE_500 (first 450 elements)","13653.74" -"LRANGE_600 (first 600 elements)","11219.57" -"MSET (10 keys)","44682.75" -""" - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# Bind a metric for each operation noted above. -for op in OPERATIONS: - - def bind(metric): - """Bind op to a new scope.""" - - # pylint: disable=unused-argument - def parse(data: str, **kwargs) -> float: - """Operation throughput in requests/sec.""" - regex = r"\"" + metric + r"( .*)?\",\"(\d*.\d*)" - res = re.compile(regex).search(data) - if res: - return float(res.group(2)) - return 0.0 - - parse.__name__ = metric - return parse - - METRICS[op] = bind(op) diff --git a/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py b/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py deleted file mode 100644 index 419ced059..000000000 --- a/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py +++ /dev/null @@ -1,51 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Parser test.""" - -import sys - -import pytest - -from benchmarks.workloads import redisbenchmark - -RESULTS = { - "PING_INLINE": 48661.80, - "PING_BULK": 50301.81, - "SET": 48923.68, - "GET": 49382.71, - "INCR": 49975.02, - "LPUSH": 49875.31, - "RPUSH": 50276.52, - "LPOP": 50327.12, - "RPOP": 50556.12, - "SADD": 49504.95, - "HSET": 49504.95, - "SPOP": 50025.02, - "LRANGE_100": 33955.86, - "LRANGE_300": 16550.81, - "LRANGE_500": 13653.74, - "LRANGE_600": 11219.57, - "MSET": 44682.75 -} - - -def test_metrics(): - """Test all metrics.""" - for (metric, func) in redisbenchmark.METRICS.items(): - res = func(redisbenchmark.sample()) - assert float(res) == RESULTS[metric] - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD deleted file mode 100644 index a3be4fe92..000000000 --- a/benchmarks/workloads/ruby/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -filegroup( - name = "files", - srcs = [ - "Dockerfile", - "Gemfile", - "Gemfile.lock", - "config.ru", - "index.rb", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "Gemfile", - "Gemfile.lock", - "config.ru", - "index.rb", - ], -) diff --git a/benchmarks/workloads/ruby/Dockerfile b/benchmarks/workloads/ruby/Dockerfile deleted file mode 100644 index a9a7a7086..000000000 --- a/benchmarks/workloads/ruby/Dockerfile +++ /dev/null @@ -1,28 +0,0 @@ -# example based on https://github.com/errm/fib - -FROM ruby:2.5 - -RUN apt-get update -qq && apt-get install -y build-essential libpq-dev nodejs libsodium-dev - -# Set an environment variable where the Rails app is installed to inside of Docker image -ENV RAILS_ROOT /var/www/app_name -RUN mkdir -p $RAILS_ROOT - -# Set working directory -WORKDIR $RAILS_ROOT - -# Setting env up -ENV RAILS_ENV='production' -ENV RACK_ENV='production' - -# Adding gems -COPY Gemfile Gemfile -COPY Gemfile.lock Gemfile.lock -RUN bundle install --jobs 20 --retry 5 --without development test - -# Adding project files -COPY . . - -EXPOSE $PORT -STOPSIGNAL SIGINT -CMD ["bundle", "exec", "puma", "config.ru"] diff --git a/benchmarks/workloads/ruby/Gemfile b/benchmarks/workloads/ruby/Gemfile deleted file mode 100644 index 8f1bdad6e..000000000 --- a/benchmarks/workloads/ruby/Gemfile +++ /dev/null @@ -1,12 +0,0 @@ -source "https://rubygems.org" -# load a bunch of dependencies to take up memory -gem "sinatra" -gem "puma" -gem "redis" -gem 'rake' -gem 'squid', '~> 1.4' -gem 'cassandra-driver' -gem 'ruby-fann' -gem 'rbnacl' -gem 'bcrypt' -gem "activemerchant"
\ No newline at end of file diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock deleted file mode 100644 index ea9f0ea85..000000000 --- a/benchmarks/workloads/ruby/Gemfile.lock +++ /dev/null @@ -1,71 +0,0 @@ -GEM - remote: https://rubygems.org/ - specs: - activemerchant (1.105.0) - activesupport (>= 4.2) - builder (>= 2.1.2, < 4.0.0) - i18n (>= 0.6.9) - nokogiri (~> 1.4) - activesupport (5.2.3) - concurrent-ruby (~> 1.0, >= 1.0.2) - i18n (>= 0.7, < 2) - minitest (~> 5.1) - tzinfo (~> 1.1) - bcrypt (3.1.13) - builder (3.2.4) - cassandra-driver (3.2.3) - ione (~> 1.2) - concurrent-ruby (1.1.5) - ffi (1.12.2) - i18n (1.6.0) - concurrent-ruby (~> 1.0) - ione (1.2.4) - mini_portile2 (2.4.0) - minitest (5.11.3) - mustermann (1.0.3) - nokogiri (1.10.8) - mini_portile2 (~> 2.4.0) - pdf-core (0.7.0) - prawn (2.2.2) - pdf-core (~> 0.7.0) - ttfunk (~> 1.5) - puma (3.12.4) - rack (2.2.2) - rack-protection (2.0.5) - rack - rake (12.3.3) - rbnacl (7.1.1) - ffi - redis (4.1.1) - ruby-fann (1.2.6) - sinatra (2.0.5) - mustermann (~> 1.0) - rack (~> 2.0) - rack-protection (= 2.0.5) - tilt (~> 2.0) - squid (1.4.1) - activesupport (>= 4.0) - prawn (~> 2.2) - thread_safe (0.3.6) - tilt (2.0.9) - ttfunk (1.5.1) - tzinfo (1.2.5) - thread_safe (~> 0.1) - -PLATFORMS - ruby - -DEPENDENCIES - activemerchant - bcrypt - cassandra-driver - puma - rake - rbnacl - redis - ruby-fann - sinatra - squid (~> 1.4) - -BUNDLED WITH - 1.17.1 diff --git a/benchmarks/workloads/ruby/config.ru b/benchmarks/workloads/ruby/config.ru deleted file mode 100755 index fbd5acc82..000000000 --- a/benchmarks/workloads/ruby/config.ru +++ /dev/null @@ -1,2 +0,0 @@ -require './index' -run Sinatra::Application
\ No newline at end of file diff --git a/benchmarks/workloads/ruby/index.rb b/benchmarks/workloads/ruby/index.rb deleted file mode 100755 index 5fa85af93..000000000 --- a/benchmarks/workloads/ruby/index.rb +++ /dev/null @@ -1,14 +0,0 @@ -require "sinatra" -require "puma" -require "redis" -require "rake" -require "squid" -require "cassandra" -require "ruby-fann" -require "rbnacl" -require "bcrypt" -require "activemerchant" - -get "/" do - "Hello World!" -end
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD deleted file mode 100644 index 72ed9403d..000000000 --- a/benchmarks/workloads/ruby_template/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "Gemfile", - "Gemfile.lock", - "config.ru", - "index.erb", - "main.rb", - ], -) diff --git a/benchmarks/workloads/ruby_template/Dockerfile b/benchmarks/workloads/ruby_template/Dockerfile deleted file mode 100755 index a06d68bf4..000000000 --- a/benchmarks/workloads/ruby_template/Dockerfile +++ /dev/null @@ -1,38 +0,0 @@ -# example based on https://github.com/errm/fib - -FROM alpine:3.9 as build - -COPY Gemfile Gemfile.lock ./ - -RUN apk add --no-cache ruby ruby-dev ruby-bundler ruby-json build-base bash \ - && bundle install --frozen -j4 -r3 --no-cache --without development \ - && apk del --no-cache ruby-bundler \ - && rm -rf /usr/lib/ruby/gems/*/cache - -FROM alpine:3.9 as prod - -COPY --from=build /usr/lib/ruby/gems /usr/lib/ruby/gems -RUN apk add --no-cache ruby ruby-json ruby-etc redis apache2-utils \ - && ruby -e "Gem::Specification.map.each do |spec| \ - Gem::Installer.for_spec( \ - spec, \ - wrappers: true, \ - force: true, \ - install_dir: spec.base_dir, \ - build_args: spec.build_args, \ - ).generate_bin \ - end" - -WORKDIR /app -COPY . /app/. - -ENV PORT=9292 \ - WEB_CONCURRENCY=20 \ - WEB_MAX_THREADS=20 \ - RACK_ENV=production - -ENV host localhost -EXPOSE $PORT -USER nobody -STOPSIGNAL SIGINT -CMD ["sh", "-c", "/usr/bin/puma", "${host}"] diff --git a/benchmarks/workloads/ruby_template/Gemfile b/benchmarks/workloads/ruby_template/Gemfile deleted file mode 100755 index ac521b32c..000000000 --- a/benchmarks/workloads/ruby_template/Gemfile +++ /dev/null @@ -1,5 +0,0 @@ -source "https://rubygems.org" - -gem "sinatra" -gem "puma" -gem "redis"
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock deleted file mode 100644 index f637b6081..000000000 --- a/benchmarks/workloads/ruby_template/Gemfile.lock +++ /dev/null @@ -1,26 +0,0 @@ -GEM - remote: https://rubygems.org/ - specs: - mustermann (1.0.3) - puma (3.12.4) - rack (2.0.6) - rack-protection (2.0.5) - rack - redis (4.1.0) - sinatra (2.0.5) - mustermann (~> 1.0) - rack (~> 2.0) - rack-protection (= 2.0.5) - tilt (~> 2.0) - tilt (2.0.9) - -PLATFORMS - ruby - -DEPENDENCIES - puma - redis - sinatra - -BUNDLED WITH - 1.17.1
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/config.ru b/benchmarks/workloads/ruby_template/config.ru deleted file mode 100755 index b2d135cc0..000000000 --- a/benchmarks/workloads/ruby_template/config.ru +++ /dev/null @@ -1,2 +0,0 @@ -require './main' -run Sinatra::Application
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/index.erb b/benchmarks/workloads/ruby_template/index.erb deleted file mode 100755 index 7f7300e80..000000000 --- a/benchmarks/workloads/ruby_template/index.erb +++ /dev/null @@ -1,8 +0,0 @@ -<!DOCTYPE html> -<html> -<body> - <% text.each do |t| %> - <p><%= t %></p> - <% end %> -</body> -</html> diff --git a/benchmarks/workloads/ruby_template/main.rb b/benchmarks/workloads/ruby_template/main.rb deleted file mode 100755 index 35c239377..000000000 --- a/benchmarks/workloads/ruby_template/main.rb +++ /dev/null @@ -1,27 +0,0 @@ -require "sinatra" -require "securerandom" -require "redis" - -redis_host = ENV["host"] -$redis = Redis.new(host: redis_host) - -def generateText - for i in 0..99 - $redis.set(i, randomBody(1024)) - end -end - -def randomBody(length) - return SecureRandom.alphanumeric(length) -end - -generateText -template = ERB.new(File.read('./index.erb')) - -get "/" do - texts = Array.new - for i in 0..4 - texts.push($redis.get(rand(0..99))) - end - template.result_with_hash(text: texts) -end
\ No newline at end of file diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/sleep/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/sleep/Dockerfile b/benchmarks/workloads/sleep/Dockerfile deleted file mode 100644 index 24c72e07a..000000000 --- a/benchmarks/workloads/sleep/Dockerfile +++ /dev/null @@ -1,3 +0,0 @@ -FROM alpine:latest - -CMD ["sleep", "315360000"] diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD deleted file mode 100644 index ab2556064..000000000 --- a/benchmarks/workloads/sysbench/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "sysbench", - srcs = ["__init__.py"], -) - -py_test( - name = "sysbench_test", - srcs = ["sysbench_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":sysbench", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/sysbench/Dockerfile b/benchmarks/workloads/sysbench/Dockerfile deleted file mode 100644 index 8225e0e14..000000000 --- a/benchmarks/workloads/sysbench/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - sysbench \ - && rm -rf /var/lib/apt/lists/* - -# Parameterize the tests. -ENV test cpu -ENV threads 1 -ENV options "" - -# run sysbench once as a warm-up and take the second result -CMD ["sh", "-c", "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null && \ -sysbench --threads=${threads} ${options} ${test} run"] diff --git a/benchmarks/workloads/sysbench/__init__.py b/benchmarks/workloads/sysbench/__init__.py deleted file mode 100644 index de357b4db..000000000 --- a/benchmarks/workloads/sysbench/__init__.py +++ /dev/null @@ -1,167 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Sysbench.""" - -import re - -STD_REGEX = r"events per second:\s*(\d*.?\d*)\n" -MEM_REGEX = r"Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)" -ALT_REGEX = r"execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)" -AVG_REGEX = r"avg:[^\n^\d]*(\d*\.?\d*)" - -SAMPLE_CPU_DATA = """ -sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) - -Running the test with following options: -Number of threads: 8 -Initializing random number generator from current time - - -Prime numbers limit: 10000 - -Initializing worker threads... - -Threads started! - -CPU speed: - events per second: 9093.38 - -General statistics: - total time: 10.0007s - total number of events: 90949 - -Latency (ms): - min: 0.64 - avg: 0.88 - max: 24.65 - 95th percentile: 1.55 - sum: 79936.91 - -Threads fairness: - events (avg/stddev): 11368.6250/831.38 - execution time (avg/stddev): 9.9921/0.01 -""" - -SAMPLE_MEMORY_DATA = """ -sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) - -Running the test with following options: -Number of threads: 8 -Initializing random number generator from current time - - -Running memory speed test with the following options: - block size: 1KiB - total size: 102400MiB - operation: write - scope: global - -Initializing worker threads... - -Threads started! - -Total operations: 47999046 (9597428.64 per second) - -46874.07 MiB transferred (9372.49 MiB/sec) - - -General statistics: - total time: 5.0001s - total number of events: 47999046 - -Latency (ms): - min: 0.00 - avg: 0.00 - max: 0.21 - 95th percentile: 0.00 - sum: 33165.91 - -Threads fairness: - events (avg/stddev): 5999880.7500/111242.52 - execution time (avg/stddev): 4.1457/0.09 -""" - -SAMPLE_MUTEX_DATA = """ -sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) - -Running the test with following options: -Number of threads: 8 -Initializing random number generator from current time - - -Initializing worker threads... - -Threads started! - - -General statistics: - total time: 3.7869s - total number of events: 8 - -Latency (ms): - min: 3688.56 - avg: 3754.03 - max: 3780.94 - 95th percentile: 3773.42 - sum: 30032.28 - -Threads fairness: - events (avg/stddev): 1.0000/0.00 - execution time (avg/stddev): 3.7540/0.03 -""" - - -# pylint: disable=unused-argument -def sample(test, **kwargs): - switch = { - "cpu": SAMPLE_CPU_DATA, - "memory": SAMPLE_MEMORY_DATA, - "mutex": SAMPLE_MUTEX_DATA, - "randwr": SAMPLE_CPU_DATA - } - return switch[test] - - -# pylint: disable=unused-argument -def cpu_events_per_second(data: str, **kwargs) -> float: - """Returns events per second.""" - return float(re.compile(STD_REGEX).search(data).group(1)) - - -# pylint: disable=unused-argument -def memory_ops_per_second(data: str, **kwargs) -> float: - """Returns memory operations per second.""" - return float(re.compile(MEM_REGEX).search(data).group(1)) - - -# pylint: disable=unused-argument -def mutex_time(data: str, count: int, locks: int, threads: int, - **kwargs) -> float: - """Returns normalized mutex time (lower is better).""" - value = float(re.compile(ALT_REGEX).search(data).group(1)) - contention = float(threads) / float(locks) - scale = contention * float(count) / 100000000.0 - return value / scale - - -# pylint: disable=unused-argument -def mutex_deviation(data: str, **kwargs) -> float: - """Returns deviation for threads.""" - return float(re.compile(ALT_REGEX).search(data).group(2)) - - -# pylint: disable=unused-argument -def mutex_latency(data: str, **kwargs) -> float: - """Returns average mutex latency.""" - return float(re.compile(AVG_REGEX).search(data).group(1)) diff --git a/benchmarks/workloads/sysbench/sysbench_test.py b/benchmarks/workloads/sysbench/sysbench_test.py deleted file mode 100644 index 3fb541fd2..000000000 --- a/benchmarks/workloads/sysbench/sysbench_test.py +++ /dev/null @@ -1,34 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Parser test.""" - -import sys - -import pytest - -from benchmarks.workloads import sysbench - - -def test_sysbench_parser(): - """Test the basic parser.""" - assert sysbench.cpu_events_per_second(sysbench.sample("cpu")) == 9093.38 - assert sysbench.memory_ops_per_second(sysbench.sample("memory")) == 9597428.64 - assert sysbench.mutex_time(sysbench.sample("mutex"), 1, 1, - 100000000.0) == 3.754 - assert sysbench.mutex_deviation(sysbench.sample("mutex")) == 0.03 - assert sysbench.mutex_latency(sysbench.sample("mutex")) == 3754.03 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD deleted file mode 100644 index f8c43bca1..000000000 --- a/benchmarks/workloads/syscall/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "syscall", - srcs = ["__init__.py"], -) - -py_test( - name = "syscall_test", - srcs = ["syscall_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":syscall", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "syscall.c", - ], -) diff --git a/benchmarks/workloads/syscall/Dockerfile b/benchmarks/workloads/syscall/Dockerfile deleted file mode 100644 index a2088d953..000000000 --- a/benchmarks/workloads/syscall/Dockerfile +++ /dev/null @@ -1,6 +0,0 @@ -FROM gcc:latest -COPY . /usr/src/syscall -WORKDIR /usr/src/syscall -RUN gcc -O2 -o syscall syscall.c -ENV count 1000000 -CMD ["sh", "-c", "./syscall ${count}"] diff --git a/benchmarks/workloads/syscall/__init__.py b/benchmarks/workloads/syscall/__init__.py deleted file mode 100644 index dc9028faa..000000000 --- a/benchmarks/workloads/syscall/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Simple syscall test.""" - -import re - -SAMPLE_DATA = "Called getpid syscall 1000000 times: 1117 ms, 500 ns each." - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# pylint: disable=unused-argument -def syscall_time_ns(data: str, **kwargs) -> int: - """Returns average system call time.""" - return float(re.compile(r"(\d+)\sns each.").search(data).group(1)) diff --git a/benchmarks/workloads/syscall/syscall.c b/benchmarks/workloads/syscall/syscall.c deleted file mode 100644 index ded030397..000000000 --- a/benchmarks/workloads/syscall/syscall.c +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#define _GNU_SOURCE -#include <stdio.h> -#include <stdlib.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <time.h> -#include <unistd.h> - -// Short program that calls getpid() a number of times and outputs time -// diference from the MONOTONIC clock. -int main(int argc, char** argv) { - struct timespec start, stop; - long result; - char buf[80]; - - if (argc < 2) { - printf("Usage:./syscall NUM_TIMES_TO_CALL"); - return 1; - } - - if (clock_gettime(CLOCK_MONOTONIC, &start)) return 1; - - long loops = atoi(argv[1]); - for (long i = 0; i < loops; i++) { - syscall(SYS_gettimeofday, 0, 0); - } - - if (clock_gettime(CLOCK_MONOTONIC, &stop)) return 1; - - if ((stop.tv_nsec - start.tv_nsec) < 0) { - result = (stop.tv_sec - start.tv_sec - 1) * 1000; - result += (stop.tv_nsec - start.tv_nsec + 1000000000) / (1000 * 1000); - } else { - result = (stop.tv_sec - start.tv_sec) * 1000; - result += (stop.tv_nsec - start.tv_nsec) / (1000 * 1000); - } - - printf("Called getpid syscall %d times: %lu ms, %lu ns each.\n", loops, - result, result * 1000000 / loops); - - return 0; -} diff --git a/benchmarks/workloads/syscall/syscall_test.py b/benchmarks/workloads/syscall/syscall_test.py deleted file mode 100644 index 72f027de1..000000000 --- a/benchmarks/workloads/syscall/syscall_test.py +++ /dev/null @@ -1,27 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -import pytest - -from benchmarks.workloads import syscall - - -def test_syscall_time_ns(): - assert syscall.syscall_time_ns(syscall.sample()) == 500 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD deleted file mode 100644 index a7b7742f4..000000000 --- a/benchmarks/workloads/tensorflow/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "tensorflow", - srcs = ["__init__.py"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/tensorflow/Dockerfile b/benchmarks/workloads/tensorflow/Dockerfile deleted file mode 100644 index 262643b98..000000000 --- a/benchmarks/workloads/tensorflow/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM tensorflow/tensorflow:1.13.2 - -RUN apt-get update \ - && apt-get install -y git -RUN git clone https://github.com/aymericdamien/TensorFlow-Examples.git -RUN python -m pip install -U pip setuptools -RUN python -m pip install matplotlib - -WORKDIR /TensorFlow-Examples/examples - -ENV PYTHONPATH="$PYTHONPATH:/TensorFlow-Examples/examples" - -ENV workload "3_NeuralNetworks/convolutional_network.py" -CMD python ${workload} diff --git a/benchmarks/workloads/tensorflow/__init__.py b/benchmarks/workloads/tensorflow/__init__.py deleted file mode 100644 index b5ec213f8..000000000 --- a/benchmarks/workloads/tensorflow/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A Tensorflow example.""" - - -# pylint: disable=unused-argument -def run_time(value, **kwargs): - """Returns the startup and runtime of the Tensorflow workload in seconds.""" - return value diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD deleted file mode 100644 index eba23d325..000000000 --- a/benchmarks/workloads/true/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], - extension = "tar", -) diff --git a/benchmarks/workloads/true/Dockerfile b/benchmarks/workloads/true/Dockerfile deleted file mode 100644 index 2e97c921e..000000000 --- a/benchmarks/workloads/true/Dockerfile +++ /dev/null @@ -1,3 +0,0 @@ -FROM alpine:latest - -CMD ["true"] diff --git a/g3doc/README.md b/g3doc/README.md deleted file mode 100644 index 49d58cdae..000000000 --- a/g3doc/README.md +++ /dev/null @@ -1,2 +0,0 @@ -The gVisor logo files are licensed under CC BY-SA 4.0 (Creative Commons -Attribution-ShareAlike 4.0 International). diff --git a/g3doc/logo.png b/g3doc/logo.png Binary files differdeleted file mode 100644 index bd1a1e4b7..000000000 --- a/g3doc/logo.png +++ /dev/null diff --git a/kokoro/benchmark_tests.cfg b/kokoro/benchmark_tests.cfg deleted file mode 100644 index f85cc9681..000000000 --- a/kokoro/benchmark_tests.cfg +++ /dev/null @@ -1,26 +0,0 @@ -build_file : 'repo/scripts/benchmark.sh' - - -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id : 73898 - keyname : 'gvisor-benchmarks-service-account' - }, - } -} - -env_vars { - key : 'PROJECT' - value : 'gvisor-benchmarks' -} - -env_vars { - key : 'ZONE' - value : 'us-central1-b' -} - -env_vars { - key : 'GCLOUD_CREDENTIALS' - value : '73898_gvisor-benchmarks-service-account' -} diff --git a/kokoro/build.cfg b/kokoro/build.cfg deleted file mode 100644 index c9ceda947..000000000 --- a/kokoro/build.cfg +++ /dev/null @@ -1,24 +0,0 @@ -build_file: "repo/scripts/build.sh" - -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73898 - keyname: "kokoro-repo-key" - } - } -} - -env_vars { - key: "KOKORO_REPO_KEY" - value: "73898_kokoro-repo-key" -} - -action { - define_artifacts { - regex: "**/runsc" - regex: "**/runsc.*" - regex: "**/dists/**" - regex: "**/pool/**" - } -} diff --git a/kokoro/build_tests.cfg b/kokoro/build_tests.cfg deleted file mode 100644 index c64b7e679..000000000 --- a/kokoro/build_tests.cfg +++ /dev/null @@ -1 +0,0 @@ -build_file: "repo/scripts/build.sh" diff --git a/kokoro/common.cfg b/kokoro/common.cfg deleted file mode 100644 index 669a2e458..000000000 --- a/kokoro/common.cfg +++ /dev/null @@ -1,29 +0,0 @@ -# Give Kokoro access to Remote Build Executor (RBE) service account key. -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73898 - keyname: "kokoro-rbe-service-account" - } - } -} - -# Configure bazel to access RBE. -bazel_setting { - # Our GCP project name. - project_id: "gvisor-rbe" - - # Use RBE for execution as well as caching. - local_execution: false - - # This must match the values in the job config. - auth_credential: { - keystore_config_id: 73898 - keyname: "kokoro-rbe-service-account" - } - - # Do not change unless you know what you are doing. - bes_backend_address: "buildeventservice.googleapis.com" - foundry_backend_address: "remotebuildexecution.googleapis.com" - upsalite_frontend_address: "https://source.cloud.google.com" -} diff --git a/kokoro/do_tests.cfg b/kokoro/do_tests.cfg deleted file mode 100644 index b45ec0b42..000000000 --- a/kokoro/do_tests.cfg +++ /dev/null @@ -1,9 +0,0 @@ -build_file: "repo/scripts/do_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/kokoro/docker_tests.cfg b/kokoro/docker_tests.cfg deleted file mode 100644 index 0a0ef87ed..000000000 --- a/kokoro/docker_tests.cfg +++ /dev/null @@ -1,10 +0,0 @@ -build_file: "repo/scripts/docker_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc_logs_*.tar.gz" - } -} diff --git a/kokoro/go.cfg b/kokoro/go.cfg deleted file mode 100644 index b9c1fcb12..000000000 --- a/kokoro/go.cfg +++ /dev/null @@ -1,20 +0,0 @@ -build_file: "repo/scripts/go.sh" - -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73898 - keyname: "kokoro-github-access-token" - } - } -} - -env_vars { - key: "KOKORO_GITHUB_ACCESS_TOKEN" - value: "73898_kokoro-github-access-token" -} - -env_vars { - key: "KOKORO_GO_PUSH" - value: "true" -} diff --git a/kokoro/go_tests.cfg b/kokoro/go_tests.cfg deleted file mode 100644 index 5eb51041a..000000000 --- a/kokoro/go_tests.cfg +++ /dev/null @@ -1 +0,0 @@ -build_file: "repo/scripts/go.sh" diff --git a/kokoro/hostnet_tests.cfg b/kokoro/hostnet_tests.cfg deleted file mode 100644 index 520dc55a3..000000000 --- a/kokoro/hostnet_tests.cfg +++ /dev/null @@ -1,10 +0,0 @@ -build_file: "repo/scripts/hostnet_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc_logs_*.tar.gz" - } -} diff --git a/kokoro/iptables_tests.cfg b/kokoro/iptables_tests.cfg deleted file mode 100644 index 7af20629a..000000000 --- a/kokoro/iptables_tests.cfg +++ /dev/null @@ -1,10 +0,0 @@ -build_file: "repo/scripts/iptables_test.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc_logs_*.tar.gz" - } -} diff --git a/kokoro/issue_reviver.cfg b/kokoro/issue_reviver.cfg deleted file mode 100644 index 2370d9250..000000000 --- a/kokoro/issue_reviver.cfg +++ /dev/null @@ -1,15 +0,0 @@ -build_file: "repo/scripts/issue_reviver.sh" - -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73898 - keyname: "kokoro-github-access-token" - } - } -} - -env_vars { - key: "KOKORO_GITHUB_ACCESS_TOKEN" - value: "73898_kokoro-github-access-token" -} diff --git a/kokoro/kvm_tests.cfg b/kokoro/kvm_tests.cfg deleted file mode 100644 index 1feb60c8a..000000000 --- a/kokoro/kvm_tests.cfg +++ /dev/null @@ -1,10 +0,0 @@ -build_file: "repo/scripts/kvm_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc_logs_*.tar.gz" - } -} diff --git a/kokoro/kythe/generate_xrefs.cfg b/kokoro/kythe/generate_xrefs.cfg deleted file mode 100644 index ccf657983..000000000 --- a/kokoro/kythe/generate_xrefs.cfg +++ /dev/null @@ -1,29 +0,0 @@ -build_file: "gvisor/kokoro/kythe/generate_xrefs.sh" - -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73898 - keyname: "kokoro-rbe-service-account" - } - } -} - -bazel_setting { - project_id: "gvisor-rbe" - local_execution: false - auth_credential: { - keystore_config_id: 73898 - keyname: "kokoro-rbe-service-account" - } - bes_backend_address: "buildeventservice.googleapis.com" - foundry_backend_address: "remotebuildexecution.googleapis.com" - upsalite_frontend_address: "https://source.cloud.google.com" -} - -action { - define_artifacts { - regex: "**/*.kzip" - fail_if_no_artifacts: true - } -} diff --git a/kokoro/kythe/generate_xrefs.sh b/kokoro/kythe/generate_xrefs.sh deleted file mode 100644 index 2f531aa72..000000000 --- a/kokoro/kythe/generate_xrefs.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -ex - -if command -v use_bazel.sh >/dev/null; then - use_bazel.sh latest -fi -bazel version - -python3 -V - -readonly KYTHE_VERSION='v0.0.43' -readonly WORKDIR="$(mktemp -d)" -readonly KYTHE_DIR="${WORKDIR}/kythe-${KYTHE_VERSION}" -if [[ -n "$KOKORO_GIT_COMMIT" ]]; then - readonly KZIP_FILENAME="${KOKORO_ARTIFACTS_DIR}/${KOKORO_GIT_COMMIT}.kzip" -else - readonly KZIP_FILENAME="$(git rev-parse HEAD).kzip" -fi - -wget -q -O "${WORKDIR}/kythe.tar.gz" \ - "https://github.com/kythe/kythe/releases/download/${KYTHE_VERSION}/kythe-${KYTHE_VERSION}.tar.gz" -tar --no-same-owner -xzf "${WORKDIR}/kythe.tar.gz" --directory "$WORKDIR" - -if [[ -n "$KOKORO_ARTIFACTS_DIR" ]]; then - cd "${KOKORO_ARTIFACTS_DIR}/github/gvisor" -fi -bazel \ - --bazelrc="${KYTHE_DIR}/extractors.bazelrc" \ - build \ - --override_repository kythe_release="${KYTHE_DIR}" \ - --define=kythe_corpus=github.com/google/gvisor \ - --cxxopt=-std=c++17 \ - --config=remote \ - --auth_credentials="${KOKORO_BAZEL_AUTH_CREDENTIAL}" \ - //... - -"${KYTHE_DIR}/tools/kzip" merge \ - --output "$KZIP_FILENAME" \ - $(find -L bazel-out/*/extra_actions/ -name '*.kzip') diff --git a/kokoro/make_tests.cfg b/kokoro/make_tests.cfg deleted file mode 100644 index d973130ff..000000000 --- a/kokoro/make_tests.cfg +++ /dev/null @@ -1,9 +0,0 @@ -build_file: "repo/scripts/make_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/kokoro/overlay_tests.cfg b/kokoro/overlay_tests.cfg deleted file mode 100644 index 6a2ddbd03..000000000 --- a/kokoro/overlay_tests.cfg +++ /dev/null @@ -1,10 +0,0 @@ -build_file: "repo/scripts/overlay_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc_logs_*.tar.gz" - } -} diff --git a/kokoro/packetdrill_tests.cfg b/kokoro/packetdrill_tests.cfg deleted file mode 100644 index 258d7deb4..000000000 --- a/kokoro/packetdrill_tests.cfg +++ /dev/null @@ -1,9 +0,0 @@ -build_file: "repo/scripts/packetdrill_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/kokoro/release.cfg b/kokoro/release.cfg deleted file mode 100644 index 5cec1790a..000000000 --- a/kokoro/release.cfg +++ /dev/null @@ -1,15 +0,0 @@ -build_file: "repo/scripts/release.sh" - -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73898 - keyname: "kokoro-github-access-token" - } - } -} - -env_vars { - key: "KOKORO_GITHUB_ACCESS_TOKEN" - value: "73898_kokoro-github-access-token" -} diff --git a/kokoro/root_tests.cfg b/kokoro/root_tests.cfg deleted file mode 100644 index 28351695c..000000000 --- a/kokoro/root_tests.cfg +++ /dev/null @@ -1,10 +0,0 @@ -build_file: "repo/scripts/root_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc_logs_*.tar.gz" - } -} diff --git a/kokoro/runtime_tests.cfg b/kokoro/runtime_tests.cfg deleted file mode 100644 index 7d56d5aca..000000000 --- a/kokoro/runtime_tests.cfg +++ /dev/null @@ -1 +0,0 @@ -build_file: "repo/scripts/runtime_tests.sh" diff --git a/kokoro/runtime_tests/go1.12.cfg b/kokoro/runtime_tests/go1.12.cfg deleted file mode 100644 index fd4911e88..000000000 --- a/kokoro/runtime_tests/go1.12.cfg +++ /dev/null @@ -1,16 +0,0 @@ -build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh" - -env_vars { - key: "RUNTIME_TEST_NAME" - value: "go1.12" -} - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc" - regex: "**/runsc.*" - } -}
\ No newline at end of file diff --git a/kokoro/runtime_tests/java11.cfg b/kokoro/runtime_tests/java11.cfg deleted file mode 100644 index 7f8611a08..000000000 --- a/kokoro/runtime_tests/java11.cfg +++ /dev/null @@ -1,16 +0,0 @@ -build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh" - -env_vars { - key: "RUNTIME_TEST_NAME" - value: "java11" -} - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc" - regex: "**/runsc.*" - } -}
\ No newline at end of file diff --git a/kokoro/runtime_tests/nodejs12.4.0.cfg b/kokoro/runtime_tests/nodejs12.4.0.cfg deleted file mode 100644 index c67ad5567..000000000 --- a/kokoro/runtime_tests/nodejs12.4.0.cfg +++ /dev/null @@ -1,16 +0,0 @@ -build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh" - -env_vars { - key: "RUNTIME_TEST_NAME" - value: "nodejs12.4.0" -} - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc" - regex: "**/runsc.*" - } -}
\ No newline at end of file diff --git a/kokoro/runtime_tests/php7.3.6.cfg b/kokoro/runtime_tests/php7.3.6.cfg deleted file mode 100644 index f266c5e26..000000000 --- a/kokoro/runtime_tests/php7.3.6.cfg +++ /dev/null @@ -1,16 +0,0 @@ -build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh" - -env_vars { - key: "RUNTIME_TEST_NAME" - value: "php7.3.6" -} - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc" - regex: "**/runsc.*" - } -}
\ No newline at end of file diff --git a/kokoro/runtime_tests/python3.7.3.cfg b/kokoro/runtime_tests/python3.7.3.cfg deleted file mode 100644 index 574add152..000000000 --- a/kokoro/runtime_tests/python3.7.3.cfg +++ /dev/null @@ -1,16 +0,0 @@ -build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh" - -env_vars { - key: "RUNTIME_TEST_NAME" - value: "python3.7.3" -} - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - regex: "**/runsc" - regex: "**/runsc.*" - } -}
\ No newline at end of file diff --git a/kokoro/runtime_tests/runtime_tests.sh b/kokoro/runtime_tests/runtime_tests.sh deleted file mode 100755 index 73a58f806..000000000 --- a/kokoro/runtime_tests/runtime_tests.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Run in the root of the repo. -cd "$(dirname "$0")" -cd "$(git rev-parse --show-toplevel)" - -source scripts/common.sh - -if [ ! -v RUNTIME_TEST_NAME ]; then - echo 'Must set $RUNTIME_TEST_NAME' >&2 - exit 1 -fi - -install_runsc_for_test runtimes -test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}_test" diff --git a/kokoro/simple_tests.cfg b/kokoro/simple_tests.cfg deleted file mode 100644 index 32e0a9431..000000000 --- a/kokoro/simple_tests.cfg +++ /dev/null @@ -1,9 +0,0 @@ -build_file: "repo/scripts/simple_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/kokoro/swgso_tests.cfg b/kokoro/swgso_tests.cfg deleted file mode 100644 index 101a9c607..000000000 --- a/kokoro/swgso_tests.cfg +++ /dev/null @@ -1,9 +0,0 @@ -build_file: "repo/scripts/swgso_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/kokoro/syscall_kvm_tests.cfg b/kokoro/syscall_kvm_tests.cfg deleted file mode 100644 index 3b99e9c13..000000000 --- a/kokoro/syscall_kvm_tests.cfg +++ /dev/null @@ -1,9 +0,0 @@ -build_file: "repo/scripts/syscall_kvm_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/kokoro/syscall_tests.cfg b/kokoro/syscall_tests.cfg deleted file mode 100644 index ee6e4a3a4..000000000 --- a/kokoro/syscall_tests.cfg +++ /dev/null @@ -1,9 +0,0 @@ -build_file: "repo/scripts/syscall_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD deleted file mode 100644 index 839f822eb..000000000 --- a/pkg/abi/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "abi", - srcs = [ - "abi.go", - "abi_linux.go", - "flag.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/abi/abi_linux_state_autogen.go b/pkg/abi/abi_linux_state_autogen.go new file mode 100755 index 000000000..327ef0e5c --- /dev/null +++ b/pkg/abi/abi_linux_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build linux + +package abi diff --git a/pkg/abi/abi_state_autogen.go b/pkg/abi/abi_state_autogen.go new file mode 100755 index 000000000..d54002c3b --- /dev/null +++ b/pkg/abi/abi_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package abi diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD deleted file mode 100644 index 322d1ccc4..000000000 --- a/pkg/abi/linux/BUILD +++ /dev/null @@ -1,82 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -# Package linux contains the constants and types needed to interface with a -# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix -# when the host OS may not be Linux. - -package(licenses = ["notice"]) - -go_library( - name = "linux", - srcs = [ - "aio.go", - "audit.go", - "bpf.go", - "capability.go", - "clone.go", - "dev.go", - "elf.go", - "epoll.go", - "epoll_amd64.go", - "epoll_arm64.go", - "errors.go", - "eventfd.go", - "exec.go", - "fcntl.go", - "file.go", - "file_amd64.go", - "file_arm64.go", - "fs.go", - "futex.go", - "inotify.go", - "ioctl.go", - "ioctl_tun.go", - "ip.go", - "ipc.go", - "limits.go", - "linux.go", - "mm.go", - "netdevice.go", - "netfilter.go", - "netlink.go", - "netlink_route.go", - "poll.go", - "prctl.go", - "ptrace.go", - "rseq.go", - "rusage.go", - "sched.go", - "seccomp.go", - "sem.go", - "shm.go", - "signal.go", - "signalfd.go", - "socket.go", - "splice.go", - "tcp.go", - "time.go", - "timer.go", - "tty.go", - "uio.go", - "utsname.go", - "wait.go", - "xattr.go", - ], - marshal = True, - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi", - "//pkg/binary", - "//pkg/bits", - ], -) - -go_test( - name = "linux_test", - size = "small", - srcs = ["netfilter_test.go"], - library = ":linux", - deps = [ - "//pkg/binary", - ], -) diff --git a/pkg/abi/linux/epoll_amd64.go b/pkg/abi/linux/epoll_amd64.go index 34ff18009..34ff18009 100644..100755 --- a/pkg/abi/linux/epoll_amd64.go +++ b/pkg/abi/linux/epoll_amd64.go diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go index f86c35329..f86c35329 100644..100755 --- a/pkg/abi/linux/epoll_arm64.go +++ b/pkg/abi/linux/epoll_arm64.go diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go index 6b72364ea..6b72364ea 100644..100755 --- a/pkg/abi/linux/file_amd64.go +++ b/pkg/abi/linux/file_amd64.go diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go index 6492c9038..6492c9038 100644..100755 --- a/pkg/abi/linux/file_arm64.go +++ b/pkg/abi/linux/file_arm64.go diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go index c59c9c136..c59c9c136 100644..100755 --- a/pkg/abi/linux/ioctl_tun.go +++ b/pkg/abi/linux/ioctl_tun.go diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go new file mode 100755 index 000000000..46f7c1197 --- /dev/null +++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go @@ -0,0 +1,1034 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*RSeqCriticalSection)(nil) +var _ marshal.Marshallable = (*SignalSet)(nil) +var _ marshal.Marshallable = (*Statfs)(nil) +var _ marshal.Marshallable = (*Statx)(nil) +var _ marshal.Marshallable = (*StatxTimestamp)(nil) +var _ marshal.Marshallable = (*Timespec)(nil) +var _ marshal.Marshallable = (*Timeval)(nil) +var _ marshal.Marshallable = (*Utime)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Statx) SizeBytes() int { + return 80 + + (*StatxTimestamp)(nil).SizeBytes() + + (*StatxTimestamp)(nil).SizeBytes() + + (*StatxTimestamp)(nil).SizeBytes() + + (*StatxTimestamp)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Statx) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.Mask)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.Blksize)) + dst = dst[4:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Attributes)) + dst = dst[8:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.Nlink)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.UID)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.GID)) + dst = dst[4:] + usermem.ByteOrder.PutUint16(dst[:2], uint16(s.Mode)) + dst = dst[2:] + // Padding: dst[:sizeof(uint16)] ~= uint16(0) + dst = dst[2:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Ino)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Size)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Blocks)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.AttributesMask)) + dst = dst[8:] + s.Atime.MarshalBytes(dst[:s.Atime.SizeBytes()]) + dst = dst[s.Atime.SizeBytes():] + s.Btime.MarshalBytes(dst[:s.Btime.SizeBytes()]) + dst = dst[s.Btime.SizeBytes():] + s.Ctime.MarshalBytes(dst[:s.Ctime.SizeBytes()]) + dst = dst[s.Ctime.SizeBytes():] + s.Mtime.MarshalBytes(dst[:s.Mtime.SizeBytes()]) + dst = dst[s.Mtime.SizeBytes():] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.RdevMajor)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.RdevMinor)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.DevMajor)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.DevMinor)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Statx) UnmarshalBytes(src []byte) { + s.Mask = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Blksize = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Attributes = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Nlink = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.UID = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.GID = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Mode = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: var _ uint16 ~= src[:sizeof(uint16)] + src = src[2:] + s.Ino = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Size = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blocks = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.AttributesMask = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Atime.UnmarshalBytes(src[:s.Atime.SizeBytes()]) + src = src[s.Atime.SizeBytes():] + s.Btime.UnmarshalBytes(src[:s.Btime.SizeBytes()]) + src = src[s.Btime.SizeBytes():] + s.Ctime.UnmarshalBytes(src[:s.Ctime.SizeBytes()]) + src = src[s.Ctime.SizeBytes():] + s.Mtime.UnmarshalBytes(src[:s.Mtime.SizeBytes()]) + src = src[s.Mtime.SizeBytes():] + s.RdevMajor = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.RdevMinor = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.DevMajor = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.DevMinor = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +func (s *Statx) Packed() bool { + return s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Statx) MarshalUnsafe(dst []byte) { + if s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() { + safecopy.CopyIn(dst, unsafe.Pointer(s)) + } else { + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Statx) UnmarshalUnsafe(src []byte) { + if s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + safecopy.CopyOut(unsafe.Pointer(s), src) + } else { + s.UnmarshalBytes(src) + } +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) error { + if !s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() { + // Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := task.CopyScratchBuffer(s.SizeBytes()) + s.MarshalBytes(buf) + _, err := task.CopyOutBytes(addr, buf) + return err + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyOutBytes. + runtime.KeepAlive(s) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) error { + if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + // Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := task.CopyScratchBuffer(s.SizeBytes()) + _, err := task.CopyInBytes(addr, buf) + if err != nil { + return err + } + s.UnmarshalBytes(buf) + return nil + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyInBytes. + runtime.KeepAlive(s) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Statx) WriteTo(w io.Writer) (int64, error) { + if !s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() { + // Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + n, err := w.Write(buf) + return int64(n), err + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the Write. + runtime.KeepAlive(s) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Statfs) SizeBytes() int { + return 120 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Statfs) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Type)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.BlockSize)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Blocks)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.BlocksFree)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.BlocksAvailable)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Files)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.FilesFree)) + dst = dst[8:] + for idx := 0; idx < 2; idx++ { + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.FSID[idx])) + dst = dst[4:] + } + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.NameLength)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.FragmentSize)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Flags)) + dst = dst[8:] + for idx := 0; idx < 4; idx++ { + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Spare[idx])) + dst = dst[8:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Statfs) UnmarshalBytes(src []byte) { + s.Type = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.BlockSize = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blocks = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.BlocksFree = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.BlocksAvailable = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Files = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.FilesFree = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 2; idx++ { + s.FSID[idx] = int32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + s.NameLength = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.FragmentSize = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Flags = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 4; idx++ { + s.Spare[idx] = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + } +} + +// Packed implements marshal.Marshallable.Packed. +func (s *Statfs) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Statfs) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(s)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Statfs) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(s), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (s *Statfs) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyOutBytes. + runtime.KeepAlive(s) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (s *Statfs) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyInBytes. + runtime.KeepAlive(s) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Statfs) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the Write. + runtime.KeepAlive(s) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *RSeqCriticalSection) SizeBytes() int { + return 32 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *RSeqCriticalSection) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], uint32(r.Version)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(r.Flags)) + dst = dst[4:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(r.Start)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(r.PostCommitOffset)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(r.Abort)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *RSeqCriticalSection) UnmarshalBytes(src []byte) { + r.Version = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + r.Flags = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + r.Start = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.PostCommitOffset = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.Abort = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +func (r *RSeqCriticalSection) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (r *RSeqCriticalSection) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(r)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (r *RSeqCriticalSection) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(r), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (r *RSeqCriticalSection) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on r. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on r. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(r) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by r's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until after the CopyOutBytes. + runtime.KeepAlive(r) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (r *RSeqCriticalSection) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on r. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on r. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(r) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by r's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until after the CopyInBytes. + runtime.KeepAlive(r) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (r *RSeqCriticalSection) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on r. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on r. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(r) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by r's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until after the Write. + runtime.KeepAlive(r) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SignalSet) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SignalSet) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(*s)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SignalSet) UnmarshalBytes(src []byte) { + *s = SignalSet(uint64(usermem.ByteOrder.Uint64(src[:8]))) +} + +// Packed implements marshal.Marshallable.Packed. +func (s *SignalSet) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SignalSet) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(s)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SignalSet) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(s), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (s *SignalSet) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyOutBytes. + runtime.KeepAlive(s) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (s *SignalSet) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyInBytes. + runtime.KeepAlive(s) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SignalSet) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the Write. + runtime.KeepAlive(s) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *Timespec) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *Timespec) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(t.Sec)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(t.Nsec)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *Timespec) UnmarshalBytes(src []byte) { + t.Sec = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.Nsec = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +func (t *Timespec) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *Timespec) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(t)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *Timespec) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(t), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (t *Timespec) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on t. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on t. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(t) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by t's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until after the CopyOutBytes. + runtime.KeepAlive(t) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (t *Timespec) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on t. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on t. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(t) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by t's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until after the CopyInBytes. + runtime.KeepAlive(t) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *Timespec) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on t. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on t. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(t) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by t's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until after the Write. + runtime.KeepAlive(t) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *Timeval) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *Timeval) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(t.Sec)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(t.Usec)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *Timeval) UnmarshalBytes(src []byte) { + t.Sec = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.Usec = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +func (t *Timeval) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *Timeval) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(t)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *Timeval) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(t), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (t *Timeval) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on t. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on t. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(t) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by t's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until after the CopyOutBytes. + runtime.KeepAlive(t) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (t *Timeval) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on t. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on t. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(t) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by t's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until after the CopyInBytes. + runtime.KeepAlive(t) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *Timeval) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on t. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on t. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(t) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by t's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until after the Write. + runtime.KeepAlive(t) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *StatxTimestamp) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *StatxTimestamp) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Sec)) + dst = dst[8:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.Nsec)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *StatxTimestamp) UnmarshalBytes(src []byte) { + s.Sec = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Nsec = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +func (s *StatxTimestamp) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *StatxTimestamp) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(s)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *StatxTimestamp) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(s), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (s *StatxTimestamp) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyOutBytes. + runtime.KeepAlive(s) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (s *StatxTimestamp) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyInBytes. + runtime.KeepAlive(s) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *StatxTimestamp) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the Write. + runtime.KeepAlive(s) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *Utime) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *Utime) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(u.Actime)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(u.Modtime)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *Utime) UnmarshalBytes(src []byte) { + u.Actime = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + u.Modtime = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +func (u *Utime) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *Utime) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(u)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *Utime) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(u), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (u *Utime) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on u. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on u. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(u) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by u's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until after the CopyOutBytes. + runtime.KeepAlive(u) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (u *Utime) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on u. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on u. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(u) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by u's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until after the CopyInBytes. + runtime.KeepAlive(u) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *Utime) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on u. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on u. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(u) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by u's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until after the Write. + runtime.KeepAlive(u) + return int64(len), err +} + diff --git a/pkg/abi/linux/linux_amd64_abi_autogen_unsafe.go b/pkg/abi/linux/linux_amd64_abi_autogen_unsafe.go new file mode 100755 index 000000000..9b9faaa36 --- /dev/null +++ b/pkg/abi/linux/linux_amd64_abi_autogen_unsafe.go @@ -0,0 +1,325 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// +build amd64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*EpollEvent)(nil) +var _ marshal.Marshallable = (*Stat)(nil) +var _ marshal.Marshallable = (*Timespec)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (e *EpollEvent) SizeBytes() int { + return 12 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (e *EpollEvent) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], uint32(e.Events)) + dst = dst[4:] + for idx := 0; idx < 2; idx++ { + usermem.ByteOrder.PutUint32(dst[:4], uint32(e.Data[idx])) + dst = dst[4:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (e *EpollEvent) UnmarshalBytes(src []byte) { + e.Events = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < 2; idx++ { + e.Data[idx] = int32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + } +} + +// Packed implements marshal.Marshallable.Packed. +func (e *EpollEvent) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (e *EpollEvent) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(e)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (e *EpollEvent) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(e), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (e *EpollEvent) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on e. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on e. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(e) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by e's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until after the CopyOutBytes. + runtime.KeepAlive(e) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (e *EpollEvent) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on e. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on e. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(e) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by e's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until after the CopyInBytes. + runtime.KeepAlive(e) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (e *EpollEvent) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on e. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on e. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(e) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by e's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until after the Write. + runtime.KeepAlive(e) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Stat) SizeBytes() int { + return 96 + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Stat) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Dev)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Ino)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Nlink)) + dst = dst[8:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.Mode)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.UID)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.GID)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Rdev)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Size)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Blksize)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Blocks)) + dst = dst[8:] + s.ATime.MarshalBytes(dst[:s.ATime.SizeBytes()]) + dst = dst[s.ATime.SizeBytes():] + s.MTime.MarshalBytes(dst[:s.MTime.SizeBytes()]) + dst = dst[s.MTime.SizeBytes():] + s.CTime.MarshalBytes(dst[:s.CTime.SizeBytes()]) + dst = dst[s.CTime.SizeBytes():] + // Padding: dst[:sizeof(int64)*3] ~= [3]int64{0} + dst = dst[24:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Stat) UnmarshalBytes(src []byte) { + s.Dev = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Ino = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Nlink = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Mode = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.UID = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.GID = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] + s.Rdev = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Size = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blksize = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blocks = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ATime.UnmarshalBytes(src[:s.ATime.SizeBytes()]) + src = src[s.ATime.SizeBytes():] + s.MTime.UnmarshalBytes(src[:s.MTime.SizeBytes()]) + src = src[s.MTime.SizeBytes():] + s.CTime.UnmarshalBytes(src[:s.CTime.SizeBytes()]) + src = src[s.CTime.SizeBytes():] + // Padding: ~ copy([3]int64(s._), src[:sizeof(int64)*3]) + src = src[24:] +} + +// Packed implements marshal.Marshallable.Packed. +func (s *Stat) Packed() bool { + return s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Stat) MarshalUnsafe(dst []byte) { + if s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + safecopy.CopyIn(dst, unsafe.Pointer(s)) + } else { + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Stat) UnmarshalUnsafe(src []byte) { + if s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + safecopy.CopyOut(unsafe.Pointer(s), src) + } else { + s.UnmarshalBytes(src) + } +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (s *Stat) CopyOut(task marshal.Task, addr usermem.Addr) error { + if !s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := task.CopyScratchBuffer(s.SizeBytes()) + s.MarshalBytes(buf) + _, err := task.CopyOutBytes(addr, buf) + return err + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyOutBytes. + runtime.KeepAlive(s) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (s *Stat) CopyIn(task marshal.Task, addr usermem.Addr) error { + if !s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := task.CopyScratchBuffer(s.SizeBytes()) + _, err := task.CopyInBytes(addr, buf) + if err != nil { + return err + } + s.UnmarshalBytes(buf) + return nil + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyInBytes. + runtime.KeepAlive(s) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Stat) WriteTo(w io.Writer) (int64, error) { + if !s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + n, err := w.Write(buf) + return int64(n), err + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the Write. + runtime.KeepAlive(s) + return int64(len), err +} + diff --git a/pkg/abi/linux/linux_amd64_state_autogen.go b/pkg/abi/linux/linux_amd64_state_autogen.go new file mode 100755 index 000000000..a5f55a80b --- /dev/null +++ b/pkg/abi/linux/linux_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package linux diff --git a/pkg/abi/linux/linux_arm64_abi_autogen_unsafe.go b/pkg/abi/linux/linux_arm64_abi_autogen_unsafe.go new file mode 100755 index 000000000..63dd339c2 --- /dev/null +++ b/pkg/abi/linux/linux_arm64_abi_autogen_unsafe.go @@ -0,0 +1,333 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// +build arm64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*EpollEvent)(nil) +var _ marshal.Marshallable = (*Stat)(nil) +var _ marshal.Marshallable = (*Timespec)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (e *EpollEvent) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (e *EpollEvent) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], uint32(e.Events)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] + for idx := 0; idx < 2; idx++ { + usermem.ByteOrder.PutUint32(dst[:4], uint32(e.Data[idx])) + dst = dst[4:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (e *EpollEvent) UnmarshalBytes(src []byte) { + e.Events = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] + for idx := 0; idx < 2; idx++ { + e.Data[idx] = int32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + } +} + +// Packed implements marshal.Marshallable.Packed. +func (e *EpollEvent) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (e *EpollEvent) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(e)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (e *EpollEvent) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(e), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (e *EpollEvent) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on e. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on e. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(e) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by e's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until after the CopyOutBytes. + runtime.KeepAlive(e) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (e *EpollEvent) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on e. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on e. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(e) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by e's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until after the CopyInBytes. + runtime.KeepAlive(e) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (e *EpollEvent) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on e. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on e. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(e) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by e's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until after the Write. + runtime.KeepAlive(e) + return int64(len), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Stat) SizeBytes() int { + return 80 + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Stat) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Dev)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Ino)) + dst = dst[8:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.Mode)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.Nlink)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.UID)) + dst = dst[4:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.GID)) + dst = dst[4:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Rdev)) + dst = dst[8:] + // Padding: dst[:sizeof(uint64)] ~= uint64(0) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Size)) + dst = dst[8:] + usermem.ByteOrder.PutUint32(dst[:4], uint32(s.Blksize)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.Blocks)) + dst = dst[8:] + s.ATime.MarshalBytes(dst[:s.ATime.SizeBytes()]) + dst = dst[s.ATime.SizeBytes():] + s.MTime.MarshalBytes(dst[:s.MTime.SizeBytes()]) + dst = dst[s.MTime.SizeBytes():] + s.CTime.MarshalBytes(dst[:s.CTime.SizeBytes()]) + dst = dst[s.CTime.SizeBytes():] + // Padding: dst[:sizeof(int32)*2] ~= [2]int32{0} + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Stat) UnmarshalBytes(src []byte) { + s.Dev = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Ino = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Mode = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Nlink = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.UID = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.GID = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Rdev = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + // Padding: var _ uint64 ~= src[:sizeof(uint64)] + src = src[8:] + s.Size = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blksize = int32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] + s.Blocks = int64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ATime.UnmarshalBytes(src[:s.ATime.SizeBytes()]) + src = src[s.ATime.SizeBytes():] + s.MTime.UnmarshalBytes(src[:s.MTime.SizeBytes()]) + src = src[s.MTime.SizeBytes():] + s.CTime.UnmarshalBytes(src[:s.CTime.SizeBytes()]) + src = src[s.CTime.SizeBytes():] + // Padding: ~ copy([2]int32(s._), src[:sizeof(int32)*2]) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +func (s *Stat) Packed() bool { + return s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Stat) MarshalUnsafe(dst []byte) { + if s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + safecopy.CopyIn(dst, unsafe.Pointer(s)) + } else { + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Stat) UnmarshalUnsafe(src []byte) { + if s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + safecopy.CopyOut(unsafe.Pointer(s), src) + } else { + s.UnmarshalBytes(src) + } +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (s *Stat) CopyOut(task marshal.Task, addr usermem.Addr) error { + if !s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := task.CopyScratchBuffer(s.SizeBytes()) + s.MarshalBytes(buf) + _, err := task.CopyOutBytes(addr, buf) + return err + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyOutBytes. + runtime.KeepAlive(s) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (s *Stat) CopyIn(task marshal.Task, addr usermem.Addr) error { + if !s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := task.CopyScratchBuffer(s.SizeBytes()) + _, err := task.CopyInBytes(addr, buf) + if err != nil { + return err + } + s.UnmarshalBytes(buf) + return nil + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyInBytes. + runtime.KeepAlive(s) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Stat) WriteTo(w io.Writer) (int64, error) { + if !s.ATime.Packed() && s.MTime.Packed() && s.CTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + n, err := w.Write(buf) + return int64(n), err + } + + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the Write. + runtime.KeepAlive(s) + return int64(len), err +} + diff --git a/pkg/abi/linux/linux_arm64_state_autogen.go b/pkg/abi/linux/linux_arm64_state_autogen.go new file mode 100755 index 000000000..7b31374fe --- /dev/null +++ b/pkg/abi/linux/linux_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package linux diff --git a/pkg/abi/linux/linux_state_autogen.go b/pkg/abi/linux/linux_state_autogen.go new file mode 100755 index 000000000..b8e488a11 --- /dev/null +++ b/pkg/abi/linux/linux_state_autogen.go @@ -0,0 +1,68 @@ +// automatically generated by stateify. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *BPFInstruction) beforeSave() {} +func (x *BPFInstruction) save(m state.Map) { + x.beforeSave() + m.Save("OpCode", &x.OpCode) + m.Save("JumpIfTrue", &x.JumpIfTrue) + m.Save("JumpIfFalse", &x.JumpIfFalse) + m.Save("K", &x.K) +} + +func (x *BPFInstruction) afterLoad() {} +func (x *BPFInstruction) load(m state.Map) { + m.Load("OpCode", &x.OpCode) + m.Load("JumpIfTrue", &x.JumpIfTrue) + m.Load("JumpIfFalse", &x.JumpIfFalse) + m.Load("K", &x.K) +} + +func (x *KernelTermios) beforeSave() {} +func (x *KernelTermios) save(m state.Map) { + x.beforeSave() + m.Save("InputFlags", &x.InputFlags) + m.Save("OutputFlags", &x.OutputFlags) + m.Save("ControlFlags", &x.ControlFlags) + m.Save("LocalFlags", &x.LocalFlags) + m.Save("LineDiscipline", &x.LineDiscipline) + m.Save("ControlCharacters", &x.ControlCharacters) + m.Save("InputSpeed", &x.InputSpeed) + m.Save("OutputSpeed", &x.OutputSpeed) +} + +func (x *KernelTermios) afterLoad() {} +func (x *KernelTermios) load(m state.Map) { + m.Load("InputFlags", &x.InputFlags) + m.Load("OutputFlags", &x.OutputFlags) + m.Load("ControlFlags", &x.ControlFlags) + m.Load("LocalFlags", &x.LocalFlags) + m.Load("LineDiscipline", &x.LineDiscipline) + m.Load("ControlCharacters", &x.ControlCharacters) + m.Load("InputSpeed", &x.InputSpeed) + m.Load("OutputSpeed", &x.OutputSpeed) +} + +func (x *WindowSize) beforeSave() {} +func (x *WindowSize) save(m state.Map) { + x.beforeSave() + m.Save("Rows", &x.Rows) + m.Save("Cols", &x.Cols) +} + +func (x *WindowSize) afterLoad() {} +func (x *WindowSize) load(m state.Map) { + m.Load("Rows", &x.Rows) + m.Load("Cols", &x.Cols) +} + +func init() { + state.Register("pkg/abi/linux.BPFInstruction", (*BPFInstruction)(nil), state.Fns{Save: (*BPFInstruction).save, Load: (*BPFInstruction).load}) + state.Register("pkg/abi/linux.KernelTermios", (*KernelTermios)(nil), state.Fns{Save: (*KernelTermios).save, Load: (*KernelTermios).load}) + state.Register("pkg/abi/linux.WindowSize", (*WindowSize)(nil), state.Fns{Save: (*WindowSize).save, Load: (*WindowSize).load}) +} diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go deleted file mode 100644 index 21e237f92..000000000 --- a/pkg/abi/linux/netfilter_test.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package linux - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/binary" -) - -func TestSizes(t *testing.T) { - testCases := []struct { - typ interface{} - defined uintptr - }{ - {IPTEntry{}, SizeOfIPTEntry}, - {IPTGetEntries{}, SizeOfIPTGetEntries}, - {IPTGetinfo{}, SizeOfIPTGetinfo}, - {IPTIP{}, SizeOfIPTIP}, - {IPTReplace{}, SizeOfIPTReplace}, - {XTCounters{}, SizeOfXTCounters}, - {XTEntryMatch{}, SizeOfXTEntryMatch}, - {XTEntryTarget{}, SizeOfXTEntryTarget}, - {XTErrorTarget{}, SizeOfXTErrorTarget}, - {XTStandardTarget{}, SizeOfXTStandardTarget}, - } - - for _, tc := range testCases { - if calculated := binary.Size(tc.typ); calculated != tc.defined { - t.Errorf("%T has a defined size of %d and calculated size of %d", tc.typ, tc.defined, calculated) - } - } -} diff --git a/pkg/abi/linux/rseq.go b/pkg/abi/linux/rseq.go index 76253ba30..76253ba30 100644..100755 --- a/pkg/abi/linux/rseq.go +++ b/pkg/abi/linux/rseq.go diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go index 85fad9956..85fad9956 100644..100755 --- a/pkg/abi/linux/signalfd.go +++ b/pkg/abi/linux/signalfd.go diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go index 99180b208..99180b208 100644..100755 --- a/pkg/abi/linux/xattr.go +++ b/pkg/abi/linux/xattr.go diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD deleted file mode 100644 index 9612f072e..000000000 --- a/pkg/amutex/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "amutex", - srcs = ["amutex.go"], - visibility = ["//:sandbox"], -) - -go_test( - name = "amutex_test", - size = "small", - srcs = ["amutex_test.go"], - library = ":amutex", - deps = ["//pkg/sync"], -) diff --git a/pkg/amutex/amutex_state_autogen.go b/pkg/amutex/amutex_state_autogen.go new file mode 100755 index 000000000..5a09c71ed --- /dev/null +++ b/pkg/amutex/amutex_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package amutex diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go deleted file mode 100644 index 8a3952f2a..000000000 --- a/pkg/amutex/amutex_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package amutex - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -type sleeper struct { - ch chan struct{} -} - -func (s *sleeper) SleepStart() <-chan struct{} { - return s.ch -} - -func (*sleeper) SleepFinish(bool) { -} - -func (s *sleeper) Interrupted() bool { - return len(s.ch) != 0 -} - -func TestMutualExclusion(t *testing.T) { - var m AbortableMutex - m.Init() - - // Test mutual exclusion by running "gr" goroutines concurrently, and - // have each one increment a counter "iters" times within the critical - // section established by the mutex. - // - // If at the end of the counter is not gr * iters, then we know that - // goroutines ran concurrently within the critical section. - // - // If one of the goroutines doesn't complete, it's likely a bug that - // causes it to wait forever. - const gr = 1000 - const iters = 100000 - v := 0 - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(1) - go func() { - for j := 0; j < iters; j++ { - m.Lock(nil) - v++ - m.Unlock() - } - wg.Done() - }() - } - - wg.Wait() - - if v != gr*iters { - t.Fatalf("Bad count: got %v, want %v", v, gr*iters) - } -} - -func TestAbortWait(t *testing.T) { - var s sleeper - var m AbortableMutex - m.Init() - - // Lock the mutex. - m.Lock(&s) - - // Lock again, but this time cancel after 500ms. - s.ch = make(chan struct{}, 1) - go func() { - time.Sleep(500 * time.Millisecond) - s.ch <- struct{}{} - }() - if v := m.Lock(&s); v { - t.Fatalf("Lock succeeded when it should have failed") - } - - // Lock again, but cancel right away. - s.ch <- struct{}{} - if v := m.Lock(&s); v { - t.Fatalf("Lock succeeded when it should have failed") - } -} diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD deleted file mode 100644 index 1a30f6967..000000000 --- a/pkg/atomicbitops/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "atomicbitops", - srcs = [ - "atomicbitops.go", - "atomicbitops_amd64.s", - "atomicbitops_arm64.s", - "atomicbitops_noasm.go", - ], - visibility = ["//:sandbox"], -) - -go_test( - name = "atomicbitops_test", - size = "small", - srcs = ["atomicbitops_test.go"], - library = ":atomicbitops", - deps = ["//pkg/sync"], -) diff --git a/pkg/atomicbitops/atomicbitops.go b/pkg/atomicbitops/atomicbitops.go index 1be081719..1be081719 100644..100755 --- a/pkg/atomicbitops/atomicbitops.go +++ b/pkg/atomicbitops/atomicbitops.go diff --git a/pkg/atomicbitops/atomicbitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s index 54c887ee5..54c887ee5 100644..100755 --- a/pkg/atomicbitops/atomicbitops_amd64.s +++ b/pkg/atomicbitops/atomicbitops_amd64.s diff --git a/pkg/atomicbitops/atomicbitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s index 5c780851b..5c780851b 100644..100755 --- a/pkg/atomicbitops/atomicbitops_arm64.s +++ b/pkg/atomicbitops/atomicbitops_arm64.s diff --git a/pkg/atomicbitops/atomicbitops_noasm.go b/pkg/atomicbitops/atomicbitops_noasm.go index 3b2898256..3b2898256 100644..100755 --- a/pkg/atomicbitops/atomicbitops_noasm.go +++ b/pkg/atomicbitops/atomicbitops_noasm.go diff --git a/pkg/atomicbitops/atomicbitops_state_autogen.go b/pkg/atomicbitops/atomicbitops_state_autogen.go new file mode 100755 index 000000000..06fcf712a --- /dev/null +++ b/pkg/atomicbitops/atomicbitops_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build amd64 arm64 +// +build !amd64,!arm64 + +package atomicbitops diff --git a/pkg/atomicbitops/atomicbitops_test.go b/pkg/atomicbitops/atomicbitops_test.go deleted file mode 100644 index 73af71bb4..000000000 --- a/pkg/atomicbitops/atomicbitops_test.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package atomicbitops - -import ( - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -const iterations = 100 - -func detectRaces32(val, target uint32, fn func(*uint32, uint32)) bool { - runtime.GOMAXPROCS(100) - for n := 0; n < iterations; n++ { - x := val - var wg sync.WaitGroup - for i := uint32(0); i < 32; i++ { - wg.Add(1) - go func(a *uint32, i uint32) { - defer wg.Done() - fn(a, uint32(1<<i)) - }(&x, i) - } - wg.Wait() - if x != target { - return true - } - } - return false -} - -func detectRaces64(val, target uint64, fn func(*uint64, uint64)) bool { - runtime.GOMAXPROCS(100) - for n := 0; n < iterations; n++ { - x := val - var wg sync.WaitGroup - for i := uint64(0); i < 64; i++ { - wg.Add(1) - go func(a *uint64, i uint64) { - defer wg.Done() - fn(a, uint64(1<<i)) - }(&x, i) - } - wg.Wait() - if x != target { - return true - } - } - return false -} - -func TestOrUint32(t *testing.T) { - if detectRaces32(0x0, 0xffffffff, OrUint32) { - t.Error("Data race detected!") - } -} - -func TestAndUint32(t *testing.T) { - if detectRaces32(0xf0f0f0f0, 0x00000000, AndUint32) { - t.Error("Data race detected!") - } -} - -func TestXorUint32(t *testing.T) { - if detectRaces32(0xf0f0f0f0, 0x0f0f0f0f, XorUint32) { - t.Error("Data race detected!") - } -} - -func TestOrUint64(t *testing.T) { - if detectRaces64(0x0, 0xffffffffffffffff, OrUint64) { - t.Error("Data race detected!") - } -} - -func TestAndUint64(t *testing.T) { - if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0, AndUint64) { - t.Error("Data race detected!") - } -} - -func TestXorUint64(t *testing.T) { - if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0f0f0f0f0f0f0f0f, XorUint64) { - t.Error("Data race detected!") - } -} - -func TestCompareAndSwapUint32(t *testing.T) { - tests := []struct { - name string - prev uint32 - old uint32 - new uint32 - next uint32 - }{ - { - name: "Successful compare-and-swap with prev == new", - prev: 10, - old: 10, - new: 10, - next: 10, - }, - { - name: "Successful compare-and-swap with prev != new", - prev: 20, - old: 20, - new: 22, - next: 22, - }, - { - name: "Failed compare-and-swap with prev == new", - prev: 31, - old: 30, - new: 31, - next: 31, - }, - { - name: "Failed compare-and-swap with prev != new", - prev: 41, - old: 40, - new: 42, - next: 41, - }, - } - for _, test := range tests { - val := test.prev - prev := CompareAndSwapUint32(&val, test.old, test.new) - if got, want := prev, test.prev; got != want { - t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want) - } - if got, want := val, test.next; got != want { - t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want) - } - } -} - -func TestCompareAndSwapUint64(t *testing.T) { - tests := []struct { - name string - prev uint64 - old uint64 - new uint64 - next uint64 - }{ - { - name: "Successful compare-and-swap with prev == new", - prev: 0x100000000, - old: 0x100000000, - new: 0x100000000, - next: 0x100000000, - }, - { - name: "Successful compare-and-swap with prev != new", - prev: 0x200000000, - old: 0x200000000, - new: 0x200000002, - next: 0x200000002, - }, - { - name: "Failed compare-and-swap with prev == new", - prev: 0x300000001, - old: 0x300000000, - new: 0x300000001, - next: 0x300000001, - }, - { - name: "Failed compare-and-swap with prev != new", - prev: 0x400000001, - old: 0x400000000, - new: 0x400000002, - next: 0x400000001, - }, - } - for _, test := range tests { - val := test.prev - prev := CompareAndSwapUint64(&val, test.old, test.new) - if got, want := prev, test.prev; got != want { - t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want) - } - if got, want := val, test.next; got != want { - t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want) - } - } -} diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD deleted file mode 100644 index 7ca2fda90..000000000 --- a/pkg/binary/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "binary", - srcs = ["binary.go"], - visibility = ["//:sandbox"], -) - -go_test( - name = "binary_test", - size = "small", - srcs = ["binary_test.go"], - library = ":binary", -) diff --git a/pkg/binary/binary_state_autogen.go b/pkg/binary/binary_state_autogen.go new file mode 100755 index 000000000..4661a5982 --- /dev/null +++ b/pkg/binary/binary_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package binary diff --git a/pkg/binary/binary_test.go b/pkg/binary/binary_test.go deleted file mode 100644 index 4d609a438..000000000 --- a/pkg/binary/binary_test.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package binary - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "reflect" - "strings" - "testing" -) - -func newInt32(i int32) *int32 { - return &i -} - -func TestSize(t *testing.T) { - if got, want := Size(uint32(10)), uintptr(4); got != want { - t.Errorf("Got = %d, want = %d", got, want) - } -} - -func TestPanic(t *testing.T) { - tests := []struct { - name string - f func([]byte, binary.ByteOrder, interface{}) - data interface{} - want string - }{ - {"Unmarshal int", Unmarshal, 5, "invalid type: int"}, - {"Unmarshal []int", Unmarshal, []int{5}, "invalid type: int"}, - {"Marshal int", func(_ []byte, bo binary.ByteOrder, d interface{}) { Marshal(nil, bo, d) }, 5, "invalid type: int"}, - {"Marshal int[]", func(_ []byte, bo binary.ByteOrder, d interface{}) { Marshal(nil, bo, d) }, []int{5}, "invalid type: int"}, - {"Unmarshal short buffer", Unmarshal, newInt32(5), "runtime error: index out of range"}, - {"Unmarshal long buffer", func(_ []byte, bo binary.ByteOrder, d interface{}) { Unmarshal(make([]byte, 50), bo, d) }, newInt32(5), "buffer too long by 46 bytes"}, - {"marshal int", func(_ []byte, bo binary.ByteOrder, d interface{}) { marshal(nil, bo, reflect.ValueOf(d)) }, 5, "invalid type: int"}, - {"Size int", func(_ []byte, _ binary.ByteOrder, d interface{}) { Size(d) }, 5, "invalid type: int"}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - defer func() { - r := recover() - if got := fmt.Sprint(r); !strings.HasPrefix(got, test.want) { - t.Errorf("Got recover() = %q, want prefix = %q", got, test.want) - } - }() - - test.f(nil, LittleEndian, test.data) - }) - } -} - -type inner struct { - Field int32 -} - -type outer struct { - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - - Slice []int32 - Array [5]int32 - Struct inner -} - -func TestMarshalUnmarshal(t *testing.T) { - want := outer{ - 1, 2, 3, 4, 5, 6, 7, 8, - []int32{9, 10, 11}, - [5]int32{12, 13, 14, 15, 16}, - inner{17}, - } - buf := Marshal(nil, LittleEndian, want) - got := outer{Slice: []int32{0, 0, 0}} - Unmarshal(buf, LittleEndian, &got) - if !reflect.DeepEqual(&got, &want) { - t.Errorf("Got = %#v, want = %#v", got, want) - } -} - -type outerBenchmark struct { - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - - Array [5]int32 - Struct inner -} - -func BenchmarkMarshalUnmarshal(b *testing.B) { - b.ReportAllocs() - - in := outerBenchmark{ - 1, 2, 3, 4, 5, 6, 7, 8, - [5]int32{9, 10, 11, 12, 13}, - inner{14}, - } - buf := make([]byte, Size(&in)) - out := outerBenchmark{} - - for i := 0; i < b.N; i++ { - buf := Marshal(buf[:0], LittleEndian, &in) - Unmarshal(buf, LittleEndian, &out) - } -} - -func BenchmarkReadWrite(b *testing.B) { - b.ReportAllocs() - - in := outerBenchmark{ - 1, 2, 3, 4, 5, 6, 7, 8, - [5]int32{9, 10, 11, 12, 13}, - inner{14}, - } - buf := bytes.NewBuffer(make([]byte, binary.Size(&in))) - out := outerBenchmark{} - - for i := 0; i < b.N; i++ { - buf.Reset() - if err := binary.Write(buf, LittleEndian, &in); err != nil { - b.Error("Write:", err) - } - if err := binary.Read(buf, LittleEndian, &out); err != nil { - b.Error("Read:", err) - } - } -} - -type outerPadding struct { - _ int8 - _ int16 - _ int32 - _ int64 - _ uint8 - _ uint16 - _ uint32 - _ uint64 - - _ []int32 - _ [5]int32 - _ inner -} - -func TestMarshalUnmarshalPadding(t *testing.T) { - var want outerPadding - buf := Marshal(nil, LittleEndian, want) - var got outerPadding - Unmarshal(buf, LittleEndian, &got) - if !reflect.DeepEqual(&got, &want) { - t.Errorf("Got = %#v, want = %#v", got, want) - } -} - -// Numbers with bits in every byte that distinguishable in big and little endian. -const ( - want16 = 64<<8 | 128 - want32 = 16<<24 | 32<<16 | want16 - want64 = 1<<56 | 2<<48 | 4<<40 | 8<<32 | want32 -) - -func TestReadWriteUint16(t *testing.T) { - const want = uint16(want16) - var buf bytes.Buffer - if err := WriteUint16(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint16:", err) - } - got, err := ReadUint16(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint16:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -func TestReadWriteUint32(t *testing.T) { - const want = uint32(want32) - var buf bytes.Buffer - if err := WriteUint32(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint32:", err) - } - got, err := ReadUint32(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint32:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -func TestReadWriteUint64(t *testing.T) { - const want = uint64(want64) - var buf bytes.Buffer - if err := WriteUint64(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint64:", err) - } - got, err := ReadUint64(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint64:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -type readWriter struct { - err error -} - -func (rw *readWriter) Write([]byte) (int, error) { - return 0, rw.err -} - -func (rw *readWriter) Read([]byte) (int, error) { - return 0, rw.err -} - -func TestReadWriteError(t *testing.T) { - tests := []struct { - name string - f func(rw io.ReadWriter) error - }{ - {"WriteUint16", func(rw io.ReadWriter) error { return WriteUint16(rw, LittleEndian, 0) }}, - {"ReadUint16", func(rw io.ReadWriter) error { _, err := ReadUint16(rw, LittleEndian); return err }}, - {"WriteUint32", func(rw io.ReadWriter) error { return WriteUint32(rw, LittleEndian, 0) }}, - {"ReadUint32", func(rw io.ReadWriter) error { _, err := ReadUint32(rw, LittleEndian); return err }}, - {"WriteUint64", func(rw io.ReadWriter) error { return WriteUint64(rw, LittleEndian, 0) }}, - {"ReadUint64", func(rw io.ReadWriter) error { _, err := ReadUint64(rw, LittleEndian); return err }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - want := errors.New("want") - if got := test.f(&readWriter{want}); got != want { - t.Errorf("got = %v, want = %v", got, want) - } - }) - } -} diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD deleted file mode 100644 index 63f4670d7..000000000 --- a/pkg/bits/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_library( - name = "bits", - srcs = [ - "bits.go", - "bits32.go", - "bits64.go", - "uint64_arch.go", - "uint64_arch_amd64_asm.s", - "uint64_arch_arm64_asm.s", - "uint64_arch_generic.go", - ], - visibility = ["//:sandbox"], -) - -go_template( - name = "bits_template", - srcs = ["bits_template.go"], - types = [ - "T", - ], -) - -go_template_instance( - name = "bits64", - out = "bits64.go", - package = "bits", - suffix = "64", - template = ":bits_template", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "bits32", - out = "bits32.go", - package = "bits", - suffix = "32", - template = ":bits_template", - types = { - "T": "uint32", - }, -) - -go_test( - name = "bits_test", - size = "small", - srcs = ["uint64_test.go"], - library = ":bits", -) diff --git a/pkg/bits/bits32.go b/pkg/bits/bits32.go new file mode 100755 index 000000000..4e9e45dce --- /dev/null +++ b/pkg/bits/bits32.go @@ -0,0 +1,25 @@ +package bits + +// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. +func IsOn32(mask, bits uint32) bool { + return mask&bits == bits +} + +// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. +func IsAnyOn32(mask, bits uint32) bool { + return mask&bits != 0 +} + +// Mask returns a T with all of the given bits set. +func Mask32(is ...int) uint32 { + ret := uint32(0) + for _, i := range is { + ret |= MaskOf32(i) + } + return ret +} + +// MaskOf is like Mask, but sets only a single bit (more efficiently). +func MaskOf32(i int) uint32 { + return uint32(1) << uint32(i) +} diff --git a/pkg/bits/bits64.go b/pkg/bits/bits64.go new file mode 100755 index 000000000..f49158792 --- /dev/null +++ b/pkg/bits/bits64.go @@ -0,0 +1,25 @@ +package bits + +// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. +func IsOn64(mask, bits uint64) bool { + return mask&bits == bits +} + +// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. +func IsAnyOn64(mask, bits uint64) bool { + return mask&bits != 0 +} + +// Mask returns a T with all of the given bits set. +func Mask64(is ...int) uint64 { + ret := uint64(0) + for _, i := range is { + ret |= MaskOf64(i) + } + return ret +} + +// MaskOf is like Mask, but sets only a single bit (more efficiently). +func MaskOf64(i int) uint64 { + return uint64(1) << uint64(i) +} diff --git a/pkg/bits/bits_state_autogen.go b/pkg/bits/bits_state_autogen.go new file mode 100755 index 000000000..22b8250c6 --- /dev/null +++ b/pkg/bits/bits_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build amd64 arm64 +// +build !amd64,!arm64 + +package bits diff --git a/pkg/bits/bits_template.go b/pkg/bits/bits_template.go deleted file mode 100644 index 93a435b80..000000000 --- a/pkg/bits/bits_template.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bits - -// Non-atomic bit operations on a template type T. - -// T is a required type parameter that must be an integral type. -type T uint64 - -// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. -func IsOn(mask, bits T) bool { - return mask&bits == bits -} - -// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. -func IsAnyOn(mask, bits T) bool { - return mask&bits != 0 -} - -// Mask returns a T with all of the given bits set. -func Mask(is ...int) T { - ret := T(0) - for _, i := range is { - ret |= MaskOf(i) - } - return ret -} - -// MaskOf is like Mask, but sets only a single bit (more efficiently). -func MaskOf(i int) T { - return T(1) << T(i) -} diff --git a/pkg/bits/uint64_arch.go b/pkg/bits/uint64_arch.go index 9f23eff77..9f23eff77 100644..100755 --- a/pkg/bits/uint64_arch.go +++ b/pkg/bits/uint64_arch.go diff --git a/pkg/bits/uint64_arch_arm64_asm.s b/pkg/bits/uint64_arch_arm64_asm.s index 814ba562d..814ba562d 100644..100755 --- a/pkg/bits/uint64_arch_arm64_asm.s +++ b/pkg/bits/uint64_arch_arm64_asm.s diff --git a/pkg/bits/uint64_test.go b/pkg/bits/uint64_test.go deleted file mode 100644 index 1b018d808..000000000 --- a/pkg/bits/uint64_test.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bits - -import ( - "reflect" - "testing" -) - -func TestTrailingZeros64(t *testing.T) { - for i := 0; i <= 64; i++ { - n := uint64(1) << uint(i) - if got, want := TrailingZeros64(n), i; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) << uint(i) - if got, want := TrailingZeros64(n), i; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) >> uint(i) - if got, want := TrailingZeros64(n), 0; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } -} - -func TestMostSignificantOne64(t *testing.T) { - for i := 0; i <= 64; i++ { - n := uint64(1) << uint(i) - if got, want := MostSignificantOne64(n), i; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) >> uint(i) - if got, want := MostSignificantOne64(n), 63-i; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) << uint(i) - if got, want := MostSignificantOne64(n), 63; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } -} - -func TestForEachSetBit64(t *testing.T) { - for _, want := range [][]int{ - {}, - {0}, - {1}, - {63}, - {0, 1}, - {1, 3, 5}, - {0, 63}, - } { - n := Mask64(want...) - // "Slice values are deeply equal when ... they are both nil or both - // non-nil ..." - got := make([]int, 0) - ForEachSetBit64(n, func(i int) { - got = append(got, i) - }) - if !reflect.DeepEqual(got, want) { - t.Errorf("ForEachSetBit64(%#x): iterated bits %v, wanted %v", n, got, want) - } - } -} - -func TestIsOn(t *testing.T) { - type spec struct { - mask uint64 - bits uint64 - any bool - all bool - } - for _, s := range []spec{ - {Mask64(0), Mask64(0), true, true}, - {Mask64(63), Mask64(63), true, true}, - {Mask64(0), Mask64(1), false, false}, - {Mask64(0), Mask64(0, 1), true, false}, - - {Mask64(1, 63), Mask64(1), true, true}, - {Mask64(1, 63), Mask64(1, 63), true, true}, - {Mask64(1, 63), Mask64(0, 1, 63), true, false}, - {Mask64(1, 63), Mask64(0, 62), false, false}, - } { - if ok := IsAnyOn64(s.mask, s.bits); ok != s.any { - t.Errorf("IsAnyOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.any) - } - if ok := IsOn64(s.mask, s.bits); ok != s.all { - t.Errorf("IsOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.all) - } - } -} diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD deleted file mode 100644 index 2a6977f85..000000000 --- a/pkg/bpf/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "bpf", - srcs = [ - "bpf.go", - "decoder.go", - "input_bytes.go", - "interpreter.go", - "program_builder.go", - ], - visibility = ["//visibility:public"], - deps = ["//pkg/abi/linux"], -) - -go_test( - name = "bpf_test", - size = "small", - srcs = [ - "decoder_test.go", - "interpreter_test.go", - "program_builder_test.go", - ], - library = ":bpf", - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - ], -) diff --git a/pkg/bpf/bpf_state_autogen.go b/pkg/bpf/bpf_state_autogen.go new file mode 100755 index 000000000..ae8a36d57 --- /dev/null +++ b/pkg/bpf/bpf_state_autogen.go @@ -0,0 +1,22 @@ +// automatically generated by stateify. + +package bpf + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Program) beforeSave() {} +func (x *Program) save(m state.Map) { + x.beforeSave() + m.Save("instructions", &x.instructions) +} + +func (x *Program) afterLoad() {} +func (x *Program) load(m state.Map) { + m.Load("instructions", &x.instructions) +} + +func init() { + state.Register("pkg/bpf.Program", (*Program)(nil), state.Fns{Save: (*Program).save, Load: (*Program).load}) +} diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go deleted file mode 100644 index 6a023f0c0..000000000 --- a/pkg/bpf/decoder_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" -) - -func TestDecode(t *testing.T) { - for _, test := range []struct { - filter linux.BPFInstruction - expected string - fail bool - }{ - {filter: Stmt(Ld+Imm, 10), expected: "A <- 10"}, - {filter: Stmt(Ld+Abs+W, 10), expected: "A <- P[10:4]"}, - {filter: Stmt(Ld+Ind+H, 10), expected: "A <- P[X+10:2]"}, - {filter: Stmt(Ld+Ind+B, 10), expected: "A <- P[X+10:1]"}, - {filter: Stmt(Ld+Mem, 10), expected: "A <- M[10]"}, - {filter: Stmt(Ld+Len, 0), expected: "A <- len"}, - {filter: Stmt(Ldx+Imm, 10), expected: "X <- 10"}, - {filter: Stmt(Ldx+Mem, 10), expected: "X <- M[10]"}, - {filter: Stmt(Ldx+Len, 0), expected: "X <- len"}, - {filter: Stmt(Ldx+Msh, 10), expected: "X <- 4*(P[10:1]&0xf)"}, - {filter: Stmt(St, 10), expected: "M[10] <- A"}, - {filter: Stmt(Stx, 10), expected: "M[10] <- X"}, - {filter: Stmt(Alu+Add+K, 10), expected: "A <- A + 10"}, - {filter: Stmt(Alu+Sub+K, 10), expected: "A <- A - 10"}, - {filter: Stmt(Alu+Mul+K, 10), expected: "A <- A * 10"}, - {filter: Stmt(Alu+Div+K, 10), expected: "A <- A / 10"}, - {filter: Stmt(Alu+Or+K, 10), expected: "A <- A | 10"}, - {filter: Stmt(Alu+And+K, 10), expected: "A <- A & 10"}, - {filter: Stmt(Alu+Lsh+K, 10), expected: "A <- A << 10"}, - {filter: Stmt(Alu+Rsh+K, 10), expected: "A <- A >> 10"}, - {filter: Stmt(Alu+Mod+K, 10), expected: "A <- A % 10"}, - {filter: Stmt(Alu+Xor+K, 10), expected: "A <- A ^ 10"}, - {filter: Stmt(Alu+Add+X, 0), expected: "A <- A + X"}, - {filter: Stmt(Alu+Sub+X, 0), expected: "A <- A - X"}, - {filter: Stmt(Alu+Mul+X, 0), expected: "A <- A * X"}, - {filter: Stmt(Alu+Div+X, 0), expected: "A <- A / X"}, - {filter: Stmt(Alu+Or+X, 0), expected: "A <- A | X"}, - {filter: Stmt(Alu+And+X, 0), expected: "A <- A & X"}, - {filter: Stmt(Alu+Lsh+X, 0), expected: "A <- A << X"}, - {filter: Stmt(Alu+Rsh+X, 0), expected: "A <- A >> X"}, - {filter: Stmt(Alu+Mod+X, 0), expected: "A <- A % X"}, - {filter: Stmt(Alu+Xor+X, 0), expected: "A <- A ^ X"}, - {filter: Stmt(Alu+Neg, 0), expected: "A <- -A"}, - {filter: Stmt(Jmp+Ja, 10), expected: "pc += 10"}, - {filter: Jump(Jmp+Jeq+K, 10, 2, 5), expected: "pc += (A == 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jgt+K, 10, 2, 5), expected: "pc += (A > 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jge+K, 10, 2, 5), expected: "pc += (A >= 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jset+K, 10, 2, 5), expected: "pc += (A & 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jeq+X, 0, 2, 5), expected: "pc += (A == X) ? 2 : 5"}, - {filter: Jump(Jmp+Jgt+X, 0, 2, 5), expected: "pc += (A > X) ? 2 : 5"}, - {filter: Jump(Jmp+Jge+X, 0, 2, 5), expected: "pc += (A >= X) ? 2 : 5"}, - {filter: Jump(Jmp+Jset+X, 0, 2, 5), expected: "pc += (A & X) ? 2 : 5"}, - {filter: Stmt(Ret+K, 10), expected: "ret 10"}, - {filter: Stmt(Ret+A, 0), expected: "ret A"}, - {filter: Stmt(Misc+Tax, 0), expected: "X <- A"}, - {filter: Stmt(Misc+Txa, 0), expected: "A <- X"}, - {filter: Stmt(Ld+Ind+Msh, 0), fail: true}, - } { - got, err := Decode(test.filter) - if test.fail { - if err == nil { - t.Errorf("Decode(%v) failed, expected: 'error', got: %q", test.filter, got) - continue - } - } else { - if err != nil { - t.Errorf("Decode(%v) failed for test %q, error: %q", test.filter, test.expected, err) - continue - } - if got != test.expected { - t.Errorf("Decode(%v) failed, expected: %q, got: %q", test.filter, test.expected, got) - continue - } - } - } -} - -func TestDecodeProgram(t *testing.T) { - for _, test := range []struct { - name string - program []linux.BPFInstruction - expected string - fail bool - }{ - {name: "basic with jump indexes", - program: []linux.BPFInstruction{ - Stmt(Ld+Abs+W, 10), - Stmt(Ldx+Mem, 10), - Stmt(St, 10), - Stmt(Stx, 10), - Stmt(Alu+Add+K, 10), - Stmt(Jmp+Ja, 10), - Jump(Jmp+Jeq+K, 10, 2, 5), - Jump(Jmp+Jset+X, 0, 0, 5), - Stmt(Misc+Tax, 0), - }, - expected: "0: A <- P[10:4]\n" + - "1: X <- M[10]\n" + - "2: M[10] <- A\n" + - "3: M[10] <- X\n" + - "4: A <- A + 10\n" + - "5: pc += 10 [16]\n" + - "6: pc += (A == 10) ? 2 [9] : 5 [12]\n" + - "7: pc += (A & X) ? 0 [8] : 5 [13]\n" + - "8: X <- A\n", - }, - {name: "invalid instruction", - program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)}, - fail: true}, - } { - got, err := DecodeProgram(test.program) - if test.fail { - if err == nil { - t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got) - continue - } - } else { - if err != nil { - t.Errorf("%s: Decode failed: %v", test.name, err) - continue - } - if got != test.expected { - t.Errorf("%s: Decode(...) failed, expected: %q, got: %q", test.name, test.expected, got) - continue - } - } - } -} diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go deleted file mode 100644 index 547921d0a..000000000 --- a/pkg/bpf/interpreter_test.go +++ /dev/null @@ -1,797 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" -) - -func TestCompilationErrors(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be compiled. - insns []linux.BPFInstruction - - // expectedErr is the expected compilation error. - expectedErr error - }{ - { - desc: "Instructions must not be nil", - expectedErr: Error{InvalidInstructionCount, 0}, - }, - { - desc: "Instructions must not be empty", - insns: []linux.BPFInstruction{}, - expectedErr: Error{InvalidInstructionCount, 0}, - }, - { - desc: "A program must end with a return", - insns: make([]linux.BPFInstruction, MaxInstructions), - expectedErr: Error{InvalidEndOfProgram, MaxInstructions - 1}, - }, - { - desc: "A program must have MaxInstructions or fewer instructions", - insns: append(make([]linux.BPFInstruction, MaxInstructions), Stmt(Ret|K, 0)), - expectedErr: Error{InvalidInstructionCount, MaxInstructions + 1}, - }, - { - desc: "A load from an invalid M register is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(Ld|Mem|W, ScratchMemRegisters), // A = M[16] - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidRegister, 0}, - }, - { - desc: "A store to an invalid M register is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(St, ScratchMemRegisters), // M[16] = A - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidRegister, 0}, - }, - { - desc: "Division by literal zero is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Div|K, 0), // A /= 0 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - { - desc: "An unconditional jump outside of the program is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Ja, 1, 0, 0), // jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - { - desc: "A conditional jump outside of the program in the true case is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 0, 1, 0), // if (A == K) jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - { - desc: "A conditional jump outside of the program in the false case is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 0, 0, 1), // if (A != K) jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - } { - _, err := Compile(test.insns) - if err != test.expectedErr { - t.Errorf("%s: expected error %q, got error %q", test.desc, test.expectedErr, err) - } - } -} - -func TestExecErrors(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be executed. - insns []linux.BPFInstruction - - // expectedErr is the expected execution error. - expectedErr error - }{ - { - desc: "An out-of-bounds load of input data is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|B, 0), // A = input[0] - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidLoad, 0}, - }, - { - desc: "Division by zero at runtime is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Div|X, 0), // A /= X - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - { - desc: "Modulo zero at runtime is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Mod|X, 0), // A %= X - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - } { - p, err := Compile(test.insns) - if err != nil { - t.Errorf("%s: unexpected compilation error: %v", test.desc, err) - continue - } - ret, err := Exec(p, InputBytes{nil, binary.BigEndian}) - if err != test.expectedErr { - t.Errorf("%s: expected execution error %q, got (%d, %v)", test.desc, test.expectedErr, ret, err) - } - } -} - -func TestValidInstructions(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be compiled. - insns []linux.BPFInstruction - - // input is the input data. Note that input will be read as big-endian. - input []byte - - // expectedRet is the expected return value of the BPF program. - expectedRet uint32 - }{ - { - desc: "Return of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ret|K, 42), // return 42 - }, - expectedRet: 42, - }, - { - desc: "Load of immediate into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of immediate into X and copying of X into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 42), // X = 42 - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Copying of A into X and back", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Misc|Txa, 0), // X = A - Stmt(Ld|Imm|W, 0), // A = 0 - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of 32-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|W, 1), // A = input[1..4] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33, 0x44}, - expectedRet: 0x11223344, - }, - { - desc: "Load of 16-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|H, 1), // A = input[1..2] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22}, - expectedRet: 0x1122, - }, - { - desc: "Load of 8-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|B, 1), // A = input[1] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11}, - expectedRet: 0x11, - }, - { - desc: "Load of 32-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|W, 1), // A = input[X+1..X+4] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - expectedRet: 0x22334455, - }, - { - desc: "Load of 16-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|H, 1), // A = input[X+1..X+2] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33}, - expectedRet: 0x2233, - }, - { - desc: "Load of 8-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|B, 1), // A = input[X+1] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22}, - expectedRet: 0x22, - }, - { - desc: "Load/store between A and scratch memory", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(St, 2), // M[2] = A - Stmt(Ld|Imm|W, 0), // A = 0 - Stmt(Ld|Mem|W, 2), // A = M[2] - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load/store between X and scratch memory", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 42), // X = 42 - Stmt(Stx, 3), // M[3] = X - Stmt(Ldx|Imm|W, 0), // X = 0 - Stmt(Ldx|Mem|W, 3), // X = M[3] - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of input length into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Len|W, 0), // A = len(input) - Stmt(Ret|A, 0), // return A - }, - input: []byte{1, 2, 3}, - expectedRet: 3, - }, - { - desc: "Load of input length into X", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Len|W, 0), // X = len(input) - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - input: []byte{1, 2, 3}, - expectedRet: 3, - }, - { - desc: "Load of MSH (?) into X", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Msh|B, 0), // X = 4*(input[0]&0xf) - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - input: []byte{0xf1}, - expectedRet: 4, - }, - { - desc: "Addition of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Alu|Add|K, 20), // A += 20 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 30, - }, - { - desc: "Addition of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 20), // X = 20 - Stmt(Alu|Add|X, 0), // A += X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 30, - }, - { - desc: "Subtraction of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 30), // A = 30 - Stmt(Alu|Sub|K, 20), // A -= 20 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 10, - }, - { - desc: "Subtraction of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 30), // A = 30 - Stmt(Ldx|Imm|W, 20), // X = 20 - Stmt(Alu|Sub|X, 0), // A -= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 10, - }, - { - desc: "Multiplication of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 2), // A = 2 - Stmt(Alu|Mul|K, 3), // A *= 3 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 6, - }, - { - desc: "Multiplication of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 2), // A = 2 - Stmt(Ldx|Imm|W, 3), // X = 3 - Stmt(Alu|Mul|X, 0), // A *= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 6, - }, - { - desc: "Division by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 6), // A = 6 - Stmt(Alu|Div|K, 3), // A /= 3 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 2, - }, - { - desc: "Division by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 6), // A = 6 - Stmt(Ldx|Imm|W, 3), // X = 3 - Stmt(Alu|Div|X, 0), // A /= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 2, - }, - { - desc: "Modulo immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 17), // A = 17 - Stmt(Alu|Mod|K, 7), // A %= 7 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 3, - }, - { - desc: "Modulo X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 17), // A = 17 - Stmt(Ldx|Imm|W, 7), // X = 7 - Stmt(Alu|Mod|X, 0), // A %= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 3, - }, - { - desc: "Arithmetic negation", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Alu|Neg, 0), // A = -A - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xffffffff, - }, - { - desc: "Bitwise OR with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|Or|K, 0xff0055aa), // A |= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff00ffff, - }, - { - desc: "Bitwise OR with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|Or|X, 0), // A |= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff00ffff, - }, - { - desc: "Bitwise AND with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|And|K, 0xff0055aa), // A &= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff000000, - }, - { - desc: "Bitwise AND with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|And|X, 0), // A &= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff000000, - }, - { - desc: "Bitwise XOR with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|Xor|K, 0xff0055aa), // A ^= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0x0000ffff, - }, - { - desc: "Bitwise XOR with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|Xor|X, 0), // A ^= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0x0000ffff, - }, - { - desc: "Left shift by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Alu|Lsh|K, 5), // A <<= 5 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 32, - }, - { - desc: "Left shift by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Ldx|Imm|W, 5), // X = 5 - Stmt(Alu|Lsh|X, 0), // A <<= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 32, - }, - { - desc: "Right shift by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xffffffff), // A = 0xffffffff - Stmt(Alu|Rsh|K, 31), // A >>= 31 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 1, - }, - { - desc: "Right shift by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xffffffff), // A = 0xffffffff - Stmt(Ldx|Imm|W, 31), // X = 31 - Stmt(Alu|Rsh|X, 0), // A >>= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 1, - }, - { - desc: "Unconditional jump", - insns: []linux.BPFInstruction{ - Jump(Jmp|Ja, 1, 0, 0), // jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - }, - expectedRet: 1, - }, - { - desc: "Jump when A == immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Jump(Jmp|Jeq|K, 42, 1, 2), // if (A == 42) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A != immediate", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 42, 1, 2), // if (A == 42) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A == X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Ldx|Imm|W, 42), // X = 42 - Jump(Jmp|Jeq|X, 0, 1, 2), // if (A == X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A != X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Jump(Jmp|Jeq|X, 0, 1, 2), // if (A == X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A > immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jgt|K, 9, 1, 2), // if (A > 9) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A <= immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jgt|K, 10, 1, 2), // if (A > 10) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A > X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 9), // X = 9 - Jump(Jmp|Jgt|X, 0, 1, 2), // if (A > X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A <= X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 10), // X = 10 - Jump(Jmp|Jgt|X, 0, 1, 2), // if (A > X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A >= immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jge|K, 10, 1, 2), // if (A >= 10) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A < immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jge|K, 11, 1, 2), // if (A >= 11) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A >= X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 10), // X = 10 - Jump(Jmp|Jge|X, 0, 1, 2), // if (A >= X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A < X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 11), // X = 11 - Jump(Jmp|Jge|X, 0, 1, 2), // if (A >= X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A & immediate != 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff), // A = 0xff - Jump(Jmp|Jset|K, 0x101, 1, 2), // if (A & 0x101) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A & immediate == 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xfe), // A = 0xfe - Jump(Jmp|Jset|K, 0x101, 1, 2), // if (A & 0x101) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A & X != 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff), // A = 0xff - Stmt(Ldx|Imm|W, 0x101), // X = 0x101 - Jump(Jmp|Jset|X, 0, 1, 2), // if (A & X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A & X == 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xfe), // A = 0xfe - Stmt(Ldx|Imm|W, 0x101), // X = 0x101 - Jump(Jmp|Jset|X, 0, 1, 2), // if (A & X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - } { - p, err := Compile(test.insns) - if err != nil { - t.Errorf("%s: unexpected compilation error: %v", test.desc, err) - continue - } - ret, err := Exec(p, InputBytes{test.input, binary.BigEndian}) - if err != nil { - t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) - continue - } - if ret != test.expectedRet { - t.Errorf("%s: expected return value of %d, got value %d", test.desc, test.expectedRet, ret) - } - } -} - -func TestSimpleFilter(t *testing.T) { - // Seccomp filter example given in Linux's - // Documentation/networking/filter.txt, translated to bytecode using the - // Linux kernel tree's tools/net/bpf_asm. - filter := []linux.BPFInstruction{ - {0x20, 0, 0, 0x00000004}, // ld [4] /* offsetof(struct seccomp_data, arch) */ - {0x15, 0, 11, 0xc000003e}, // jne #0xc000003e, bad /* AUDIT_ARCH_X86_64 */ - {0x20, 0, 0, 0000000000}, // ld [0] /* offsetof(struct seccomp_data, nr) */ - {0x15, 10, 0, 0x0000000f}, // jeq #15, good /* __NR_rt_sigreturn */ - {0x15, 9, 0, 0x000000e7}, // jeq #231, good /* __NR_exit_group */ - {0x15, 8, 0, 0x0000003c}, // jeq #60, good /* __NR_exit */ - {0x15, 7, 0, 0000000000}, // jeq #0, good /* __NR_read */ - {0x15, 6, 0, 0x00000001}, // jeq #1, good /* __NR_write */ - {0x15, 5, 0, 0x00000005}, // jeq #5, good /* __NR_fstat */ - {0x15, 4, 0, 0x00000009}, // jeq #9, good /* __NR_mmap */ - {0x15, 3, 0, 0x0000000e}, // jeq #14, good /* __NR_rt_sigprocmask */ - {0x15, 2, 0, 0x0000000d}, // jeq #13, good /* __NR_rt_sigaction */ - {0x15, 1, 0, 0x00000023}, // jeq #35, good /* __NR_nanosleep */ - {0x06, 0, 0, 0000000000}, // bad: ret #0 /* SECCOMP_RET_KILL */ - {0x06, 0, 0, 0x7fff0000}, // good: ret #0x7fff0000 /* SECCOMP_RET_ALLOW */ - } - p, err := Compile(filter) - if err != nil { - t.Fatalf("Unexpected compilation error: %v", err) - } - - for _, test := range []struct { - // desc is the test's description. - desc string - - // seccompData is the input data. - seccompData - - // expectedRet is the expected return value of the BPF program. - expectedRet uint32 - }{ - { - desc: "Invalid arch is rejected", - seccompData: seccompData{nr: 1 /* x86 exit */, arch: 0x40000003 /* AUDIT_ARCH_I386 */}, - expectedRet: 0, - }, - { - desc: "Disallowed syscall is rejected", - seccompData: seccompData{nr: 105 /* __NR_setuid */, arch: 0xc000003e}, - expectedRet: 0, - }, - { - desc: "Whitelisted syscall is allowed", - seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e}, - expectedRet: 0x7fff0000, - }, - } { - ret, err := Exec(p, test.seccompData.asInput()) - if err != nil { - t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) - continue - } - if ret != test.expectedRet { - t.Errorf("%s: expected return value of %d, got value %d", test.desc, test.expectedRet, ret) - } - } -} - -// seccompData is equivalent to struct seccomp_data. -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - -// asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() Input { - return InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} -} diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go deleted file mode 100644 index 92ca5f4c3..000000000 --- a/pkg/bpf/program_builder_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" -) - -func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error { - instructions, err := p.Instructions() - if err != nil { - return fmt.Errorf("Instructions() failed: %v", err) - } - got, err := DecodeProgram(instructions) - if err != nil { - return fmt.Errorf("DecodeProgram('instructions') failed: %v", err) - } - expectedDecoded, err := DecodeProgram(expected) - if err != nil { - return fmt.Errorf("DecodeProgram('expected') failed: %v", err) - } - if got != expectedDecoded { - return fmt.Errorf("DecodeProgram() failed, expected: %q, got: %q", expectedDecoded, got) - } - return nil -} - -func TestProgramBuilderSimple(t *testing.T) { - p := NewProgramBuilder() - p.AddStmt(Ld+Abs+W, 10) - p.AddJump(Jmp+Ja, 10, 0, 0) - - expected := []linux.BPFInstruction{ - Stmt(Ld+Abs+W, 10), - Jump(Jmp+Ja, 10, 0, 0), - } - - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } -} - -func TestProgramBuilderLabels(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 11, "label_1", 0) - p.AddJumpFalseLabel(Jmp+Jeq+K, 12, 0, "label_2") - p.AddJumpLabels(Jmp+Jeq+K, 13, "label_3", "label_4") - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 1) - if err := p.AddLabel("label_3"); err != nil { - t.Errorf("AddLabel(label_3) failed: %v", err) - } - p.AddJumpLabels(Jmp+Jeq+K, 14, "label_4", "label_5") - if err := p.AddLabel("label_2"); err != nil { - t.Errorf("AddLabel(label_2) failed: %v", err) - } - p.AddJumpLabels(Jmp+Jeq+K, 15, "label_4", "label_6") - if err := p.AddLabel("label_4"); err != nil { - t.Errorf("AddLabel(label_4) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 4) - if err := p.AddLabel("label_5"); err != nil { - t.Errorf("AddLabel(label_5) failed: %v", err) - } - if err := p.AddLabel("label_6"); err != nil { - t.Errorf("AddLabel(label_6) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 5) - - expected := []linux.BPFInstruction{ - Jump(Jmp+Jeq+K, 11, 2, 0), - Jump(Jmp+Jeq+K, 12, 0, 3), - Jump(Jmp+Jeq+K, 13, 1, 3), - Stmt(Ld+Abs+W, 1), - Jump(Jmp+Jeq+K, 14, 1, 2), - Jump(Jmp+Jeq+K, 15, 0, 1), - Stmt(Ld+Abs+W, 4), - Stmt(Ld+Abs+W, 5), - } - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } - // Calling validate()=>p.Instructions() again to make sure - // Instructions can be called multiple times without ruining - // the program. - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } -} - -func TestProgramBuilderMissingErrorTarget(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} - -func TestProgramBuilderLabelWithNoInstruction(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} - -func TestProgramBuilderUnusedLabel(t *testing.T) { - p := NewProgramBuilder() - if err := p.AddLabel("unused"); err == nil { - t.Errorf("AddLabel(unused) should have failed") - } -} - -func TestProgramBuilderLabelAddedTwice(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 0) - if err := p.AddLabel("label_1"); err == nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } -} - -func TestProgramBuilderJumpBackwards(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 0) - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD deleted file mode 100644 index dcd086298..000000000 --- a/pkg/buffer/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "buffer_list", - out = "buffer_list.go", - package = "buffer", - prefix = "buffer", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*buffer", - "Linker": "*buffer", - }, -) - -go_library( - name = "buffer", - srcs = [ - "buffer.go", - "buffer_list.go", - "safemem.go", - "view.go", - "view_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/safemem", - ], -) - -go_test( - name = "buffer_test", - size = "small", - srcs = [ - "safemem_test.go", - "view_test.go", - ], - library = ":buffer", - deps = ["//pkg/safemem"], -) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index c6d089fd9..c6d089fd9 100644..100755 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go diff --git a/pkg/buffer/buffer_list.go b/pkg/buffer/buffer_list.go new file mode 100755 index 000000000..e2d519538 --- /dev/null +++ b/pkg/buffer/buffer_list.go @@ -0,0 +1,186 @@ +package buffer + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type bufferElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (bufferElementMapper) linkerFor(elem *buffer) *buffer { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type bufferList struct { + head *buffer + tail *buffer +} + +// Reset resets list l to the empty state. +func (l *bufferList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *bufferList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *bufferList) Front() *buffer { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *bufferList) Back() *buffer { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *bufferList) PushFront(e *buffer) { + linker := bufferElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + bufferElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *bufferList) PushBack(e *buffer) { + linker := bufferElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + bufferElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *bufferList) PushBackList(m *bufferList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + bufferElementMapper{}.linkerFor(l.tail).SetNext(m.head) + bufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *bufferList) InsertAfter(b, e *buffer) { + bLinker := bufferElementMapper{}.linkerFor(b) + eLinker := bufferElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + bufferElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *bufferList) InsertBefore(a, e *buffer) { + aLinker := bufferElementMapper{}.linkerFor(a) + eLinker := bufferElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + bufferElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *bufferList) Remove(e *buffer) { + linker := bufferElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + bufferElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + bufferElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type bufferEntry struct { + next *buffer + prev *buffer +} + +// Next returns the entry that follows e in the list. +func (e *bufferEntry) Next() *buffer { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *bufferEntry) Prev() *buffer { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *bufferEntry) SetNext(elem *buffer) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *bufferEntry) SetPrev(elem *buffer) { + e.prev = elem +} diff --git a/pkg/buffer/buffer_state_autogen.go b/pkg/buffer/buffer_state_autogen.go new file mode 100755 index 000000000..2e6299f81 --- /dev/null +++ b/pkg/buffer/buffer_state_autogen.go @@ -0,0 +1,70 @@ +// automatically generated by stateify. + +package buffer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *buffer) beforeSave() {} +func (x *buffer) save(m state.Map) { + x.beforeSave() + m.Save("data", &x.data) + m.Save("read", &x.read) + m.Save("write", &x.write) + m.Save("bufferEntry", &x.bufferEntry) +} + +func (x *buffer) afterLoad() {} +func (x *buffer) load(m state.Map) { + m.Load("data", &x.data) + m.Load("read", &x.read) + m.Load("write", &x.write) + m.Load("bufferEntry", &x.bufferEntry) +} + +func (x *bufferList) beforeSave() {} +func (x *bufferList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *bufferList) afterLoad() {} +func (x *bufferList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *bufferEntry) beforeSave() {} +func (x *bufferEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *bufferEntry) afterLoad() {} +func (x *bufferEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *View) beforeSave() {} +func (x *View) save(m state.Map) { + x.beforeSave() + m.Save("data", &x.data) + m.Save("size", &x.size) +} + +func (x *View) afterLoad() {} +func (x *View) load(m state.Map) { + m.Load("data", &x.data) + m.Load("size", &x.size) +} + +func init() { + state.Register("pkg/buffer.buffer", (*buffer)(nil), state.Fns{Save: (*buffer).save, Load: (*buffer).load}) + state.Register("pkg/buffer.bufferList", (*bufferList)(nil), state.Fns{Save: (*bufferList).save, Load: (*bufferList).load}) + state.Register("pkg/buffer.bufferEntry", (*bufferEntry)(nil), state.Fns{Save: (*bufferEntry).save, Load: (*bufferEntry).load}) + state.Register("pkg/buffer.View", (*View)(nil), state.Fns{Save: (*View).save, Load: (*View).load}) +} diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go index 0e5b86344..0e5b86344 100644..100755 --- a/pkg/buffer/safemem.go +++ b/pkg/buffer/safemem.go diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go deleted file mode 100644 index 47f357e0c..000000000 --- a/pkg/buffer/safemem_test.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bytes" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/safemem" -) - -func TestSafemem(t *testing.T) { - testCases := []struct { - name string - input string - output string - readLen int - op func(*View) - }{ - // Basic coverage. - { - name: "short", - input: "010", - output: "010", - }, - { - name: "long", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize) + "0", - }, - { - name: "short-read", - input: "0", - readLen: 100, // > size. - output: "0", - }, - { - name: "zero-read", - input: "0", - output: "", - }, - { - name: "read-empty", - input: "", - readLen: 1, // > size. - output: "", - }, - - // Ensure offsets work. - { - name: "offsets-short", - input: "012", - output: "2", - op: func(v *View) { - v.TrimFront(2) - }, - }, - { - name: "offsets-long0", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: strings.Repeat("1", bufferSize) + "0", - op: func(v *View) { - v.TrimFront(1) - }, - }, - { - name: "offsets-long1", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: strings.Repeat("1", bufferSize-1) + "0", - op: func(v *View) { - v.TrimFront(2) - }, - }, - { - name: "offsets-long2", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "10", - op: func(v *View) { - v.TrimFront(bufferSize) - }, - }, - - // Ensure truncation works. - { - name: "truncate-short", - input: "012", - output: "01", - op: func(v *View) { - v.Truncate(2) - }, - }, - { - name: "truncate-long0", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize), - op: func(v *View) { - v.Truncate(bufferSize + 1) - }, - }, - { - name: "truncate-long1", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize-1), - op: func(v *View) { - v.Truncate(bufferSize) - }, - }, - { - name: "truncate-long2", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "01", - op: func(v *View) { - v.Truncate(2) - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Construct the new view. - var view View - bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input))) - n, err := view.WriteFromBlocks(bs) - if err != nil { - t.Errorf("expected err nil, got %v", err) - } - if n != uint64(len(tc.input)) { - t.Errorf("expected %d bytes, got %d", len(tc.input), n) - } - - // Run the operation. - if tc.op != nil { - tc.op(&view) - } - - // Read and validate. - readLen := tc.readLen - if readLen == 0 { - readLen = len(tc.output) // Default. - } - out := make([]byte, readLen) - bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out)) - n, err = view.ReadToBlocks(bs) - if err != nil { - t.Errorf("expected nil, got %v", err) - } - if n != uint64(len(tc.output)) { - t.Errorf("expected %d bytes, got %d", len(tc.output), n) - } - - // Ensure the contents are correct. - if !bytes.Equal(out[:n], []byte(tc.output[:n])) { - t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out)) - } - }) - } -} diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go index e6901eadb..e6901eadb 100644..100755 --- a/pkg/buffer/view.go +++ b/pkg/buffer/view.go diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go deleted file mode 100644 index 3db1bc6ee..000000000 --- a/pkg/buffer/view_test.go +++ /dev/null @@ -1,467 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bytes" - "io" - "strings" - "testing" -) - -func fillAppend(v *View, data []byte) { - v.Append(data) -} - -func fillAppendEnd(v *View, data []byte) { - v.Grow(bufferSize-1, false) - v.Append(data) - v.TrimFront(bufferSize - 1) -} - -func fillWriteFromReader(v *View, data []byte) { - b := bytes.NewBuffer(data) - v.WriteFromReader(b, int64(len(data))) -} - -func fillWriteFromReaderEnd(v *View, data []byte) { - v.Grow(bufferSize-1, false) - b := bytes.NewBuffer(data) - v.WriteFromReader(b, int64(len(data))) - v.TrimFront(bufferSize - 1) -} - -var fillFuncs = map[string]func(*View, []byte){ - "append": fillAppend, - "appendEnd": fillAppendEnd, - "writeFromReader": fillWriteFromReader, - "writeFromReaderEnd": fillWriteFromReaderEnd, -} - -func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) { - t.Helper() - d := make([]byte, n) - n, err := v.ReadAt(d, offset) - if n != len(wantStr) { - t.Errorf("got %d, want %d", n, len(wantStr)) - } - if err != wantErr { - t.Errorf("got err %v, want %v", err, wantErr) - } - if !bytes.Equal(d[:n], []byte(wantStr)) { - t.Errorf("got %q, want %q", string(d[:n]), wantStr) - } -} - -func TestView(t *testing.T) { - testCases := []struct { - name string - input string - output string - op func(*testing.T, *View) - }{ - // Preconditions. - { - name: "truncate-check", - input: "hello", - output: "hello", // Not touched. - op: func(t *testing.T, v *View) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Truncate(-1) did not panic") - } - }() - v.Truncate(-1) - }, - }, - { - name: "grow-check", - input: "hello", - output: "hello", // Not touched. - op: func(t *testing.T, v *View) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Grow(-1) did not panic") - } - }() - v.Grow(-1, false) - }, - }, - { - name: "advance-check", - input: "hello", - output: "", // Consumed. - op: func(t *testing.T, v *View) { - defer func() { - if r := recover(); r == nil { - t.Errorf("advanceRead(Size()+1) did not panic") - } - }() - v.advanceRead(v.Size() + 1) - }, - }, - - // Prepend. - { - name: "prepend", - input: "world", - output: "hello world", - op: func(t *testing.T, v *View) { - v.Prepend([]byte("hello ")) - }, - }, - { - name: "prepend-backfill-full", - input: "hello world", - output: "jello world", - op: func(t *testing.T, v *View) { - v.TrimFront(1) - v.Prepend([]byte("j")) - }, - }, - { - name: "prepend-backfill-under", - input: "hello world", - output: "hola world", - op: func(t *testing.T, v *View) { - v.TrimFront(5) - v.Prepend([]byte("hola")) - }, - }, - { - name: "prepend-backfill-over", - input: "hello world", - output: "smello world", - op: func(t *testing.T, v *View) { - v.TrimFront(1) - v.Prepend([]byte("sm")) - }, - }, - { - name: "prepend-fill", - input: strings.Repeat("1", bufferSize-1), - output: "0" + strings.Repeat("1", bufferSize-1), - op: func(t *testing.T, v *View) { - v.Prepend([]byte("0")) - }, - }, - { - name: "prepend-overflow", - input: strings.Repeat("1", bufferSize), - output: "0" + strings.Repeat("1", bufferSize), - op: func(t *testing.T, v *View) { - v.Prepend([]byte("0")) - }, - }, - { - name: "prepend-multiple-buffers", - input: strings.Repeat("1", bufferSize-1), - output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1), - op: func(t *testing.T, v *View) { - v.Prepend([]byte(strings.Repeat("0", bufferSize*3))) - }, - }, - - // Append and write. - { - name: "append", - input: "hello", - output: "hello world", - op: func(t *testing.T, v *View) { - v.Append([]byte(" world")) - }, - }, - { - name: "append-fill", - input: strings.Repeat("1", bufferSize-1), - output: strings.Repeat("1", bufferSize-1) + "0", - op: func(t *testing.T, v *View) { - v.Append([]byte("0")) - }, - }, - { - name: "append-overflow", - input: strings.Repeat("1", bufferSize), - output: strings.Repeat("1", bufferSize) + "0", - op: func(t *testing.T, v *View) { - v.Append([]byte("0")) - }, - }, - { - name: "append-multiple-buffers", - input: strings.Repeat("1", bufferSize-1), - output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3), - op: func(t *testing.T, v *View) { - v.Append([]byte(strings.Repeat("0", bufferSize*3))) - }, - }, - - // Truncate. - { - name: "truncate", - input: "hello world", - output: "hello", - op: func(t *testing.T, v *View) { - v.Truncate(5) - }, - }, - { - name: "truncate-noop", - input: "hello world", - output: "hello world", - op: func(t *testing.T, v *View) { - v.Truncate(v.Size() + 1) - }, - }, - { - name: "truncate-multiple-buffers", - input: strings.Repeat("1", bufferSize*2), - output: strings.Repeat("1", bufferSize*2-1), - op: func(t *testing.T, v *View) { - v.Truncate(bufferSize*2 - 1) - }, - }, - { - name: "truncate-multiple-buffers-to-one", - input: strings.Repeat("1", bufferSize*2), - output: "11111", - op: func(t *testing.T, v *View) { - v.Truncate(5) - }, - }, - - // TrimFront. - { - name: "trim", - input: "hello world", - output: "world", - op: func(t *testing.T, v *View) { - v.TrimFront(6) - }, - }, - { - name: "trim-too-large", - input: "hello world", - output: "", - op: func(t *testing.T, v *View) { - v.TrimFront(v.Size() + 1) - }, - }, - { - name: "trim-multiple-buffers", - input: strings.Repeat("1", bufferSize*2), - output: strings.Repeat("1", bufferSize*2-1), - op: func(t *testing.T, v *View) { - v.TrimFront(1) - }, - }, - { - name: "trim-multiple-buffers-to-one-buffer", - input: strings.Repeat("1", bufferSize*2), - output: "1", - op: func(t *testing.T, v *View) { - v.TrimFront(bufferSize*2 - 1) - }, - }, - - // Grow. - { - name: "grow", - input: "hello world", - output: "hello world", - op: func(t *testing.T, v *View) { - v.Grow(1, true) - }, - }, - { - name: "grow-from-zero", - output: strings.Repeat("\x00", 1024), - op: func(t *testing.T, v *View) { - v.Grow(1024, true) - }, - }, - { - name: "grow-from-non-zero", - input: strings.Repeat("1", bufferSize), - output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize), - op: func(t *testing.T, v *View) { - v.Grow(bufferSize*2, true) - }, - }, - - // Copy. - { - name: "copy", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { - other := v.Copy() - bs := other.Flatten() - want := []byte("hello") - if !bytes.Equal(bs, want) { - t.Errorf("expected %v, got %v", want, bs) - } - }, - }, - { - name: "copy-large", - input: strings.Repeat("1", bufferSize+1), - output: strings.Repeat("1", bufferSize+1), - op: func(t *testing.T, v *View) { - other := v.Copy() - bs := other.Flatten() - want := []byte(strings.Repeat("1", bufferSize+1)) - if !bytes.Equal(bs, want) { - t.Errorf("expected %v, got %v", want, bs) - } - }, - }, - - // Merge. - { - name: "merge", - input: "hello", - output: "hello world", - op: func(t *testing.T, v *View) { - var other View - other.Append([]byte(" world")) - v.Merge(&other) - if sz := other.Size(); sz != 0 { - t.Errorf("expected 0, got %d", sz) - } - }, - }, - { - name: "merge-large", - input: strings.Repeat("1", bufferSize+1), - output: strings.Repeat("1", bufferSize+1) + strings.Repeat("0", bufferSize+1), - op: func(t *testing.T, v *View) { - var other View - other.Append([]byte(strings.Repeat("0", bufferSize+1))) - v.Merge(&other) - if sz := other.Size(); sz != 0 { - t.Errorf("expected 0, got %d", sz) - } - }, - }, - - // ReadAt. - { - name: "readat", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 6, "hello", io.EOF) }, - }, - { - name: "readat-long", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 8, "hello", io.EOF) }, - }, - { - name: "readat-short", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 3, "hel", nil) }, - }, - { - name: "readat-offset", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 3, "llo", io.EOF) }, - }, - { - name: "readat-long-offset", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 8, "llo", io.EOF) }, - }, - { - name: "readat-short-offset", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 2, "ll", nil) }, - }, - { - name: "readat-skip-all", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "", io.EOF) }, - }, - { - name: "readat-second-buffer", - input: strings.Repeat("0", bufferSize+1) + "12", - output: strings.Repeat("0", bufferSize+1) + "12", - op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "1", nil) }, - }, - { - name: "readat-second-buffer-end", - input: strings.Repeat("0", bufferSize+1) + "12", - output: strings.Repeat("0", bufferSize+1) + "12", - op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 2, "12", io.EOF) }, - }, - } - - for _, tc := range testCases { - for fillName, fn := range fillFuncs { - t.Run(fillName+"/"+tc.name, func(t *testing.T) { - // Construct & fill the view. - var view View - fn(&view, []byte(tc.input)) - - // Run the operation. - if tc.op != nil { - tc.op(t, &view) - } - - // Flatten and validate. - out := view.Flatten() - if !bytes.Equal([]byte(tc.output), out) { - t.Errorf("expected %q, got %q", tc.output, string(out)) - } - - // Ensure the size is correct. - if len(out) != int(view.Size()) { - t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size()) - } - - // Calculate contents via apply. - var appliedOut []byte - view.Apply(func(b []byte) { - appliedOut = append(appliedOut, b...) - }) - if len(appliedOut) != len(out) { - t.Errorf("expected %d, got %d", len(out), len(appliedOut)) - } - if !bytes.Equal(appliedOut, out) { - t.Errorf("expected %v, got %v", out, appliedOut) - } - - // Calculate contents via ReadToWriter. - var b bytes.Buffer - n, err := view.ReadToWriter(&b, int64(len(out))) - if n != int64(len(out)) { - t.Errorf("expected %d, got %d", len(out), n) - } - if err != nil { - t.Errorf("expected nil, got %v", err) - } - if !bytes.Equal(b.Bytes(), out) { - t.Errorf("expected %v, got %v", out, b.Bytes()) - } - }) - } - } -} diff --git a/pkg/buffer/view_unsafe.go b/pkg/buffer/view_unsafe.go index d1ef39b26..d1ef39b26 100644..100755 --- a/pkg/buffer/view_unsafe.go +++ b/pkg/buffer/view_unsafe.go diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD deleted file mode 100644 index 1f75319a7..000000000 --- a/pkg/compressio/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "compressio", - srcs = ["compressio.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/binary", - "//pkg/sync", - ], -) - -go_test( - name = "compressio_test", - size = "medium", - srcs = ["compressio_test.go"], - library = ":compressio", -) diff --git a/pkg/compressio/compressio_state_autogen.go b/pkg/compressio/compressio_state_autogen.go new file mode 100755 index 000000000..c47e0dd17 --- /dev/null +++ b/pkg/compressio/compressio_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package compressio diff --git a/pkg/compressio/compressio_test.go b/pkg/compressio/compressio_test.go deleted file mode 100644 index 86dc47e44..000000000 --- a/pkg/compressio/compressio_test.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compressio - -import ( - "bytes" - "compress/flate" - "encoding/base64" - "fmt" - "io" - "math/rand" - "runtime" - "testing" - "time" -) - -type harness interface { - Errorf(format string, v ...interface{}) - Fatalf(format string, v ...interface{}) - Logf(format string, v ...interface{}) -} - -func initTest(t harness, size int) []byte { - // Set number of processes to number of CPUs. - runtime.GOMAXPROCS(runtime.NumCPU()) - - // Construct synthetic data. We do this by encoding random data with - // base64. This gives a high level of entropy, but still quite a bit of - // structure, to give reasonable compression ratios (~75%). - var buf bytes.Buffer - bufW := base64.NewEncoder(base64.RawStdEncoding, &buf) - bufR := rand.New(rand.NewSource(0)) - if _, err := io.CopyN(bufW, bufR, int64(size)); err != nil { - t.Fatalf("unable to seed random data: %v", err) - } - return buf.Bytes() -} - -type testOpts struct { - Name string - Data []byte - NewWriter func(*bytes.Buffer) (io.Writer, error) - NewReader func(*bytes.Buffer) (io.Reader, error) - PreCompress func() - PostCompress func() - PreDecompress func() - PostDecompress func() - CompressIters int - DecompressIters int - CorruptData bool -} - -func doTest(t harness, opts testOpts) { - // Compress. - var compressed bytes.Buffer - compressionStartTime := time.Now() - if opts.PreCompress != nil { - opts.PreCompress() - } - if opts.CompressIters <= 0 { - opts.CompressIters = 1 - } - for i := 0; i < opts.CompressIters; i++ { - compressed.Reset() - w, err := opts.NewWriter(&compressed) - if err != nil { - t.Errorf("%s: NewWriter got err %v, expected nil", opts.Name, err) - } - if _, err := io.Copy(w, bytes.NewBuffer(opts.Data)); err != nil { - t.Errorf("%s: compress got err %v, expected nil", opts.Name, err) - return - } - closer, ok := w.(io.Closer) - if ok { - if err := closer.Close(); err != nil { - t.Errorf("%s: got err %v, expected nil", opts.Name, err) - return - } - } - } - if opts.PostCompress != nil { - opts.PostCompress() - } - compressionTime := time.Since(compressionStartTime) - compressionRatio := float32(compressed.Len()) / float32(len(opts.Data)) - - // Decompress. - var decompressed bytes.Buffer - decompressionStartTime := time.Now() - if opts.PreDecompress != nil { - opts.PreDecompress() - } - if opts.DecompressIters <= 0 { - opts.DecompressIters = 1 - } - if opts.CorruptData { - b := compressed.Bytes() - b[rand.Intn(len(b))]++ - } - for i := 0; i < opts.DecompressIters; i++ { - decompressed.Reset() - r, err := opts.NewReader(bytes.NewBuffer(compressed.Bytes())) - if err != nil { - if opts.CorruptData { - continue - } - t.Errorf("%s: NewReader got err %v, expected nil", opts.Name, err) - return - } - if _, err := io.Copy(&decompressed, r); (err != nil) != opts.CorruptData { - t.Errorf("%s: decompress got err %v unexpectly", opts.Name, err) - return - } - } - if opts.PostDecompress != nil { - opts.PostDecompress() - } - decompressionTime := time.Since(decompressionStartTime) - - if opts.CorruptData { - return - } - - // Verify. - if decompressed.Len() != len(opts.Data) { - t.Errorf("%s: got %d bytes, expected %d", opts.Name, decompressed.Len(), len(opts.Data)) - } - if !bytes.Equal(opts.Data, decompressed.Bytes()) { - t.Errorf("%s: got mismatch, expected match", opts.Name) - if len(opts.Data) < 32 { // Don't flood the logs. - t.Errorf("got %v, expected %v", decompressed.Bytes(), opts.Data) - } - } - - t.Logf("%s: compression time %v, ratio %2.2f, decompression time %v", - opts.Name, compressionTime, compressionRatio, decompressionTime) -} - -var hashKey = []byte("01234567890123456789012345678901") - -func TestCompress(t *testing.T) { - rand.Seed(time.Now().Unix()) - - var ( - data = initTest(t, 10*1024*1024) - data0 = data[:0] - data1 = data[:1] - data2 = data[:11] - data3 = data[:16] - data4 = data[:] - ) - - for _, data := range [][]byte{data0, data1, data2, data3, data4} { - for _, blockSize := range []uint32{1, 4, 1024, 4 * 1024, 16 * 1024} { - // Skip annoying tests; they just take too long. - if blockSize <= 16 && len(data) > 16 { - continue - } - - for _, key := range [][]byte{nil, hashKey} { - for _, corruptData := range []bool{false, true} { - if key == nil && corruptData { - // No need to test corrupt data - // case when not doing hashing. - continue - } - // Do the compress test. - doTest(t, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d, key=%s, corruptData=%v", len(data), blockSize, string(key), corruptData), - Data: data, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, key, blockSize, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b, key) - }, - CorruptData: corruptData, - }) - } - } - } - - // Do the vanilla test. - doTest(t, testOpts{ - Name: fmt.Sprintf("len(data)=%d, vanilla flate", len(data)), - Data: data, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return flate.NewWriter(b, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return flate.NewReader(b), nil - }, - }) - } -} - -const ( - benchDataSize = 600 * 1024 * 1024 -) - -func benchmark(b *testing.B, compress bool, hash bool, blockSize uint32) { - b.StopTimer() - b.SetBytes(benchDataSize) - data := initTest(b, benchDataSize) - compIters := b.N - decompIters := b.N - if compress { - decompIters = 0 - } else { - compIters = 0 - } - key := hashKey - if !hash { - key = nil - } - doTest(b, testOpts{ - Name: fmt.Sprintf("compress=%t, hash=%t, len(data)=%d, blockSize=%d", compress, hash, len(data), blockSize), - Data: data, - PreCompress: b.StartTimer, - PostCompress: b.StopTimer, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, key, blockSize, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b, key) - }, - CompressIters: compIters, - DecompressIters: decompIters, - }) -} - -func BenchmarkCompressNoHash64K(b *testing.B) { - benchmark(b, true, false, 64*1024) -} - -func BenchmarkCompressHash64K(b *testing.B) { - benchmark(b, true, true, 64*1024) -} - -func BenchmarkDecompressNoHash64K(b *testing.B) { - benchmark(b, false, false, 64*1024) -} - -func BenchmarkDecompressHash64K(b *testing.B) { - benchmark(b, false, true, 64*1024) -} - -func BenchmarkCompressNoHash1M(b *testing.B) { - benchmark(b, true, false, 1024*1024) -} - -func BenchmarkCompressHash1M(b *testing.B) { - benchmark(b, true, true, 1024*1024) -} - -func BenchmarkDecompressNoHash1M(b *testing.B) { - benchmark(b, false, false, 1024*1024) -} - -func BenchmarkDecompressHash1M(b *testing.B) { - benchmark(b, false, true, 1024*1024) -} - -func BenchmarkCompressNoHash16M(b *testing.B) { - benchmark(b, true, false, 16*1024*1024) -} - -func BenchmarkCompressHash16M(b *testing.B) { - benchmark(b, true, true, 16*1024*1024) -} - -func BenchmarkDecompressNoHash16M(b *testing.B) { - benchmark(b, false, false, 16*1024*1024) -} - -func BenchmarkDecompressHash16M(b *testing.B) { - benchmark(b, false, true, 16*1024*1024) -} diff --git a/pkg/context/BUILD b/pkg/context/BUILD deleted file mode 100644 index 239f31149..000000000 --- a/pkg/context/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - srcs = ["context.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/amutex", - "//pkg/log", - ], -) diff --git a/pkg/context/context.go b/pkg/context/context.go index 23e009ef3..23e009ef3 100644..100755 --- a/pkg/context/context.go +++ b/pkg/context/context.go diff --git a/pkg/context/context_state_autogen.go b/pkg/context/context_state_autogen.go new file mode 100755 index 000000000..fdc3c9fbb --- /dev/null +++ b/pkg/context/context_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package context diff --git a/pkg/control/client/BUILD b/pkg/control/client/BUILD deleted file mode 100644 index 1b9e10ee7..000000000 --- a/pkg/control/client/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "client", - srcs = [ - "client.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/unet", - "//pkg/urpc", - ], -) diff --git a/pkg/control/client/client_state_autogen.go b/pkg/control/client/client_state_autogen.go new file mode 100755 index 000000000..9872f1107 --- /dev/null +++ b/pkg/control/client/client_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package client diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD deleted file mode 100644 index 002d2ef44..000000000 --- a/pkg/control/server/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "server", - srcs = ["server.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/unet", - "//pkg/urpc", - ], -) diff --git a/pkg/control/server/server_state_autogen.go b/pkg/control/server/server_state_autogen.go new file mode 100755 index 000000000..c236b8da5 --- /dev/null +++ b/pkg/control/server/server_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package server diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD deleted file mode 100644 index d6cb1a549..000000000 --- a/pkg/cpuid/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "cpuid", - srcs = [ - "cpu_amd64.s", - "cpuid.go", - "cpuid_arm64.go", - "cpuid_x86.go", - ], - visibility = ["//:sandbox"], - deps = ["//pkg/log"], -) - -go_test( - name = "cpuid_test", - size = "small", - srcs = [ - "cpuid_arm64_test.go", - "cpuid_x86_test.go", - ], - library = ":cpuid", -) - -go_test( - name = "cpuid_parse_test", - size = "small", - srcs = [ - "cpuid_parse_x86_test.go", - ], - library = ":cpuid", - tags = ["manual"], -) diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go index 08381c1c0..08381c1c0 100644..100755 --- a/pkg/cpuid/cpuid_arm64.go +++ b/pkg/cpuid/cpuid_arm64.go diff --git a/pkg/cpuid/cpuid_arm64_state_autogen.go b/pkg/cpuid/cpuid_arm64_state_autogen.go new file mode 100755 index 000000000..0e671d441 --- /dev/null +++ b/pkg/cpuid/cpuid_arm64_state_autogen.go @@ -0,0 +1,34 @@ +// automatically generated by stateify. + +// +build arm64 + +package cpuid + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FeatureSet) beforeSave() {} +func (x *FeatureSet) save(m state.Map) { + x.beforeSave() + m.Save("Set", &x.Set) + m.Save("CPUImplementer", &x.CPUImplementer) + m.Save("CPUArchitecture", &x.CPUArchitecture) + m.Save("CPUVariant", &x.CPUVariant) + m.Save("CPUPartnum", &x.CPUPartnum) + m.Save("CPURevision", &x.CPURevision) +} + +func (x *FeatureSet) afterLoad() {} +func (x *FeatureSet) load(m state.Map) { + m.Load("Set", &x.Set) + m.Load("CPUImplementer", &x.CPUImplementer) + m.Load("CPUArchitecture", &x.CPUArchitecture) + m.Load("CPUVariant", &x.CPUVariant) + m.Load("CPUPartnum", &x.CPUPartnum) + m.Load("CPURevision", &x.CPURevision) +} + +func init() { + state.Register("pkg/cpuid.FeatureSet", (*FeatureSet)(nil), state.Fns{Save: (*FeatureSet).save, Load: (*FeatureSet).load}) +} diff --git a/pkg/cpuid/cpuid_arm64_test.go b/pkg/cpuid/cpuid_arm64_test.go deleted file mode 100644 index a34f67779..000000000 --- a/pkg/cpuid/cpuid_arm64_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package cpuid - -import ( - "testing" -) - -var justFP = &FeatureSet{ - Set: map[Feature]bool{ - ARM64FeatureFP: true, - }} - -func TestHostFeatureSet(t *testing.T) { - hostFeatures := HostFeatureSet() - if len(hostFeatures.Set) == 0 { - t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures) - } -} - -func TestHasFeature(t *testing.T) { - if !justFP.HasFeature(ARM64FeatureFP) { - t.Errorf("HasFeature failed, %v should contain %v", justFP, ARM64FeatureFP) - } - - if justFP.HasFeature(ARM64FeatureSM3) { - t.Errorf("HasFeature failed, %v should not contain %v", justFP, ARM64FeatureSM3) - } -} - -func TestFeatureFromString(t *testing.T) { - f, ok := FeatureFromString("asimd") - if f != ARM64FeatureASIMD || !ok { - t.Errorf("got %v want asimd", f) - } - - f, ok = FeatureFromString("bad") - if ok { - t.Errorf("got %v want nothing", f) - } -} diff --git a/pkg/cpuid/cpuid_parse_x86_test.go b/pkg/cpuid/cpuid_parse_x86_test.go deleted file mode 100644 index d48418e69..000000000 --- a/pkg/cpuid/cpuid_parse_x86_test.go +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build i386 amd64 - -package cpuid - -import ( - "fmt" - "io/ioutil" - "regexp" - "strconv" - "strings" - "syscall" - "testing" -) - -func kernelVersion() (int, int, error) { - var u syscall.Utsname - if err := syscall.Uname(&u); err != nil { - return 0, 0, err - } - - var r string - for _, b := range u.Release { - if b == 0 { - break - } - r += string(b) - } - - s := strings.Split(r, ".") - if len(s) < 2 { - return 0, 0, fmt.Errorf("kernel release missing major and minor component: %s", r) - } - - major, err := strconv.Atoi(s[0]) - if err != nil { - return 0, 0, fmt.Errorf("error parsing major version %q in %q: %v", s[0], r, err) - } - - minor, err := strconv.Atoi(s[1]) - if err != nil { - return 0, 0, fmt.Errorf("error parsing minor version %q in %q: %v", s[1], r, err) - } - - return major, minor, nil -} - -// TestHostFeatureFlags tests that all features detected by HostFeatureSet are -// on the host. -// -// It does *not* verify that all features reported by the host are detected by -// HostFeatureSet. -// -// i.e., test that HostFeatureSet is a subset of the host features. -func TestHostFeatureFlags(t *testing.T) { - cpuinfoBytes, _ := ioutil.ReadFile("/proc/cpuinfo") - cpuinfo := string(cpuinfoBytes) - t.Logf("Host cpu info:\n%s", cpuinfo) - - major, minor, err := kernelVersion() - if err != nil { - t.Fatalf("Unable to parse kernel version: %v", err) - } - - re := regexp.MustCompile(`(?m)^flags\s+: (.*)$`) - m := re.FindStringSubmatch(cpuinfo) - if len(m) != 2 { - t.Fatalf("Unable to extract flags from %q", cpuinfo) - } - - cpuinfoFlags := make(map[string]struct{}) - for _, f := range strings.Split(m[1], " ") { - cpuinfoFlags[f] = struct{}{} - } - - fs := HostFeatureSet() - - // All features have a string and appear in host cpuinfo. - for f := range fs.Set { - name := f.flagString(false) - if name == "" { - t.Errorf("Non-parsable feature: %v", f) - } - - // Special cases not consistently visible. We don't mind if - // they are exposed in earlier versions. - switch { - // Block 0. - case f == X86FeatureSDBG && (major < 4 || major == 4 && minor < 3): - // SDBG only exposed in - // b1c599b8ff80ea79b9f8277a3f9f36a7b0cfedce (4.3). - continue - // Block 2. - case f == X86FeatureRDT && (major < 4 || major == 4 && minor < 10): - // RDT only exposed in - // 4ab1586488cb56ed8728e54c4157cc38646874d9 (4.10). - continue - // Block 3. - case f == X86FeatureAVX512VBMI && (major < 4 || major == 4 && minor < 10): - // AVX512VBMI only exposed in - // a8d9df5a509a232a959e4ef2e281f7ecd77810d6 (4.10). - continue - case f == X86FeatureUMIP && (major < 4 || major == 4 && minor < 15): - // UMIP only exposed in - // 3522c2a6a4f341058b8291326a945e2a2d2aaf55 (4.15). - continue - case f == X86FeaturePKU && (major < 4 || major == 4 && minor < 9): - // PKU only exposed in - // dfb4a70f20c5b3880da56ee4c9484bdb4e8f1e65 (4.9). - continue - // Block 4. - case f == X86FeatureXSAVES && (major < 4 || major == 4 && minor < 8): - // XSAVES only exposed in - // b8be15d588060a03569ac85dc4a0247460988f5b (4.8). - continue - // Block 5. - case f == X86FeaturePERFCTR_LLC && (major < 4 || major == 4 && minor < 14): - // PERFCTR_LLC renamed in - // 910448bbed066ab1082b510eef1ae61bb792d854 (4.14). - continue - } - - hidden := f.flagString(true) == "" - _, ok := cpuinfoFlags[name] - if hidden && ok { - t.Errorf("Unexpectedly hidden flag: %v", f) - } else if !hidden && !ok { - t.Errorf("Non-native flag: %v", f) - } - } -} diff --git a/pkg/cpuid/cpuid_state_autogen.go b/pkg/cpuid/cpuid_state_autogen.go new file mode 100755 index 000000000..845a149d4 --- /dev/null +++ b/pkg/cpuid/cpuid_state_autogen.go @@ -0,0 +1,70 @@ +// automatically generated by stateify. + +// +build i386 amd64 + +package cpuid + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Cache) beforeSave() {} +func (x *Cache) save(m state.Map) { + x.beforeSave() + m.Save("Level", &x.Level) + m.Save("Type", &x.Type) + m.Save("FullyAssociative", &x.FullyAssociative) + m.Save("Partitions", &x.Partitions) + m.Save("Ways", &x.Ways) + m.Save("Sets", &x.Sets) + m.Save("InvalidateHierarchical", &x.InvalidateHierarchical) + m.Save("Inclusive", &x.Inclusive) + m.Save("DirectMapped", &x.DirectMapped) +} + +func (x *Cache) afterLoad() {} +func (x *Cache) load(m state.Map) { + m.Load("Level", &x.Level) + m.Load("Type", &x.Type) + m.Load("FullyAssociative", &x.FullyAssociative) + m.Load("Partitions", &x.Partitions) + m.Load("Ways", &x.Ways) + m.Load("Sets", &x.Sets) + m.Load("InvalidateHierarchical", &x.InvalidateHierarchical) + m.Load("Inclusive", &x.Inclusive) + m.Load("DirectMapped", &x.DirectMapped) +} + +func (x *FeatureSet) beforeSave() {} +func (x *FeatureSet) save(m state.Map) { + x.beforeSave() + m.Save("Set", &x.Set) + m.Save("VendorID", &x.VendorID) + m.Save("ExtendedFamily", &x.ExtendedFamily) + m.Save("ExtendedModel", &x.ExtendedModel) + m.Save("ProcessorType", &x.ProcessorType) + m.Save("Family", &x.Family) + m.Save("Model", &x.Model) + m.Save("SteppingID", &x.SteppingID) + m.Save("Caches", &x.Caches) + m.Save("CacheLine", &x.CacheLine) +} + +func (x *FeatureSet) afterLoad() {} +func (x *FeatureSet) load(m state.Map) { + m.Load("Set", &x.Set) + m.Load("VendorID", &x.VendorID) + m.Load("ExtendedFamily", &x.ExtendedFamily) + m.Load("ExtendedModel", &x.ExtendedModel) + m.Load("ProcessorType", &x.ProcessorType) + m.Load("Family", &x.Family) + m.Load("Model", &x.Model) + m.Load("SteppingID", &x.SteppingID) + m.Load("Caches", &x.Caches) + m.Load("CacheLine", &x.CacheLine) +} + +func init() { + state.Register("pkg/cpuid.Cache", (*Cache)(nil), state.Fns{Save: (*Cache).save, Load: (*Cache).load}) + state.Register("pkg/cpuid.FeatureSet", (*FeatureSet)(nil), state.Fns{Save: (*FeatureSet).save, Load: (*FeatureSet).load}) +} diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go index a0bc55ea1..a0bc55ea1 100644..100755 --- a/pkg/cpuid/cpuid_x86.go +++ b/pkg/cpuid/cpuid_x86.go diff --git a/pkg/cpuid/cpuid_x86_test.go b/pkg/cpuid/cpuid_x86_test.go deleted file mode 100644 index 0fe20c213..000000000 --- a/pkg/cpuid/cpuid_x86_test.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build i386 amd64 - -package cpuid - -import ( - "testing" -) - -// These are the default values of various FeatureSet fields. -const ( - defaultVendorID = "GenuineIntel" - - // These processor signature defaults are derived from the values - // listed in Intel Application Note 485 for i7/Xeon processors. - defaultExtFamily uint8 = 0 - defaultExtModel uint8 = 1 - defaultType uint8 = 0 - defaultFamily uint8 = 0x06 - defaultModel uint8 = 0x0a - defaultSteppingID uint8 = 0 -) - -// newEmptyFeatureSet creates a new FeatureSet with a sensible default model and no features. -func newEmptyFeatureSet() *FeatureSet { - return &FeatureSet{ - Set: make(map[Feature]bool), - VendorID: defaultVendorID, - ExtendedFamily: defaultExtFamily, - ExtendedModel: defaultExtModel, - ProcessorType: defaultType, - Family: defaultFamily, - Model: defaultModel, - SteppingID: defaultSteppingID, - } -} - -var justFPU = &FeatureSet{ - Set: map[Feature]bool{ - X86FeatureFPU: true, - }} - -var justFPUandPAE = &FeatureSet{ - Set: map[Feature]bool{ - X86FeatureFPU: true, - X86FeaturePAE: true, - }} - -func TestSubtract(t *testing.T) { - if diff := justFPU.Subtract(justFPUandPAE); diff != nil { - t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff) - } - - if justFPUandPAE.Subtract(justFPU) == nil { - t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE) - } -} - -// TODO(b/73346484): Run this test on a very old platform, and make sure more -// bits are enabled than just FPU and PAE. This test currently may not detect -// if HostFeatureSet gives back junk bits. -func TestHostFeatureSet(t *testing.T) { - hostFeatures := HostFeatureSet() - if justFPUandPAE.Subtract(hostFeatures) != nil { - t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures) - } -} - -func TestHasFeature(t *testing.T) { - if !justFPU.HasFeature(X86FeatureFPU) { - t.Errorf("HasFeature failed, %v should contain %v", justFPU, X86FeatureFPU) - } - - if justFPU.HasFeature(X86FeatureAVX) { - t.Errorf("HasFeature failed, %v should not contain %v", justFPU, X86FeatureAVX) - } -} - -// Note: these tests are aware of and abuse internal details of FeatureSets. -// Users of FeatureSets should not depend on this. -func TestAdd(t *testing.T) { - // Test a basic insertion into the FeatureSet. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureCLFSH) - if len(testFeatures.Set) != 1 { - t.Errorf("Got length %v want 1", len(testFeatures.Set)) - } - - if !testFeatures.HasFeature(X86FeatureCLFSH) { - t.Errorf("Add failed, got %v want set with %v", testFeatures, X86FeatureCLFSH) - } - - // Test that duplicates are ignored. - testFeatures.Add(X86FeatureCLFSH) - if len(testFeatures.Set) != 1 { - t.Errorf("Got length %v, want 1", len(testFeatures.Set)) - } -} - -func TestRemove(t *testing.T) { - // Try removing the last feature. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureFPU) - testFeatures.Add(X86FeaturePAE) - testFeatures.Remove(X86FeaturePAE) - if !testFeatures.HasFeature(X86FeatureFPU) || len(testFeatures.Set) != 1 || testFeatures.HasFeature(X86FeaturePAE) { - t.Errorf("Remove failed, got %v want %v", testFeatures, justFPU) - } - - // Try removing a feature not in the set. - testFeatures.Remove(X86FeatureRDRAND) - if !testFeatures.HasFeature(X86FeatureFPU) || len(testFeatures.Set) != 1 { - t.Errorf("Remove failed, got %v want %v", testFeatures, justFPU) - } -} - -func TestFeatureFromString(t *testing.T) { - f, ok := FeatureFromString("avx") - if f != X86FeatureAVX || !ok { - t.Errorf("got %v want avx", f) - } - - f, ok = FeatureFromString("bad") - if ok { - t.Errorf("got %v want nothing", f) - } -} - -// This tests function 0 (eax=0), which returns the vendor ID and highest cpuid -// function reported to be available. -func TestEmulateIDVendorAndLength(t *testing.T) { - testFeatures := newEmptyFeatureSet() - - ax, bx, cx, dx := testFeatures.EmulateID(0, 0) - wantEax := uint32(0xd) // Highest supported cpuid function. - - // These magical constants are the characters of "GenuineIntel". - // See Intel AN485 for a reference on why they are laid out like this. - wantEbx := uint32(0x756e6547) - wantEcx := uint32(0x6c65746e) - wantEdx := uint32(0x49656e69) - if wantEax != ax { - t.Errorf("highest function failed, got %x want %x", ax, wantEax) - } - - if wantEbx != bx || wantEcx != cx || wantEdx != dx { - t.Errorf("vendor string emulation failed, bx:cx:dx, got %x:%x:%x want %x:%x:%x", bx, cx, dx, wantEbx, wantEcx, wantEdx) - } -} - -func TestEmulateIDBasicFeatures(t *testing.T) { - // Make a minimal test feature set. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureCLFSH) - testFeatures.Add(X86FeatureAVX) - testFeatures.CacheLine = 64 - - ax, bx, cx, dx := testFeatures.EmulateID(1, 0) - ECXAVXBit := uint32(1 << uint(X86FeatureAVX)) - EDXCLFlushBit := uint32(1 << uint(X86FeatureCLFSH-32)) // We adjust by 32 since it's in block 1. - - if EDXCLFlushBit&dx == 0 || dx&^EDXCLFlushBit != 0 { - t.Errorf("EmulateID failed, got feature bits %x want %x", dx, testFeatures.blockMask(1)) - } - - if ECXAVXBit&cx == 0 || cx&^ECXAVXBit != 0 { - t.Errorf("EmulateID failed, got feature bits %x want %x", cx, testFeatures.blockMask(0)) - } - - // Default signature bits, based on values for i7/Xeon. - // See Intel AN485 for information on stepping/model bits. - defaultSignature := uint32(0x000106a0) - if defaultSignature != ax { - t.Errorf("EmulateID stepping emulation failed, got %x want %x", ax, defaultSignature) - } - - clflushSizeInfo := uint32(8 << 8) - if clflushSizeInfo != bx { - t.Errorf("EmulateID bx emulation failed, got %x want %x", bx, clflushSizeInfo) - } -} - -func TestEmulateIDExtendedFeatures(t *testing.T) { - // Make a minimal test feature set, one bit in each extended feature word. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureSMEP) - testFeatures.Add(X86FeatureAVX512VBMI) - - ax, bx, cx, dx := testFeatures.EmulateID(7, 0) - EBXSMEPBit := uint32(1 << uint(X86FeatureSMEP-2*32)) // Adjust by 2*32 since SMEP is a block 2 feature. - ECXAVXBit := uint32(1 << uint(X86FeatureAVX512VBMI-3*32)) // We adjust by 3*32 since it's a block 3 feature. - - // Test that the desired bit is set and no other bits are set. - if EBXSMEPBit&bx == 0 || bx&^EBXSMEPBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", bx, testFeatures.blockMask(2)) - } - - if ECXAVXBit&cx == 0 || cx&^ECXAVXBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", cx, testFeatures.blockMask(3)) - } - - if ax != 0 || dx != 0 { - t.Errorf("extended feature emulation failed, ax:dx, got %x:%x want 0:0", ax, dx) - } - - // Check that no subleaves other than 0 do anything. - ax, bx, cx, dx = testFeatures.EmulateID(7, 1) - if ax != 0 || bx != 0 || cx != 0 || dx != 0 { - t.Errorf("extended feature emulation failed, got %x:%x:%x:%x want 0:0", ax, bx, cx, dx) - } - -} - -// Checks that the expected extended features are available via cpuid functions -// 0x80000000 and up. -func TestEmulateIDExtended(t *testing.T) { - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureSYSCALL) - EDXSYSCALLBit := uint32(1 << uint(X86FeatureSYSCALL-6*32)) // Adjust by 6*32 since SYSCALL is a block 6 feature. - - ax, bx, cx, dx := testFeatures.EmulateID(0x80000000, 0) - if ax != 0x80000001 || bx != 0 || cx != 0 || dx != 0 { - t.Errorf("EmulateID extended emulation failed, ax:bx:cx:dx, got %x:%x:%x:%x want 0x80000001:0:0:0", ax, bx, cx, dx) - } - - _, _, _, dx = testFeatures.EmulateID(0x80000001, 0) - if EDXSYSCALLBit&dx == 0 || dx&^EDXSYSCALLBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", dx, testFeatures.blockMask(6)) - } -} diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD deleted file mode 100644 index bee28b68d..000000000 --- a/pkg/eventchannel/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "eventchannel", - srcs = [ - "event.go", - "rate.go", - ], - visibility = ["//:sandbox"], - deps = [ - ":eventchannel_go_proto", - "//pkg/log", - "//pkg/sync", - "//pkg/unet", - "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_golang_protobuf//ptypes:go_default_library_gen", - "@org_golang_x_time//rate:go_default_library", - ], -) - -proto_library( - name = "eventchannel", - srcs = ["event.proto"], - visibility = ["//:sandbox"], -) - -go_test( - name = "eventchannel_test", - srcs = ["event_test.go"], - library = ":eventchannel", - deps = [ - "//pkg/sync", - "@com_github_golang_protobuf//proto:go_default_library", - ], -) diff --git a/pkg/eventchannel/event.proto b/pkg/eventchannel/event.proto deleted file mode 100644 index 34468f072..000000000 --- a/pkg/eventchannel/event.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -// A debug event encapsulates any other event protobuf in text format. This is -// useful because clients reading events emitted this way do not need to link -// the event protobufs to display them in a human-readable format. -message DebugEvent { - // Name of the inner message. - string name = 1; - // Text representation of the inner message content. - string text = 2; -} diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go deleted file mode 100644 index 7f41b4a27..000000000 --- a/pkg/eventchannel/event_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package eventchannel - -import ( - "fmt" - "testing" - "time" - - "github.com/golang/protobuf/proto" - "gvisor.dev/gvisor/pkg/sync" -) - -// testEmitter is an emitter that can be used in tests. It records all events -// emitted, and whether it has been closed. -type testEmitter struct { - // mu protects all fields below. - mu sync.Mutex - - // events contains all emitted events. - events []proto.Message - - // closed records whether Close() was called. - closed bool -} - -// Emit implements Emitter.Emit. -func (te *testEmitter) Emit(msg proto.Message) (bool, error) { - te.mu.Lock() - defer te.mu.Unlock() - te.events = append(te.events, msg) - return false, nil -} - -// Close implements Emitter.Close. -func (te *testEmitter) Close() error { - te.mu.Lock() - defer te.mu.Unlock() - if te.closed { - return fmt.Errorf("closed called twice") - } - te.closed = true - return nil -} - -// testMessage implements proto.Message for testing. -type testMessage struct { - proto.Message - - // name is the name of the message, used by tests to compare messages. - name string -} - -func TestMultiEmitter(t *testing.T) { - // Create three testEmitters, tied together in a multiEmitter. - me := &multiEmitter{} - var emitters []*testEmitter - for i := 0; i < 3; i++ { - te := &testEmitter{} - emitters = append(emitters, te) - me.AddEmitter(te) - } - - // Emit three messages to multiEmitter. - names := []string{"foo", "bar", "baz"} - for _, name := range names { - m := testMessage{name: name} - if _, err := me.Emit(m); err != nil { - t.Fatal("me.Emit(%v) failed: %v", m, err) - } - } - - // All three emitters should have all three events. - for _, te := range emitters { - if got, want := len(te.events), len(names); got != want { - t.Fatalf("emitter got %d events, want %d", got, want) - } - for i, name := range names { - if got := te.events[i].(testMessage).name; got != name { - t.Errorf("emitter got message with name %q, want %q", got, name) - } - } - } - - // Close multiEmitter. - if err := me.Close(); err != nil { - t.Fatal("me.Close() failed: %v", err) - } - - // All testEmitters should be closed. - for _, te := range emitters { - if !te.closed { - t.Errorf("te.closed got false, want true") - } - } -} - -func TestRateLimitedEmitter(t *testing.T) { - // Create a RateLimittedEmitter that wraps a testEmitter. - te := &testEmitter{} - max := float64(5) // events per second - burst := 10 // events - rle := RateLimitedEmitterFrom(te, max, burst) - - // Send 50 messages in one shot. - for i := 0; i < 50; i++ { - if _, err := rle.Emit(testMessage{}); err != nil { - t.Fatalf("rle.Emit failed: %v", err) - } - } - - // We should have received only 10 messages. - if got, want := len(te.events), 10; got != want { - t.Errorf("got %d events, want %d", got, want) - } - - // Sleep for a second and then send another 50. - time.Sleep(1 * time.Second) - for i := 0; i < 50; i++ { - if _, err := rle.Emit(testMessage{}); err != nil { - t.Fatalf("rle.Emit failed: %v", err) - } - } - - // We should have at least 5 more message, plus maybe a few more if the - // test ran slowly. - got, wantAtLeast, wantAtMost := len(te.events), 15, 20 - if got < wantAtLeast { - t.Errorf("got %d events, want at least %d", got, wantAtLeast) - } - if got > wantAtMost { - t.Errorf("got %d events, want at most %d", got, wantAtMost) - } -} diff --git a/pkg/eventchannel/eventchannel_go_proto/event.pb.go b/pkg/eventchannel/eventchannel_go_proto/event.pb.go new file mode 100755 index 000000000..bb71ed3e6 --- /dev/null +++ b/pkg/eventchannel/eventchannel_go_proto/event.pb.go @@ -0,0 +1,85 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/eventchannel/event.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type DebugEvent struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Text string `protobuf:"bytes,2,opt,name=text,proto3" json:"text,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *DebugEvent) Reset() { *m = DebugEvent{} } +func (m *DebugEvent) String() string { return proto.CompactTextString(m) } +func (*DebugEvent) ProtoMessage() {} +func (*DebugEvent) Descriptor() ([]byte, []int) { + return fileDescriptor_fcfbd51abd9de962, []int{0} +} + +func (m *DebugEvent) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_DebugEvent.Unmarshal(m, b) +} +func (m *DebugEvent) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_DebugEvent.Marshal(b, m, deterministic) +} +func (m *DebugEvent) XXX_Merge(src proto.Message) { + xxx_messageInfo_DebugEvent.Merge(m, src) +} +func (m *DebugEvent) XXX_Size() int { + return xxx_messageInfo_DebugEvent.Size(m) +} +func (m *DebugEvent) XXX_DiscardUnknown() { + xxx_messageInfo_DebugEvent.DiscardUnknown(m) +} + +var xxx_messageInfo_DebugEvent proto.InternalMessageInfo + +func (m *DebugEvent) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *DebugEvent) GetText() string { + if m != nil { + return m.Text + } + return "" +} + +func init() { + proto.RegisterType((*DebugEvent)(nil), "gvisor.DebugEvent") +} + +func init() { proto.RegisterFile("pkg/eventchannel/event.proto", fileDescriptor_fcfbd51abd9de962) } + +var fileDescriptor_fcfbd51abd9de962 = []byte{ + // 103 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x29, 0xc8, 0x4e, 0xd7, + 0x4f, 0x2d, 0x4b, 0xcd, 0x2b, 0x49, 0xce, 0x48, 0xcc, 0xcb, 0x4b, 0xcd, 0x81, 0x70, 0xf4, 0x0a, + 0x8a, 0xf2, 0x4b, 0xf2, 0x85, 0xd8, 0xd2, 0xcb, 0x32, 0x8b, 0xf3, 0x8b, 0x94, 0x4c, 0xb8, 0xb8, + 0x5c, 0x52, 0x93, 0x4a, 0xd3, 0x5d, 0x41, 0x72, 0x42, 0x42, 0x5c, 0x2c, 0x79, 0x89, 0xb9, 0xa9, + 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x60, 0x36, 0x48, 0xac, 0x24, 0xb5, 0xa2, 0x44, 0x82, + 0x09, 0x22, 0x06, 0x62, 0x27, 0xb1, 0x81, 0x0d, 0x31, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x17, + 0xee, 0x7f, 0xef, 0x64, 0x00, 0x00, 0x00, +} diff --git a/pkg/eventchannel/eventchannel_state_autogen.go b/pkg/eventchannel/eventchannel_state_autogen.go new file mode 100755 index 000000000..50b9c54b3 --- /dev/null +++ b/pkg/eventchannel/eventchannel_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package eventchannel diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD deleted file mode 100644 index 872361546..000000000 --- a/pkg/fd/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fd", - srcs = ["fd.go"], - visibility = ["//visibility:public"], -) - -go_test( - name = "fd_test", - size = "small", - srcs = ["fd_test.go"], - library = ":fd", -) diff --git a/pkg/fd/fd_state_autogen.go b/pkg/fd/fd_state_autogen.go new file mode 100755 index 000000000..5ad412976 --- /dev/null +++ b/pkg/fd/fd_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fd diff --git a/pkg/fd/fd_test.go b/pkg/fd/fd_test.go deleted file mode 100644 index 5fb0ad47d..000000000 --- a/pkg/fd/fd_test.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fd - -import ( - "math" - "os" - "syscall" - "testing" -) - -func TestSetNegOne(t *testing.T) { - type entry struct { - name string - file *FD - fn func() error - } - var tests []entry - - fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("syscall.Socket:", err) - } - f1 := New(fd) - tests = append(tests, entry{ - "Release", - f1, - func() error { - return syscall.Close(f1.Release()) - }, - }) - - fd, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("syscall.Socket:", err) - } - f2 := New(fd) - tests = append(tests, entry{ - "Close", - f2, - f2.Close, - }) - - for _, test := range tests { - if err := test.fn(); err != nil { - t.Errorf("%s: %v", test.name, err) - continue - } - if fd := test.file.FD(); fd != -1 { - t.Errorf("%s: got FD() = %d, want = -1", test.name, fd) - } - } -} - -func TestStartsNegOne(t *testing.T) { - type entry struct { - name string - file *FD - } - - tests := []entry{ - {"-1", New(-1)}, - {"-2", New(-2)}, - {"MinInt32", New(math.MinInt32)}, - {"MinInt64", New(math.MinInt64)}, - } - - for _, test := range tests { - if fd := test.file.FD(); fd != -1 { - t.Errorf("%s: got FD() = %d, want = -1", test.name, fd) - } - } -} - -func TestFileDotFile(t *testing.T) { - fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("syscall.Socket:", err) - } - - f := New(fd) - of, err := f.File() - if err != nil { - t.Fatalf("File got err %v want nil", err) - } - - if ofd, nfd := int(of.Fd()), f.FD(); ofd == nfd || ofd == -1 { - // Try not to double close the FD. - f.Release() - - t.Fatalf("got %#v.File().Fd() = %d, want new FD", f, ofd) - } - - f.Close() - of.Close() -} - -func TestFileDotFileError(t *testing.T) { - f := &FD{ReadWriter{-2}} - - if of, err := f.File(); err == nil { - t.Errorf("File %v got nil err want non-nil", of) - of.Close() - } -} - -func TestNewFromFile(t *testing.T) { - f, err := NewFromFile(os.Stdin) - if err != nil { - t.Fatalf("NewFromFile got err %v want nil", err) - } - if nfd, ofd := f.FD(), int(os.Stdin.Fd()); nfd == -1 || nfd == ofd { - t.Errorf("got FD() = %d, want = new FD (old FD was %d)", nfd, ofd) - } - f.Close() -} - -func TestNewFromFileError(t *testing.T) { - f, err := NewFromFile(nil) - if err == nil { - t.Errorf("NewFromFile got %v with nil err want non-nil", f) - f.Close() - } -} diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD deleted file mode 100644 index d9104ef02..000000000 --- a/pkg/fdchannel/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "fdchannel", - srcs = ["fdchannel_unsafe.go"], - visibility = ["//visibility:public"], -) - -go_test( - name = "fdchannel_test", - size = "small", - srcs = ["fdchannel_test.go"], - library = ":fdchannel", - deps = ["//pkg/sync"], -) diff --git a/pkg/fdchannel/fdchannel_state_autogen.go b/pkg/fdchannel/fdchannel_state_autogen.go new file mode 100755 index 000000000..61447d773 --- /dev/null +++ b/pkg/fdchannel/fdchannel_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +package fdchannel diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go deleted file mode 100644 index 7a8a63a59..000000000 --- a/pkg/fdchannel/fdchannel_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fdchannel - -import ( - "io/ioutil" - "os" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestSendRecvFD(t *testing.T) { - sendFile, err := ioutil.TempFile("", "fdchannel_test_") - if err != nil { - t.Fatalf("failed to create temporary file: %v", err) - } - defer sendFile.Close() - - chanFDs, err := NewConnectedSockets() - if err != nil { - t.Fatalf("failed to create fdchannel sockets: %v", err) - } - sendEP := NewEndpoint(chanFDs[0]) - defer sendEP.Destroy() - recvEP := NewEndpoint(chanFDs[1]) - defer recvEP.Destroy() - - recvFD, err := recvEP.RecvFDNonblock() - if err != syscall.EAGAIN && err != syscall.EWOULDBLOCK { - t.Errorf("RecvFDNonblock before SendFD: got (%d, %v), wanted (<unspecified>, EAGAIN or EWOULDBLOCK", recvFD, err) - } - - if err := sendEP.SendFD(int(sendFile.Fd())); err != nil { - t.Fatalf("SendFD failed: %v", err) - } - recvFD, err = recvEP.RecvFD() - if err != nil { - t.Fatalf("RecvFD failed: %v", err) - } - recvFile := os.NewFile(uintptr(recvFD), "received file") - defer recvFile.Close() - - sendInfo, err := sendFile.Stat() - if err != nil { - t.Fatalf("failed to stat sent file: %v", err) - } - sendInfoSys := sendInfo.Sys() - sendStat, ok := sendInfoSys.(*syscall.Stat_t) - if !ok { - t.Fatalf("sent file's FileInfo is backed by unknown type %T", sendInfoSys) - } - - recvInfo, err := recvFile.Stat() - if err != nil { - t.Fatalf("failed to stat received file: %v", err) - } - recvInfoSys := recvInfo.Sys() - recvStat, ok := recvInfoSys.(*syscall.Stat_t) - if !ok { - t.Fatalf("received file's FileInfo is backed by unknown type %T", recvInfoSys) - } - - if sendStat.Dev != recvStat.Dev || sendStat.Ino != recvStat.Ino { - t.Errorf("sent file (dev=%d, ino=%d) does not match received file (dev=%d, ino=%d)", sendStat.Dev, sendStat.Ino, recvStat.Dev, recvStat.Ino) - } -} - -func TestShutdownThenRecvFD(t *testing.T) { - sendFile, err := ioutil.TempFile("", "fdchannel_test_") - if err != nil { - t.Fatalf("failed to create temporary file: %v", err) - } - defer sendFile.Close() - - chanFDs, err := NewConnectedSockets() - if err != nil { - t.Fatalf("failed to create fdchannel sockets: %v", err) - } - sendEP := NewEndpoint(chanFDs[0]) - defer sendEP.Destroy() - recvEP := NewEndpoint(chanFDs[1]) - defer recvEP.Destroy() - - recvEP.Shutdown() - if _, err := recvEP.RecvFD(); err == nil { - t.Error("RecvFD succeeded unexpectedly") - } -} - -func TestRecvFDThenShutdown(t *testing.T) { - sendFile, err := ioutil.TempFile("", "fdchannel_test_") - if err != nil { - t.Fatalf("failed to create temporary file: %v", err) - } - defer sendFile.Close() - - chanFDs, err := NewConnectedSockets() - if err != nil { - t.Fatalf("failed to create fdchannel sockets: %v", err) - } - sendEP := NewEndpoint(chanFDs[0]) - defer sendEP.Destroy() - recvEP := NewEndpoint(chanFDs[1]) - defer recvEP.Destroy() - - var receiverWG sync.WaitGroup - receiverWG.Add(1) - go func() { - defer receiverWG.Done() - if _, err := recvEP.RecvFD(); err == nil { - t.Error("RecvFD succeeded unexpectedly") - } - }() - defer receiverWG.Wait() - time.Sleep(time.Second) // to ensure recvEP.RecvFD() has blocked - recvEP.Shutdown() -} diff --git a/pkg/fdchannel/fdchannel_unsafe.go b/pkg/fdchannel/fdchannel_unsafe.go index 367235be5..367235be5 100644..100755 --- a/pkg/fdchannel/fdchannel_unsafe.go +++ b/pkg/fdchannel/fdchannel_unsafe.go diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD deleted file mode 100644 index 235dcc490..000000000 --- a/pkg/fdnotifier/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fdnotifier", - srcs = [ - "fdnotifier.go", - "poll_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/sync", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/fdnotifier/fdnotifier_state_autogen.go b/pkg/fdnotifier/fdnotifier_state_autogen.go new file mode 100755 index 000000000..c665190ae --- /dev/null +++ b/pkg/fdnotifier/fdnotifier_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build linux +// +build linux + +package fdnotifier diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD deleted file mode 100644 index 9c5ad500b..000000000 --- a/pkg/flipcall/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "flipcall", - srcs = [ - "ctrl_futex.go", - "flipcall.go", - "flipcall_unsafe.go", - "futex_linux.go", - "io.go", - "packet_window_allocator.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/memutil", - "//pkg/sync", - ], -) - -go_test( - name = "flipcall_test", - size = "small", - srcs = [ - "flipcall_example_test.go", - "flipcall_test.go", - ], - library = ":flipcall", - deps = ["//pkg/sync"], -) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index e7c3a3a0b..e7c3a3a0b 100644..100755 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 3cdb576e1..3cdb576e1 100644..100755 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go deleted file mode 100644 index 2e28a149a..000000000 --- a/pkg/flipcall/flipcall_example_test.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package flipcall - -import ( - "bytes" - "fmt" - - "gvisor.dev/gvisor/pkg/sync" -) - -func Example() { - const ( - reqPrefix = "request " - respPrefix = "response " - count = 3 - maxMessageLen = len(respPrefix) + 1 // 1 digit - ) - - pwa, err := NewPacketWindowAllocator() - if err != nil { - panic(err) - } - defer pwa.Destroy() - pwd, err := pwa.Allocate(PacketWindowLengthForDataCap(uint32(maxMessageLen))) - if err != nil { - panic(err) - } - var clientEP Endpoint - if err := clientEP.Init(ClientSide, pwd); err != nil { - panic(err) - } - defer clientEP.Destroy() - var serverEP Endpoint - if err := serverEP.Init(ServerSide, pwd); err != nil { - panic(err) - } - defer serverEP.Destroy() - - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - i := 0 - var buf bytes.Buffer - // wait for first request - n, err := serverEP.RecvFirst() - if err != nil { - return - } - for { - // read request - buf.Reset() - buf.Write(serverEP.Data()[:n]) - fmt.Println(buf.String()) - // write response - buf.Reset() - fmt.Fprintf(&buf, "%s%d", respPrefix, i) - copy(serverEP.Data(), buf.Bytes()) - // send response and wait for next request - n, err = serverEP.SendRecv(uint32(buf.Len())) - if err != nil { - return - } - i++ - } - }() - defer func() { - serverEP.Shutdown() - serverRun.Wait() - }() - - // establish connection as client - if err := clientEP.Connect(); err != nil { - panic(err) - } - var buf bytes.Buffer - for i := 0; i < count; i++ { - // write request - buf.Reset() - fmt.Fprintf(&buf, "%s%d", reqPrefix, i) - copy(clientEP.Data(), buf.Bytes()) - // send request and wait for response - n, err := clientEP.SendRecv(uint32(buf.Len())) - if err != nil { - panic(err) - } - // read response - buf.Reset() - buf.Write(clientEP.Data()[:n]) - fmt.Println(buf.String()) - } - - // Output: - // request 0 - // response 0 - // request 1 - // response 1 - // request 2 - // response 2 -} diff --git a/pkg/flipcall/flipcall_linux_state_autogen.go b/pkg/flipcall/flipcall_linux_state_autogen.go new file mode 100755 index 000000000..ce37ac4e1 --- /dev/null +++ b/pkg/flipcall/flipcall_linux_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build linux + +package flipcall diff --git a/pkg/flipcall/flipcall_state_autogen.go b/pkg/flipcall/flipcall_state_autogen.go new file mode 100755 index 000000000..0e03c2a65 --- /dev/null +++ b/pkg/flipcall/flipcall_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package flipcall diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go deleted file mode 100644 index 33fd55a44..000000000 --- a/pkg/flipcall/flipcall_test.go +++ /dev/null @@ -1,405 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package flipcall - -import ( - "runtime" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -var testPacketWindowSize = pageSize - -type testConnection struct { - pwa PacketWindowAllocator - clientEP Endpoint - serverEP Endpoint -} - -func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []EndpointOption) *testConnection { - c := &testConnection{} - if err := c.pwa.Init(); err != nil { - tb.Fatalf("failed to create PacketWindowAllocator: %v", err) - } - pwd, err := c.pwa.Allocate(testPacketWindowSize) - if err != nil { - c.pwa.Destroy() - tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) - } - if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil { - c.pwa.Destroy() - tb.Fatalf("failed to create client Endpoint: %v", err) - } - if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil { - c.pwa.Destroy() - c.clientEP.Destroy() - tb.Fatalf("failed to create server Endpoint: %v", err) - } - return c -} - -func newTestConnection(tb testing.TB) *testConnection { - return newTestConnectionWithOptions(tb, nil, nil) -} - -func (c *testConnection) destroy() { - c.pwa.Destroy() - c.clientEP.Destroy() - c.serverEP.Destroy() -} - -func testSendRecv(t *testing.T, c *testConnection) { - // This shared variable is used to confirm that synchronization between - // flipcall endpoints is visible to the Go race detector. - state := 0 - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - t.Logf("server Endpoint waiting for packet 1") - if _, err := c.serverEP.RecvFirst(); err != nil { - t.Errorf("server Endpoint.RecvFirst() failed: %v", err) - return - } - state++ - if state != 2 { - t.Errorf("shared state counter: got %d, wanted 2", state) - } - t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") - if _, err := c.serverEP.SendRecv(0); err != nil { - t.Errorf("server Endpoint.SendRecv() failed: %v", err) - return - } - state++ - if state != 4 { - t.Errorf("shared state counter: got %d, wanted 4", state) - } - t.Logf("server Endpoint got packet 3") - }() - defer func() { - // Ensure that the server goroutine is cleaned up before - // c.serverEP.Destroy(), even if the test fails. - c.serverEP.Shutdown() - serverRun.Wait() - }() - - t.Logf("client Endpoint establishing connection") - if err := c.clientEP.Connect(); err != nil { - t.Fatalf("client Endpoint.Connect() failed: %v", err) - } - state++ - if state != 1 { - t.Errorf("shared state counter: got %d, wanted 1", state) - } - t.Logf("client Endpoint sending packet 1 and waiting for packet 2") - if _, err := c.clientEP.SendRecv(0); err != nil { - t.Fatalf("client Endpoint.SendRecv() failed: %v", err) - } - state++ - if state != 3 { - t.Errorf("shared state counter: got %d, wanted 3", state) - } - t.Logf("client Endpoint got packet 2, sending packet 3") - if err := c.clientEP.SendLast(0); err != nil { - t.Fatalf("client Endpoint.SendLast() failed: %v", err) - } - t.Logf("waiting for server goroutine to complete") - serverRun.Wait() -} - -func TestSendRecv(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testSendRecv(t, c) -} - -func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { - if remoteShutdown { - c.serverEP.Shutdown() - } else { - c.clientEP.Shutdown() - } - if err := c.clientEP.Connect(); err == nil { - t.Errorf("client Endpoint.Connect() succeeded unexpectedly") - } -} - -func TestShutdownBeforeConnectLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownBeforeConnect(t, c, false) -} - -func TestShutdownBeforeConnectRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownBeforeConnect(t, c, true) -} - -func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) { - var clientRun sync.WaitGroup - clientRun.Add(1) - go func() { - defer clientRun.Done() - if err := c.clientEP.Connect(); err == nil { - t.Errorf("client Endpoint.Connect() succeeded unexpectedly") - } - }() - time.Sleep(time.Second) // to allow c.clientEP.Connect() to block - if remoteShutdown { - c.serverEP.Shutdown() - } else { - c.clientEP.Shutdown() - } - clientRun.Wait() -} - -func TestShutdownDuringConnectLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringConnect(t, c, false) -} - -func TestShutdownDuringConnectRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringConnect(t, c, true) -} - -func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) { - if remoteShutdown { - c.clientEP.Shutdown() - } else { - c.serverEP.Shutdown() - } - if _, err := c.serverEP.RecvFirst(); err == nil { - t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") - } -} - -func TestShutdownBeforeRecvFirstLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownBeforeRecvFirst(t, c, false) -} - -func TestShutdownBeforeRecvFirstRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownBeforeRecvFirst(t, c, true) -} - -func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if _, err := c.serverEP.RecvFirst(); err == nil { - t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") - } - }() - time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block - if remoteShutdown { - c.clientEP.Shutdown() - } else { - c.serverEP.Shutdown() - } - serverRun.Wait() -} - -func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringRecvFirstBeforeConnect(t, c, false) -} - -func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringRecvFirstBeforeConnect(t, c, true) -} - -func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if _, err := c.serverEP.RecvFirst(); err == nil { - t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") - } - }() - defer func() { - // Ensure that the server goroutine is cleaned up before - // c.serverEP.Destroy(), even if the test fails. - c.serverEP.Shutdown() - serverRun.Wait() - }() - if err := c.clientEP.Connect(); err != nil { - t.Fatalf("client Endpoint.Connect() failed: %v", err) - } - if remoteShutdown { - c.clientEP.Shutdown() - } else { - c.serverEP.Shutdown() - } - serverRun.Wait() -} - -func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringRecvFirstAfterConnect(t, c, false) -} - -func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringRecvFirstAfterConnect(t, c, true) -} - -func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if _, err := c.serverEP.RecvFirst(); err != nil { - t.Errorf("server Endpoint.RecvFirst() failed: %v", err) - } - // At this point, the client must be blocked in c.clientEP.SendRecv(). - if remoteShutdown { - c.serverEP.Shutdown() - } else { - c.clientEP.Shutdown() - } - }() - defer func() { - // Ensure that the server goroutine is cleaned up before - // c.serverEP.Destroy(), even if the test fails. - c.serverEP.Shutdown() - serverRun.Wait() - }() - if err := c.clientEP.Connect(); err != nil { - t.Fatalf("client Endpoint.Connect() failed: %v", err) - } - if _, err := c.clientEP.SendRecv(0); err == nil { - t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly") - } -} - -func TestShutdownDuringClientSendRecvLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringClientSendRecv(t, c, false) -} - -func TestShutdownDuringClientSendRecvRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringClientSendRecv(t, c, true) -} - -func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if _, err := c.serverEP.RecvFirst(); err != nil { - t.Errorf("server Endpoint.RecvFirst() failed: %v", err) - return - } - if _, err := c.serverEP.SendRecv(0); err == nil { - t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly") - } - }() - defer func() { - // Ensure that the server goroutine is cleaned up before - // c.serverEP.Destroy(), even if the test fails. - c.serverEP.Shutdown() - serverRun.Wait() - }() - if err := c.clientEP.Connect(); err != nil { - t.Fatalf("client Endpoint.Connect() failed: %v", err) - } - if _, err := c.clientEP.SendRecv(0); err != nil { - t.Fatalf("client Endpoint.SendRecv() failed: %v", err) - } - time.Sleep(time.Second) // to allow serverEP.SendRecv() to block - if remoteShutdown { - c.clientEP.Shutdown() - } else { - c.serverEP.Shutdown() - } - serverRun.Wait() -} - -func TestShutdownDuringServerSendRecvLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringServerSendRecv(t, c, false) -} - -func TestShutdownDuringServerSendRecvRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringServerSendRecv(t, c, true) -} - -func benchmarkSendRecv(b *testing.B, c *testConnection) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if b.N == 0 { - return - } - if _, err := c.serverEP.RecvFirst(); err != nil { - b.Errorf("server Endpoint.RecvFirst() failed: %v", err) - return - } - for i := 1; i < b.N; i++ { - if _, err := c.serverEP.SendRecv(0); err != nil { - b.Errorf("server Endpoint.SendRecv() failed: %v", err) - return - } - } - if err := c.serverEP.SendLast(0); err != nil { - b.Errorf("server Endpoint.SendLast() failed: %v", err) - } - }() - defer func() { - c.serverEP.Shutdown() - serverRun.Wait() - }() - - if err := c.clientEP.Connect(); err != nil { - b.Fatalf("client Endpoint.Connect() failed: %v", err) - } - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if _, err := c.clientEP.SendRecv(0); err != nil { - b.Fatalf("client Endpoint.SendRecv() failed: %v", err) - } - } - b.StopTimer() -} - -func BenchmarkSendRecv(b *testing.B) { - c := newTestConnection(b) - defer c.destroy() - benchmarkSendRecv(b, c) -} diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index ac974b232..ac974b232 100644..100755 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index 168c1ccff..168c1ccff 100644..100755 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go diff --git a/pkg/flipcall/io.go b/pkg/flipcall/io.go index 85e40b932..85e40b932 100644..100755 --- a/pkg/flipcall/io.go +++ b/pkg/flipcall/io.go diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window_allocator.go index ccb918fab..ccb918fab 100644..100755 --- a/pkg/flipcall/packet_window_allocator.go +++ b/pkg/flipcall/packet_window_allocator.go diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD deleted file mode 100644 index 67dd1e225..000000000 --- a/pkg/fspath/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) - -go_library( - name = "fspath", - srcs = [ - "builder.go", - "fspath.go", - ], - deps = [ - "//pkg/gohacks", - ], -) - -go_test( - name = "fspath_test", - size = "small", - srcs = [ - "builder_test.go", - "fspath_test.go", - ], - library = ":fspath", -) diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go index 6318d3874..6318d3874 100644..100755 --- a/pkg/fspath/builder.go +++ b/pkg/fspath/builder.go diff --git a/pkg/fspath/builder_test.go b/pkg/fspath/builder_test.go deleted file mode 100644 index 22f890273..000000000 --- a/pkg/fspath/builder_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fspath - -import ( - "testing" -) - -func TestBuilder(t *testing.T) { - type testCase struct { - pcs []string // path components in reverse order - after string - want string - } - tests := []testCase{ - { - // Empty case. - }, - { - pcs: []string{"foo"}, - want: "foo", - }, - { - pcs: []string{"foo", "bar", "baz"}, - want: "baz/bar/foo", - }, - { - pcs: []string{"foo", "bar"}, - after: " (deleted)", - want: "bar/foo (deleted)", - }, - } - - for _, test := range tests { - t.Run(test.want, func(t *testing.T) { - var b Builder - for _, pc := range test.pcs { - b.PrependComponent(pc) - } - b.AppendString(test.after) - if got := b.String(); got != test.want { - t.Errorf("got %q, wanted %q", got, test.want) - } - }) - } -} diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go index 4c983d5fd..4c983d5fd 100644..100755 --- a/pkg/fspath/fspath.go +++ b/pkg/fspath/fspath.go diff --git a/pkg/fspath/fspath_state_autogen.go b/pkg/fspath/fspath_state_autogen.go new file mode 100755 index 000000000..6ceea8003 --- /dev/null +++ b/pkg/fspath/fspath_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fspath diff --git a/pkg/fspath/fspath_test.go b/pkg/fspath/fspath_test.go deleted file mode 100644 index d5e9a549a..000000000 --- a/pkg/fspath/fspath_test.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fspath - -import ( - "reflect" - "strings" - "testing" -) - -func TestParseIteratorPartialPathnames(t *testing.T) { - path := Parse("/foo//bar///baz////") - // Parse strips leading slashes, and records their presence as - // Path.Absolute. - if !path.Absolute { - t.Errorf("Path.Absolute: got false, wanted true") - } - // Parse strips trailing slashes, and records their presence as Path.Dir. - if !path.Dir { - t.Errorf("Path.Dir: got false, wanted true") - } - // The first Iterator.partialPathname is the input pathname, with leading - // and trailing slashes stripped. - it := path.Begin - if want := "foo//bar///baz"; it.partialPathname != want { - t.Errorf("first Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want) - } - // Successive Iterator.partialPathnames remove the leading path component - // and following slashes, until we run out of path components and get a - // terminal Iterator. - it = it.Next() - if want := "bar///baz"; it.partialPathname != want { - t.Errorf("second Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want) - } - it = it.Next() - if want := "baz"; it.partialPathname != want { - t.Errorf("third Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want) - } - it = it.Next() - if want := ""; it.partialPathname != want { - t.Errorf("fourth Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want) - } - if it.Ok() { - t.Errorf("fourth Iterator.Ok(): got true, wanted false") - } -} - -func TestParse(t *testing.T) { - type testCase struct { - pathname string - relpath []string - abs bool - dir bool - } - tests := []testCase{ - { - pathname: "", - relpath: []string{}, - abs: false, - dir: false, - }, - { - pathname: "/", - relpath: []string{}, - abs: true, - dir: true, - }, - { - pathname: "//", - relpath: []string{}, - abs: true, - dir: true, - }, - } - for _, sep := range []string{"/", "//"} { - for _, abs := range []bool{false, true} { - for _, dir := range []bool{false, true} { - for _, pcs := range [][]string{ - // single path component - {"foo"}, - // multiple path components, including non-UTF-8 - {".", "foo", "..", "\xe6", "bar"}, - } { - prefix := "" - if abs { - prefix = sep - } - suffix := "" - if dir { - suffix = sep - } - tests = append(tests, testCase{ - pathname: prefix + strings.Join(pcs, sep) + suffix, - relpath: pcs, - abs: abs, - dir: dir, - }) - } - } - } - } - - for _, test := range tests { - t.Run(test.pathname, func(t *testing.T) { - p := Parse(test.pathname) - t.Logf("pathname %q => path %q", test.pathname, p) - if p.Absolute != test.abs { - t.Errorf("path absoluteness: got %v, wanted %v", p.Absolute, test.abs) - } - if p.Dir != test.dir { - t.Errorf("path must resolve to a directory: got %v, wanted %v", p.Dir, test.dir) - } - pcs := []string{} - for pit := p.Begin; pit.Ok(); pit = pit.Next() { - pcs = append(pcs, pit.String()) - } - if !reflect.DeepEqual(pcs, test.relpath) { - t.Errorf("relative path: got %v, wanted %v", pcs, test.relpath) - } - }) - } -} diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD deleted file mode 100644 index dd3141143..000000000 --- a/pkg/gate/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gate", - srcs = [ - "gate.go", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "gate_test", - srcs = [ - "gate_test.go", - ], - deps = [ - ":gate", - "//pkg/sync", - ], -) diff --git a/pkg/gate/gate_state_autogen.go b/pkg/gate/gate_state_autogen.go new file mode 100755 index 000000000..221af659e --- /dev/null +++ b/pkg/gate/gate_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package gate diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go deleted file mode 100644 index 850693df8..000000000 --- a/pkg/gate/gate_test.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gate_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/gate" - "gvisor.dev/gvisor/pkg/sync" -) - -func TestBasicEnter(t *testing.T) { - var g gate.Gate - - if !g.Enter() { - t.Fatalf("Failed to enter when it should be allowed") - } - - g.Leave() - - g.Close() - - if g.Enter() { - t.Fatalf("Allowed to enter when it should fail") - } -} - -func enterFunc(t *testing.T, g *gate.Gate, enter, leave, reenter chan struct{}, done1, done2, done3 *sync.WaitGroup) { - // Wait until instructed to enter. - <-enter - if !g.Enter() { - t.Errorf("Failed to enter when it should be allowed") - } - - done1.Done() - - // Wait until instructed to leave. - <-leave - g.Leave() - - done2.Done() - - // Wait until instructed to reenter. - <-reenter - if g.Enter() { - t.Errorf("Allowed to enter when it should fail") - } - done3.Done() -} - -func TestConcurrentEnter(t *testing.T) { - var g gate.Gate - var done1, done2, done3 sync.WaitGroup - - // Create 1000 worker goroutines. - enter := make(chan struct{}) - leave := make(chan struct{}) - reenter := make(chan struct{}) - done1.Add(1000) - done2.Add(1000) - done3.Add(1000) - for i := 0; i < 1000; i++ { - go enterFunc(t, &g, enter, leave, reenter, &done1, &done2, &done3) - } - - // Tell them all to enter, then leave. - close(enter) - done1.Wait() - - close(leave) - done2.Wait() - - // Close the gate, then have the workers try to enter again. - g.Close() - close(reenter) - done3.Wait() -} - -func closeFunc(g *gate.Gate, done chan struct{}) { - g.Close() - close(done) -} - -func TestCloseWaits(t *testing.T) { - var g gate.Gate - - // Enter 10 times. - for i := 0; i < 10; i++ { - if !g.Enter() { - t.Fatalf("Failed to enter when it should be allowed") - } - } - - // Launch closer. Check that it doesn't complete. - done := make(chan struct{}) - go closeFunc(&g, done) - - for i := 0; i < 10; i++ { - select { - case <-done: - t.Fatalf("Close function completed too soon") - case <-time.After(100 * time.Millisecond): - } - - g.Leave() - } - - // Now the closer must complete. - <-done -} - -func TestMultipleSerialCloses(t *testing.T) { - var g gate.Gate - - // Enter 10 times. - for i := 0; i < 10; i++ { - if !g.Enter() { - t.Fatalf("Failed to enter when it should be allowed") - } - } - - // Launch closer. Check that it doesn't complete. - done := make(chan struct{}) - go closeFunc(&g, done) - - for i := 0; i < 10; i++ { - select { - case <-done: - t.Fatalf("Close function completed too soon") - case <-time.After(100 * time.Millisecond): - } - - g.Leave() - } - - // Now the closer must complete. - <-done - - // Close again should not block. - done = make(chan struct{}) - go closeFunc(&g, done) - - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("Second Close is blocking") - } -} - -func worker(g *gate.Gate, done *sync.WaitGroup) { - for { - if !g.Enter() { - break - } - g.Leave() - } - done.Done() -} - -func TestConcurrentAll(t *testing.T) { - var g gate.Gate - var done sync.WaitGroup - - // Launch 1000 goroutines to concurrently enter/leave. - done.Add(1000) - for i := 0; i < 1000; i++ { - go worker(&g, &done) - } - - // Wait for the goroutines to do some work, then close the gate. - time.Sleep(2 * time.Second) - g.Close() - - // Wait for all of them to complete. - done.Wait() -} diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD deleted file mode 100644 index 798a65eca..000000000 --- a/pkg/gohacks/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "gohacks", - srcs = [ - "gohacks_unsafe.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/gohacks/gohacks_state_autogen.go b/pkg/gohacks/gohacks_state_autogen.go new file mode 100755 index 000000000..c651ff01e --- /dev/null +++ b/pkg/gohacks/gohacks_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package gohacks diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go index aad675172..aad675172 100644..100755 --- a/pkg/gohacks/gohacks_unsafe.go +++ b/pkg/gohacks/gohacks_unsafe.go diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD deleted file mode 100644 index ea8d2422c..000000000 --- a/pkg/goid/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "goid", - srcs = [ - "goid.go", - "goid_amd64.s", - "goid_race.go", - "goid_unsafe.go", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "goid_test", - size = "small", - srcs = [ - "empty_test.go", - "goid_test.go", - ], - library = ":goid", -) diff --git a/pkg/goid/empty_test.go b/pkg/goid/empty_test.go deleted file mode 100644 index c0a4b17ab..000000000 --- a/pkg/goid/empty_test.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !race - -package goid - -import "testing" - -// TestNothing exists to make the build system happy. -func TestNothing(t *testing.T) {} diff --git a/pkg/goid/goid.go b/pkg/goid/goid.go deleted file mode 100644 index 39df30031..000000000 --- a/pkg/goid/goid.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !race - -// Package goid provides access to the ID of the current goroutine in -// race/gotsan builds. -package goid - -// Get returns the ID of the current goroutine. -func Get() int64 { - panic("unimplemented for non-race builds") -} diff --git a/pkg/goid/goid_amd64.s b/pkg/goid/goid_amd64.s deleted file mode 100644 index d9f5cd2a3..000000000 --- a/pkg/goid/goid_amd64.s +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -// func getg() *g -TEXT ·getg(SB),NOSPLIT,$0-8 - MOVQ (TLS), R14 - MOVQ R14, ret+0(FP) - RET diff --git a/pkg/goid/goid_race.go b/pkg/goid/goid_race.go deleted file mode 100644 index 1766beaee..000000000 --- a/pkg/goid/goid_race.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Only available in race/gotsan builds. -// +build race - -// Package goid provides access to the ID of the current goroutine in -// race/gotsan builds. -package goid - -// Get returns the ID of the current goroutine. -func Get() int64 { - return goid() -} diff --git a/pkg/goid/goid_test.go b/pkg/goid/goid_test.go deleted file mode 100644 index 31970ce79..000000000 --- a/pkg/goid/goid_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build race - -package goid - -import ( - "runtime" - "sync" - "testing" -) - -func TestInitialGoID(t *testing.T) { - const max = 10000 - if id := goid(); id < 0 || id > max { - t.Errorf("got goid = %d, want 0 < goid <= %d", id, max) - } -} - -// TestGoIDSquence verifies that goid returns values which could plausibly be -// goroutine IDs. If this test breaks or becomes flaky, the structs in -// goid_unsafe.go may need to be updated. -func TestGoIDSquence(t *testing.T) { - // Goroutine IDs are cached by each P. - runtime.GOMAXPROCS(1) - - // Fill any holes in lower range. - for i := 0; i < 50; i++ { - var wg sync.WaitGroup - wg.Add(1) - go func() { - wg.Done() - - // Leak the goroutine to prevent the ID from being - // reused. - select {} - }() - wg.Wait() - } - - id := goid() - for i := 0; i < 100; i++ { - var ( - newID int64 - wg sync.WaitGroup - ) - wg.Add(1) - go func() { - newID = goid() - wg.Done() - - // Leak the goroutine to prevent the ID from being - // reused. - select {} - }() - wg.Wait() - if max := id + 100; newID <= id || newID > max { - t.Errorf("unexpected goroutine ID pattern, got goid = %d, want %d < goid <= %d (previous = %d)", newID, id, max, id) - } - id = newID - } -} diff --git a/pkg/goid/goid_unsafe.go b/pkg/goid/goid_unsafe.go deleted file mode 100644 index ded8004dd..000000000 --- a/pkg/goid/goid_unsafe.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package goid - -// Structs from Go runtime. These may change in the future and require -// updating. These structs are currently the same on both AMD64 and ARM64, -// but may diverge in the future. - -type stack struct { - lo uintptr - hi uintptr -} - -type gobuf struct { - sp uintptr - pc uintptr - g uintptr - ctxt uintptr - ret uint64 - lr uintptr - bp uintptr -} - -type g struct { - stack stack - stackguard0 uintptr - stackguard1 uintptr - - _panic uintptr - _defer uintptr - m uintptr - sched gobuf - syscallsp uintptr - syscallpc uintptr - stktopsp uintptr - param uintptr - atomicstatus uint32 - stackLock uint32 - goid int64 - - // More fields... - // - // We only use goid and the fields before it are only listed to - // calculate the correct offset. -} - -func getg() *g - -// goid returns the ID of the current goroutine. -func goid() int64 { - return getg().goid -} diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD deleted file mode 100644 index 3f6eb07df..000000000 --- a/pkg/ilist/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_library( - name = "ilist", - srcs = [ - "interface_list.go", - ], - visibility = ["//visibility:public"], -) - -go_template_instance( - name = "interface_list", - out = "interface_list.go", - package = "ilist", - template = ":generic_list", - types = {}, -) - -# This list is used for benchmarking. -go_template_instance( - name = "test_list", - out = "test_list.go", - package = "ilist", - prefix = "direct", - template = ":generic_list", - types = { - "Element": "*direct", - "Linker": "*direct", - }, -) - -go_test( - name = "list_test", - size = "small", - srcs = [ - "list_test.go", - "test_list.go", - ], - library = ":ilist", -) - -go_template( - name = "generic_list", - srcs = [ - "list.go", - ], - opt_types = [ - "Element", - "ElementMapper", - "Linker", - ], - visibility = ["//visibility:public"], -) diff --git a/pkg/ilist/ilist_state_autogen.go b/pkg/ilist/ilist_state_autogen.go new file mode 100755 index 000000000..4294bcb90 --- /dev/null +++ b/pkg/ilist/ilist_state_autogen.go @@ -0,0 +1,38 @@ +// automatically generated by stateify. + +package ilist + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *List) beforeSave() {} +func (x *List) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *List) afterLoad() {} +func (x *List) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *Entry) beforeSave() {} +func (x *Entry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *Entry) afterLoad() {} +func (x *Entry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/ilist.List", (*List)(nil), state.Fns{Save: (*List).save, Load: (*List).load}) + state.Register("pkg/ilist.Entry", (*Entry)(nil), state.Fns{Save: (*Entry).save, Load: (*Entry).load}) +} diff --git a/pkg/ilist/list.go b/pkg/ilist/interface_list.go index 8f93e4d6d..aeb636f52 100644..100755 --- a/pkg/ilist/list.go +++ b/pkg/ilist/interface_list.go @@ -1,18 +1,3 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package ilist provides the implementation of intrusive linked lists. package ilist // Linker is the interface that objects must implement if they want to be added diff --git a/pkg/ilist/list_test.go b/pkg/ilist/list_test.go deleted file mode 100644 index 3f9abfb56..000000000 --- a/pkg/ilist/list_test.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ilist - -import ( - "testing" -) - -type testEntry struct { - Entry - value int -} - -type direct struct { - directEntry - value int -} - -func verifyEquality(t *testing.T, entries []testEntry, l *List) { - t.Helper() - - i := 0 - for it := l.Front(); it != nil; it = it.Next() { - e := it.(*testEntry) - if e != &entries[i] { - t.Errorf("Wrong entry at index %d", i) - return - } - i++ - } - - if i != len(entries) { - t.Errorf("Wrong number of entries; want = %d, got = %d", len(entries), i) - return - } - - i = 0 - for it := l.Back(); it != nil; it = it.Prev() { - e := it.(*testEntry) - if e != &entries[len(entries)-1-i] { - t.Errorf("Wrong entry at index %d", i) - return - } - i++ - } - - if i != len(entries) { - t.Errorf("Wrong number of entries; want = %d, got = %d", len(entries), i) - return - } -} - -func TestZeroEmpty(t *testing.T) { - var l List - if l.Front() != nil { - t.Error("Front is non-nil") - } - if l.Back() != nil { - t.Error("Back is non-nil") - } -} - -func TestPushBack(t *testing.T) { - var l List - - // Test single entry insertion. - var entry testEntry - l.PushBack(&entry) - - e := l.Front().(*testEntry) - if e != &entry { - t.Error("Wrong entry returned") - } - - // Test inserting 100 entries. - l.Reset() - var entries [100]testEntry - for i := range entries { - l.PushBack(&entries[i]) - } - - verifyEquality(t, entries[:], &l) -} - -func TestPushFront(t *testing.T) { - var l List - - // Test single entry insertion. - var entry testEntry - l.PushFront(&entry) - - e := l.Front().(*testEntry) - if e != &entry { - t.Error("Wrong entry returned") - } - - // Test inserting 100 entries. - l.Reset() - var entries [100]testEntry - for i := range entries { - l.PushFront(&entries[len(entries)-1-i]) - } - - verifyEquality(t, entries[:], &l) -} - -func TestRemove(t *testing.T) { - // Remove entry from single-element list. - var l List - var entry testEntry - l.PushBack(&entry) - l.Remove(&entry) - if l.Front() != nil { - t.Error("List is empty") - } - - var entries [100]testEntry - - // Remove single element from lists of lengths 2 to 101. - for n := 1; n <= len(entries); n++ { - for extra := 0; extra <= n; extra++ { - l.Reset() - for i := 0; i < n; i++ { - if extra == i { - l.PushBack(&entry) - } - l.PushBack(&entries[i]) - } - if extra == n { - l.PushBack(&entry) - } - - l.Remove(&entry) - verifyEquality(t, entries[:n], &l) - } - } -} - -func TestReset(t *testing.T) { - var l List - - // Resetting list of one element. - l.PushBack(&testEntry{}) - if l.Front() == nil { - t.Error("List is empty") - } - - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } - - // Resetting list of 10 elements. - for i := 0; i < 10; i++ { - l.PushBack(&testEntry{}) - } - - if l.Front() == nil { - t.Error("List is empty") - } - - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } - - // Resetting empty list. - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } -} - -func BenchmarkIterateForward(b *testing.B) { - var l List - for i := 0; i < 1000000; i++ { - l.PushBack(&testEntry{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Front(); e != nil; e = e.Next() { - tmp += e.(*testEntry).value - } - } -} - -func BenchmarkIterateBackward(b *testing.B) { - var l List - for i := 0; i < 1000000; i++ { - l.PushBack(&testEntry{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Back(); e != nil; e = e.Prev() { - tmp += e.(*testEntry).value - } - } -} - -func BenchmarkDirectIterateForward(b *testing.B) { - var l directList - for i := 0; i < 1000000; i++ { - l.PushBack(&direct{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Front(); e != nil; e = e.Next() { - tmp += e.value - } - } -} - -func BenchmarkDirectIterateBackward(b *testing.B) { - var l directList - for i := 0; i < 1000000; i++ { - l.PushBack(&direct{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Back(); e != nil; e = e.Prev() { - tmp += e.value - } - } -} diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD deleted file mode 100644 index 41bf104d0..000000000 --- a/pkg/linewriter/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "linewriter", - srcs = ["linewriter.go"], - visibility = ["//visibility:public"], - deps = ["//pkg/sync"], -) - -go_test( - name = "linewriter_test", - srcs = ["linewriter_test.go"], - library = ":linewriter", -) diff --git a/pkg/linewriter/linewriter_state_autogen.go b/pkg/linewriter/linewriter_state_autogen.go new file mode 100755 index 000000000..1cd1df9b8 --- /dev/null +++ b/pkg/linewriter/linewriter_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package linewriter diff --git a/pkg/linewriter/linewriter_test.go b/pkg/linewriter/linewriter_test.go deleted file mode 100644 index 96dc7e6e0..000000000 --- a/pkg/linewriter/linewriter_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package linewriter - -import ( - "bytes" - "testing" -) - -func TestWriter(t *testing.T) { - testCases := []struct { - input []string - want []string - }{ - { - input: []string{"1\n", "2\n"}, - want: []string{"1", "2"}, - }, - { - input: []string{"1\n", "\n", "2\n"}, - want: []string{"1", "", "2"}, - }, - { - input: []string{"1\n2\n", "3\n"}, - want: []string{"1", "2", "3"}, - }, - { - input: []string{"1", "2\n"}, - want: []string{"12"}, - }, - { - // Data with no newline yet is omitted. - input: []string{"1\n", "2\n", "3"}, - want: []string{"1", "2"}, - }, - } - - for _, c := range testCases { - var lines [][]byte - - w := NewWriter(func(p []byte) { - // We must not retain p, so we must make a copy. - b := make([]byte, len(p)) - copy(b, p) - - lines = append(lines, b) - }) - - for _, in := range c.input { - n, err := w.Write([]byte(in)) - if err != nil { - t.Errorf("Write(%q) err got %v want nil (case %+v)", in, err, c) - } - if n != len(in) { - t.Errorf("Write(%q) b got %d want %d (case %+v)", in, n, len(in), c) - } - } - - if len(lines) != len(c.want) { - t.Errorf("len(lines) got %d want %d (case %+v)", len(lines), len(c.want), c) - } - - for i := range lines { - if !bytes.Equal(lines[i], []byte(c.want[i])) { - t.Errorf("item %d got %q want %q (case %+v)", i, lines[i], c.want[i], c) - } - } - } -} diff --git a/pkg/log/BUILD b/pkg/log/BUILD deleted file mode 100644 index a7c8f7bef..000000000 --- a/pkg/log/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "log", - srcs = [ - "glog.go", - "json.go", - "json_k8s.go", - "log.go", - ], - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/linewriter", - "//pkg/sync", - ], -) - -go_test( - name = "log_test", - size = "small", - srcs = [ - "json_test.go", - "log_test.go", - ], - library = ":log", -) diff --git a/pkg/log/json_test.go b/pkg/log/json_test.go deleted file mode 100644 index f25224fe1..000000000 --- a/pkg/log/json_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package log - -import ( - "encoding/json" - "testing" -) - -// Tests that Level can marshal/unmarshal properly. -func TestLevelMarshal(t *testing.T) { - lvs := []Level{Warning, Info, Debug} - for _, lv := range lvs { - bs, err := lv.MarshalJSON() - if err != nil { - t.Errorf("error marshaling %v: %v", lv, err) - } - var lv2 Level - if err := lv2.UnmarshalJSON(bs); err != nil { - t.Errorf("error unmarshaling %v: %v", bs, err) - } - if lv != lv2 { - t.Errorf("marshal/unmarshal level got %v wanted %v", lv2, lv) - } - } -} - -// Test that integers can be properly unmarshaled. -func TestUnmarshalFromInt(t *testing.T) { - tcs := []struct { - i int - want Level - }{ - {0, Warning}, - {1, Info}, - {2, Debug}, - } - - for _, tc := range tcs { - j, err := json.Marshal(tc.i) - if err != nil { - t.Errorf("error marshaling %v: %v", tc.i, err) - } - var lv Level - if err := lv.UnmarshalJSON(j); err != nil { - t.Errorf("error unmarshaling %v: %v", j, err) - } - if lv != tc.want { - t.Errorf("marshal/unmarshal %v got %v want %v", tc.i, lv, tc.want) - } - } -} diff --git a/pkg/log/log_state_autogen.go b/pkg/log/log_state_autogen.go new file mode 100755 index 000000000..4e243c216 --- /dev/null +++ b/pkg/log/log_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package log diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go deleted file mode 100644 index 402cc29ae..000000000 --- a/pkg/log/log_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package log - -import ( - "fmt" - "strings" - "testing" -) - -type testWriter struct { - lines []string - fail bool - limit int -} - -func (w *testWriter) Write(bytes []byte) (int, error) { - if w.fail { - return 0, fmt.Errorf("simulated failure") - } - if w.limit > 0 && len(w.lines) >= w.limit { - return len(bytes), nil - } - w.lines = append(w.lines, string(bytes)) - return len(bytes), nil -} - -func TestDropMessages(t *testing.T) { - tw := &testWriter{} - w := Writer{Next: tw} - if _, err := w.Write([]byte("line 1\n")); err != nil { - t.Fatalf("Write failed, err: %v", err) - } - - tw.fail = true - if _, err := w.Write([]byte("error\n")); err == nil { - t.Fatalf("Write should have failed") - } - if _, err := w.Write([]byte("error\n")); err == nil { - t.Fatalf("Write should have failed") - } - - fmt.Printf("writer: %+v\n", w) - - tw.fail = false - if _, err := w.Write([]byte("line 2\n")); err != nil { - t.Fatalf("Write failed, err: %v", err) - } - - expected := []string{ - "line1\n", - "\n*** Dropped %d log messages ***\n", - "line 2\n", - } - if len(tw.lines) != len(expected) { - t.Fatalf("Writer should have logged %d lines, got: %v, expected: %v", len(expected), tw.lines, expected) - } - for i, l := range tw.lines { - if l == expected[i] { - t.Fatalf("line %d doesn't match, got: %v, expected: %v", i, l, expected[i]) - } - } -} - -func TestCaller(t *testing.T) { - tw := &testWriter{} - e := &GoogleEmitter{Writer: Writer{Next: tw}} - bl := &BasicLogger{ - Emitter: e, - Level: Debug, - } - bl.Debugf("testing...\n") // Just for file + line. - if len(tw.lines) != 1 { - t.Errorf("expected 1 line, got %d", len(tw.lines)) - } - if !strings.Contains(tw.lines[0], "log_test.go") { - t.Errorf("expected log_test.go, got %q", tw.lines[0]) - } -} - -func BenchmarkGoogleLogging(b *testing.B) { - tw := &testWriter{ - limit: 1, // Only record one message. - } - e := &GoogleEmitter{Writer: Writer{Next: tw}} - bl := &BasicLogger{ - Emitter: e, - Level: Debug, - } - for i := 0; i < b.N; i++ { - bl.Debugf("hello %d, %d, %d", 1, 2, 3) - } -} diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD deleted file mode 100644 index 9d07d98b4..000000000 --- a/pkg/memutil/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "memutil", - srcs = ["memutil_unsafe.go"], - visibility = ["//visibility:public"], - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/memutil/memutil_state_autogen.go b/pkg/memutil/memutil_state_autogen.go new file mode 100755 index 000000000..173297149 --- /dev/null +++ b/pkg/memutil/memutil_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build linux + +package memutil diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD deleted file mode 100644 index 58305009d..000000000 --- a/pkg/metric/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "metric", - srcs = ["metric.go"], - visibility = ["//:sandbox"], - deps = [ - ":metric_go_proto", - "//pkg/eventchannel", - "//pkg/log", - "//pkg/sync", - ], -) - -proto_library( - name = "metric", - srcs = ["metric.proto"], - visibility = ["//:sandbox"], -) - -go_test( - name = "metric_test", - srcs = ["metric_test.go"], - library = ":metric", - deps = [ - ":metric_go_proto", - "//pkg/eventchannel", - "@com_github_golang_protobuf//proto:go_default_library", - ], -) diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto deleted file mode 100644 index a2c2bd1ba..000000000 --- a/pkg/metric/metric.proto +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -// MetricMetadata contains all of the metadata describing a single metric. -message MetricMetadata { - // name is the unique name of the metric, usually in a "directory" format - // (e.g., /foo/count). - string name = 1; - - // description is a human-readable description of the metric. - string description = 2; - - // cumulative indicates that this metric is never decremented. - bool cumulative = 3; - - // sync indicates that values from the final metric event should be - // synchronized to the backing monitoring system at exit. - // - // If sync is false, values are only sent to the monitoring system - // periodically. There is no guarantee that values will ever be received by - // the monitoring system. - bool sync = 4; - - enum Type { UINT64 = 0; } - - // type is the type of the metric value. - Type type = 5; -} - -// MetricRegistration contains the metadata for all metrics that will be in -// future MetricUpdates. -message MetricRegistration { - repeated MetricMetadata metrics = 1; -} - -// MetricValue the value of a metric at a single point in time. -message MetricValue { - // name is the unique name of the metric, as in MetricMetadata. - string name = 1; - - // value is the value of the metric at a single point in time. The field set - // depends on the type of the metric. - oneof value { - uint64 uint64_value = 2; - } -} - -// MetricUpdate contains new values for multiple distinct metrics. -// -// Metrics whose values have not changed are not included. -message MetricUpdate { - repeated MetricValue metrics = 1; -} diff --git a/pkg/metric/metric_go_proto/metric.pb.go b/pkg/metric/metric_go_proto/metric.pb.go new file mode 100755 index 000000000..553236535 --- /dev/null +++ b/pkg/metric/metric_go_proto/metric.pb.go @@ -0,0 +1,297 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/metric/metric.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type MetricMetadata_Type int32 + +const ( + MetricMetadata_UINT64 MetricMetadata_Type = 0 +) + +var MetricMetadata_Type_name = map[int32]string{ + 0: "UINT64", +} + +var MetricMetadata_Type_value = map[string]int32{ + "UINT64": 0, +} + +func (x MetricMetadata_Type) String() string { + return proto.EnumName(MetricMetadata_Type_name, int32(x)) +} + +func (MetricMetadata_Type) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{0, 0} +} + +type MetricMetadata struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"` + Cumulative bool `protobuf:"varint,3,opt,name=cumulative,proto3" json:"cumulative,omitempty"` + Sync bool `protobuf:"varint,4,opt,name=sync,proto3" json:"sync,omitempty"` + Type MetricMetadata_Type `protobuf:"varint,5,opt,name=type,proto3,enum=gvisor.MetricMetadata_Type" json:"type,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MetricMetadata) Reset() { *m = MetricMetadata{} } +func (m *MetricMetadata) String() string { return proto.CompactTextString(m) } +func (*MetricMetadata) ProtoMessage() {} +func (*MetricMetadata) Descriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{0} +} + +func (m *MetricMetadata) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MetricMetadata.Unmarshal(m, b) +} +func (m *MetricMetadata) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MetricMetadata.Marshal(b, m, deterministic) +} +func (m *MetricMetadata) XXX_Merge(src proto.Message) { + xxx_messageInfo_MetricMetadata.Merge(m, src) +} +func (m *MetricMetadata) XXX_Size() int { + return xxx_messageInfo_MetricMetadata.Size(m) +} +func (m *MetricMetadata) XXX_DiscardUnknown() { + xxx_messageInfo_MetricMetadata.DiscardUnknown(m) +} + +var xxx_messageInfo_MetricMetadata proto.InternalMessageInfo + +func (m *MetricMetadata) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *MetricMetadata) GetDescription() string { + if m != nil { + return m.Description + } + return "" +} + +func (m *MetricMetadata) GetCumulative() bool { + if m != nil { + return m.Cumulative + } + return false +} + +func (m *MetricMetadata) GetSync() bool { + if m != nil { + return m.Sync + } + return false +} + +func (m *MetricMetadata) GetType() MetricMetadata_Type { + if m != nil { + return m.Type + } + return MetricMetadata_UINT64 +} + +type MetricRegistration struct { + Metrics []*MetricMetadata `protobuf:"bytes,1,rep,name=metrics,proto3" json:"metrics,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MetricRegistration) Reset() { *m = MetricRegistration{} } +func (m *MetricRegistration) String() string { return proto.CompactTextString(m) } +func (*MetricRegistration) ProtoMessage() {} +func (*MetricRegistration) Descriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{1} +} + +func (m *MetricRegistration) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MetricRegistration.Unmarshal(m, b) +} +func (m *MetricRegistration) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MetricRegistration.Marshal(b, m, deterministic) +} +func (m *MetricRegistration) XXX_Merge(src proto.Message) { + xxx_messageInfo_MetricRegistration.Merge(m, src) +} +func (m *MetricRegistration) XXX_Size() int { + return xxx_messageInfo_MetricRegistration.Size(m) +} +func (m *MetricRegistration) XXX_DiscardUnknown() { + xxx_messageInfo_MetricRegistration.DiscardUnknown(m) +} + +var xxx_messageInfo_MetricRegistration proto.InternalMessageInfo + +func (m *MetricRegistration) GetMetrics() []*MetricMetadata { + if m != nil { + return m.Metrics + } + return nil +} + +type MetricValue struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // Types that are valid to be assigned to Value: + // *MetricValue_Uint64Value + Value isMetricValue_Value `protobuf_oneof:"value"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MetricValue) Reset() { *m = MetricValue{} } +func (m *MetricValue) String() string { return proto.CompactTextString(m) } +func (*MetricValue) ProtoMessage() {} +func (*MetricValue) Descriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{2} +} + +func (m *MetricValue) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MetricValue.Unmarshal(m, b) +} +func (m *MetricValue) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MetricValue.Marshal(b, m, deterministic) +} +func (m *MetricValue) XXX_Merge(src proto.Message) { + xxx_messageInfo_MetricValue.Merge(m, src) +} +func (m *MetricValue) XXX_Size() int { + return xxx_messageInfo_MetricValue.Size(m) +} +func (m *MetricValue) XXX_DiscardUnknown() { + xxx_messageInfo_MetricValue.DiscardUnknown(m) +} + +var xxx_messageInfo_MetricValue proto.InternalMessageInfo + +func (m *MetricValue) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type isMetricValue_Value interface { + isMetricValue_Value() +} + +type MetricValue_Uint64Value struct { + Uint64Value uint64 `protobuf:"varint,2,opt,name=uint64_value,json=uint64Value,proto3,oneof"` +} + +func (*MetricValue_Uint64Value) isMetricValue_Value() {} + +func (m *MetricValue) GetValue() isMetricValue_Value { + if m != nil { + return m.Value + } + return nil +} + +func (m *MetricValue) GetUint64Value() uint64 { + if x, ok := m.GetValue().(*MetricValue_Uint64Value); ok { + return x.Uint64Value + } + return 0 +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*MetricValue) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*MetricValue_Uint64Value)(nil), + } +} + +type MetricUpdate struct { + Metrics []*MetricValue `protobuf:"bytes,1,rep,name=metrics,proto3" json:"metrics,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MetricUpdate) Reset() { *m = MetricUpdate{} } +func (m *MetricUpdate) String() string { return proto.CompactTextString(m) } +func (*MetricUpdate) ProtoMessage() {} +func (*MetricUpdate) Descriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{3} +} + +func (m *MetricUpdate) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MetricUpdate.Unmarshal(m, b) +} +func (m *MetricUpdate) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MetricUpdate.Marshal(b, m, deterministic) +} +func (m *MetricUpdate) XXX_Merge(src proto.Message) { + xxx_messageInfo_MetricUpdate.Merge(m, src) +} +func (m *MetricUpdate) XXX_Size() int { + return xxx_messageInfo_MetricUpdate.Size(m) +} +func (m *MetricUpdate) XXX_DiscardUnknown() { + xxx_messageInfo_MetricUpdate.DiscardUnknown(m) +} + +var xxx_messageInfo_MetricUpdate proto.InternalMessageInfo + +func (m *MetricUpdate) GetMetrics() []*MetricValue { + if m != nil { + return m.Metrics + } + return nil +} + +func init() { + proto.RegisterEnum("gvisor.MetricMetadata_Type", MetricMetadata_Type_name, MetricMetadata_Type_value) + proto.RegisterType((*MetricMetadata)(nil), "gvisor.MetricMetadata") + proto.RegisterType((*MetricRegistration)(nil), "gvisor.MetricRegistration") + proto.RegisterType((*MetricValue)(nil), "gvisor.MetricValue") + proto.RegisterType((*MetricUpdate)(nil), "gvisor.MetricUpdate") +} + +func init() { proto.RegisterFile("pkg/metric/metric.proto", fileDescriptor_87b8778a4ff2ab5c) } + +var fileDescriptor_87b8778a4ff2ab5c = []byte{ + // 288 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x91, 0xc1, 0x4b, 0xc3, 0x30, + 0x14, 0xc6, 0x17, 0xd7, 0x75, 0xfa, 0x3a, 0x86, 0x44, 0xd0, 0x80, 0x20, 0xa5, 0x5e, 0x7a, 0xb1, + 0x93, 0x39, 0x76, 0xf3, 0xe2, 0x41, 0xf4, 0x30, 0x85, 0xb0, 0x79, 0x95, 0xd8, 0x86, 0x12, 0x5c, + 0xdb, 0xd0, 0xa4, 0x85, 0xfe, 0x75, 0xfe, 0x6b, 0xd2, 0x17, 0x95, 0x4d, 0x76, 0xca, 0xcb, 0xfb, + 0xde, 0xf7, 0xf1, 0xcb, 0x0b, 0x5c, 0xe8, 0xcf, 0x7c, 0x56, 0x48, 0x5b, 0xab, 0xf4, 0xe7, 0x48, + 0x74, 0x5d, 0xd9, 0x8a, 0xfa, 0x79, 0xab, 0x4c, 0x55, 0x47, 0x5f, 0x04, 0xa6, 0x2b, 0x14, 0x56, + 0xd2, 0x8a, 0x4c, 0x58, 0x41, 0x29, 0x78, 0xa5, 0x28, 0x24, 0x23, 0x21, 0x89, 0x4f, 0x38, 0xd6, + 0x34, 0x84, 0x20, 0x93, 0x26, 0xad, 0x95, 0xb6, 0xaa, 0x2a, 0xd9, 0x11, 0x4a, 0xbb, 0x2d, 0x7a, + 0x05, 0x90, 0x36, 0x45, 0xb3, 0x15, 0x56, 0xb5, 0x92, 0x0d, 0x43, 0x12, 0x1f, 0xf3, 0x9d, 0x4e, + 0x9f, 0x6a, 0xba, 0x32, 0x65, 0x1e, 0x2a, 0x58, 0xd3, 0x19, 0x78, 0xb6, 0xd3, 0x92, 0x8d, 0x42, + 0x12, 0x4f, 0xe7, 0x97, 0x89, 0x63, 0x4a, 0xf6, 0x79, 0x92, 0x75, 0xa7, 0x25, 0xc7, 0xc1, 0x88, + 0x82, 0xd7, 0xdf, 0x28, 0x80, 0xbf, 0x79, 0x7e, 0x59, 0x2f, 0x17, 0xa7, 0x83, 0xe8, 0x11, 0xa8, + 0x33, 0x70, 0x99, 0x2b, 0x63, 0x6b, 0x81, 0x38, 0xb7, 0x30, 0x76, 0xef, 0x35, 0x8c, 0x84, 0xc3, + 0x38, 0x98, 0x9f, 0x1f, 0x4e, 0xe7, 0xbf, 0x63, 0xd1, 0x2b, 0x04, 0x4e, 0x7a, 0x13, 0xdb, 0x46, + 0x1e, 0xdc, 0xc2, 0x35, 0x4c, 0x1a, 0x55, 0xda, 0xe5, 0xe2, 0xbd, 0xed, 0x67, 0x70, 0x0d, 0xde, + 0xd3, 0x80, 0x07, 0xae, 0x8b, 0xc6, 0x87, 0x31, 0x8c, 0x50, 0x8d, 0xee, 0x61, 0xe2, 0x02, 0x37, + 0x3a, 0x13, 0x56, 0xd2, 0x9b, 0xff, 0x48, 0x67, 0xfb, 0x48, 0x68, 0xff, 0xe3, 0xf9, 0xf0, 0xf1, + 0xa3, 0xee, 0xbe, 0x03, 0x00, 0x00, 0xff, 0xff, 0xcb, 0x7f, 0xcb, 0x46, 0xc3, 0x01, 0x00, 0x00, +} diff --git a/pkg/metric/metric_state_autogen.go b/pkg/metric/metric_state_autogen.go new file mode 100755 index 000000000..36e5ed81b --- /dev/null +++ b/pkg/metric/metric_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package metric diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go deleted file mode 100644 index 34969385a..000000000 --- a/pkg/metric/metric_test.go +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package metric - -import ( - "testing" - - "github.com/golang/protobuf/proto" - "gvisor.dev/gvisor/pkg/eventchannel" - pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" -) - -// sliceEmitter implements eventchannel.Emitter by appending all messages to a -// slice. -type sliceEmitter []proto.Message - -// Emit implements eventchannel.Emitter.Emit. -func (s *sliceEmitter) Emit(msg proto.Message) (bool, error) { - *s = append(*s, msg) - return false, nil -} - -// Emit implements eventchannel.Emitter.Close. -func (s *sliceEmitter) Close() error { - return nil -} - -// Reset clears all events in s. -func (s *sliceEmitter) Reset() { - *s = nil -} - -// emitter is the eventchannel.Emitter used for all tests. Package eventchannel -// doesn't allow removing Emitters, so we must use one global emitter for all -// test cases. -var emitter sliceEmitter - -func init() { - eventchannel.AddEmitter(&emitter) -} - -// reset clears all global state in the metric package. -func reset() { - initialized = false - allMetrics = makeMetricSet() - emitter.Reset() -} - -const ( - fooDescription = "Foo!" - barDescription = "Bar Baz" -) - -func TestInitialize(t *testing.T) { - defer reset() - - _, err := NewUint64Metric("/foo", false, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - Initialize() - - if len(emitter) != 1 { - t.Fatalf("Initialize emitted %d events want 1", len(emitter)) - } - - mr, ok := emitter[0].(*pb.MetricRegistration) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricRegistration", emitter[0], emitter[0]) - } - - if len(mr.Metrics) != 2 { - t.Errorf("MetricRegistration got %d metrics want 2", len(mr.Metrics)) - } - - foundFoo := false - foundBar := false - for _, m := range mr.Metrics { - if m.Type != pb.MetricMetadata_UINT64 { - t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_UINT64) - } - if !m.Cumulative { - t.Errorf("Metadata %+v Cumulative got false want true", m) - } - - switch m.Name { - case "/foo": - foundFoo = true - if m.Description != fooDescription { - t.Errorf("/foo %+v Description got %q want %q", m, m.Description, fooDescription) - } - if m.Sync { - t.Errorf("/foo %+v Sync got true want false", m) - } - case "/bar": - foundBar = true - if m.Description != barDescription { - t.Errorf("/bar %+v Description got %q want %q", m, m.Description, barDescription) - } - if !m.Sync { - t.Errorf("/bar %+v Sync got true want false", m) - } - } - } - - if !foundFoo { - t.Errorf("/foo not found: %+v", emitter) - } - if !foundBar { - t.Errorf("/bar not found: %+v", emitter) - } -} - -func TestDisable(t *testing.T) { - defer reset() - - _, err := NewUint64Metric("/foo", false, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - Disable() - - if len(emitter) != 1 { - t.Fatalf("Initialize emitted %d events want 1", len(emitter)) - } - - mr, ok := emitter[0].(*pb.MetricRegistration) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricRegistration", emitter[0], emitter[0]) - } - - if len(mr.Metrics) != 0 { - t.Errorf("MetricRegistration got %d metrics want 0", len(mr.Metrics)) - } -} - -func TestEmitMetricUpdate(t *testing.T) { - defer reset() - - foo, err := NewUint64Metric("/foo", false, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - Initialize() - - // Don't care about the registration metrics. - emitter.Reset() - EmitMetricUpdate() - - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) - } - - update, ok := emitter[0].(*pb.MetricUpdate) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) - } - - if len(update.Metrics) != 2 { - t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics)) - } - - // Both are included for their initial values. - foundFoo := false - foundBar := false - for _, m := range update.Metrics { - switch m.Name { - case "/foo": - foundFoo = true - case "/bar": - foundBar = true - } - uv, ok := m.Value.(*pb.MetricValue_Uint64Value) - if !ok { - t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) - continue - } - if uv.Uint64Value != 0 { - t.Errorf("%v: Value got %v want 0", m, uv.Uint64Value) - } - } - - if !foundFoo { - t.Errorf("/foo not found: %+v", emitter) - } - if !foundBar { - t.Errorf("/bar not found: %+v", emitter) - } - - // Increment foo. Only it is included in the next update. - foo.Increment() - - emitter.Reset() - EmitMetricUpdate() - - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) - } - - update, ok = emitter[0].(*pb.MetricUpdate) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) - } - - if len(update.Metrics) != 1 { - t.Errorf("MetricUpdate got %d metrics want 1", len(update.Metrics)) - } - - m := update.Metrics[0] - - if m.Name != "/foo" { - t.Errorf("Metric %+v name got %q want '/foo'", m, m.Name) - } - - uv, ok := m.Value.(*pb.MetricValue_Uint64Value) - if !ok { - t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) - } - if uv.Uint64Value != 1 { - t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value) - } -} diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD deleted file mode 100644 index 8904afad9..000000000 --- a/pkg/p9/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -go_library( - name = "p9", - srcs = [ - "buffer.go", - "client.go", - "client_file.go", - "file.go", - "handlers.go", - "messages.go", - "p9.go", - "path_tree.go", - "server.go", - "transport.go", - "transport_flipcall.go", - "version.go", - ], - deps = [ - "//pkg/fd", - "//pkg/fdchannel", - "//pkg/flipcall", - "//pkg/log", - "//pkg/pool", - "//pkg/sync", - "//pkg/unet", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "p9_test", - size = "small", - srcs = [ - "buffer_test.go", - "client_test.go", - "messages_test.go", - "p9_test.go", - "transport_test.go", - "version_test.go", - ], - library = ":p9", - deps = [ - "//pkg/fd", - "//pkg/unet", - ], -) diff --git a/pkg/p9/buffer_test.go b/pkg/p9/buffer_test.go deleted file mode 100644 index a9c75f86b..000000000 --- a/pkg/p9/buffer_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "testing" -) - -func TestBufferOverrun(t *testing.T) { - buf := &buffer{ - // This header indicates that a large string should follow, but - // it is only two bytes. Reading a string should cause an - // overrun. - data: []byte{0x0, 0x16}, - } - if s := buf.ReadString(); s != "" { - t.Errorf("overrun read got %s, want empty", s) - } -} diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go deleted file mode 100644 index 29a0afadf..000000000 --- a/pkg/p9/client_test.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/unet" -) - -// TestVersion tests the version negotiation. -func TestVersion(t *testing.T) { - // First, create a new server and connection. - serverSocket, clientSocket, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer clientSocket.Close() - - // Create a new server and client. - s := NewServer(nil) - go s.Handle(serverSocket) - - // NewClient does a Tversion exchange, so this is our test for success. - c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) - if err != nil { - t.Fatalf("got %v, expected nil", err) - } - - // Check a bogus version string. - if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { - t.Errorf("got %v expected %v", err, syscall.EINVAL) - } - - // Check a bogus version number. - if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { - t.Errorf("got %v expected %v", err, syscall.EINVAL) - } - - // Check a too high version number. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN { - t.Errorf("got %v expected %v", err, syscall.EAGAIN) - } - - // Check an invalid MSize. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion), MSize: 0}, &Rversion{}); err != syscall.EINVAL { - t.Errorf("got %v expected %v", err, syscall.EINVAL) - } -} - -func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { - // See above. - serverSocket, clientSocket, err := unet.SocketPair(false) - if err != nil { - b.Fatalf("socketpair got err %v expected nil", err) - } - defer clientSocket.Close() - - // See above. - s := NewServer(nil) - go s.Handle(serverSocket) - - // See above. - c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) - if err != nil { - b.Fatalf("got %v, expected nil", err) - } - - // Initialize messages. - sendRecv := fn(c) - tversion := &Tversion{ - Version: versionString(highestSupportedVersion), - MSize: DefaultMessageSize, - } - rversion := new(Rversion) - - // Run in a loop. - for i := 0; i < b.N; i++ { - if err := sendRecv(tversion, rversion); err != nil { - b.Fatalf("got unexpected err: %v", err) - } - } -} - -func BenchmarkSendRecvLegacy(b *testing.B) { - benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy }) -} - -func BenchmarkSendRecvChannel(b *testing.B) { - benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel }) -} diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go deleted file mode 100644 index c20324404..000000000 --- a/pkg/p9/messages_test.go +++ /dev/null @@ -1,483 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "fmt" - "reflect" - "testing" -) - -func TestEncodeDecode(t *testing.T) { - objs := []encoder{ - &QID{ - Type: 1, - Version: 2, - Path: 3, - }, - &FSStat{ - Type: 1, - BlockSize: 2, - Blocks: 3, - BlocksFree: 4, - BlocksAvailable: 5, - Files: 6, - FilesFree: 7, - FSID: 8, - NameLength: 9, - }, - &AttrMask{ - Mode: true, - NLink: true, - UID: true, - GID: true, - RDev: true, - ATime: true, - MTime: true, - CTime: true, - INo: true, - Size: true, - Blocks: true, - BTime: true, - Gen: true, - DataVersion: true, - }, - &Attr{ - Mode: Exec, - UID: 2, - GID: 3, - NLink: 4, - RDev: 5, - Size: 6, - BlockSize: 7, - Blocks: 8, - ATimeSeconds: 9, - ATimeNanoSeconds: 10, - MTimeSeconds: 11, - MTimeNanoSeconds: 12, - CTimeSeconds: 13, - CTimeNanoSeconds: 14, - BTimeSeconds: 15, - BTimeNanoSeconds: 16, - Gen: 17, - DataVersion: 18, - }, - &SetAttrMask{ - Permissions: true, - UID: true, - GID: true, - Size: true, - ATime: true, - MTime: true, - CTime: true, - ATimeNotSystemTime: true, - MTimeNotSystemTime: true, - }, - &SetAttr{ - Permissions: 1, - UID: 2, - GID: 3, - Size: 4, - ATimeSeconds: 5, - ATimeNanoSeconds: 6, - MTimeSeconds: 7, - MTimeNanoSeconds: 8, - }, - &Dirent{ - QID: QID{Type: 1}, - Offset: 2, - Type: 3, - Name: "a", - }, - &Rlerror{ - Error: 1, - }, - &Tstatfs{ - FID: 1, - }, - &Rstatfs{ - FSStat: FSStat{Type: 1}, - }, - &Tlopen{ - FID: 1, - Flags: WriteOnly, - }, - &Rlopen{ - QID: QID{Type: 1}, - IoUnit: 2, - }, - &Tlconnect{ - FID: 1, - }, - &Rlconnect{}, - &Tlcreate{ - FID: 1, - Name: "a", - OpenFlags: 2, - Permissions: 3, - GID: 4, - }, - &Rlcreate{ - Rlopen{QID: QID{Type: 1}}, - }, - &Tsymlink{ - Directory: 1, - Name: "a", - Target: "b", - GID: 2, - }, - &Rsymlink{ - QID: QID{Type: 1}, - }, - &Tmknod{ - Directory: 1, - Name: "a", - Mode: 2, - Major: 3, - Minor: 4, - GID: 5, - }, - &Rmknod{ - QID: QID{Type: 1}, - }, - &Trename{ - FID: 1, - Directory: 2, - Name: "a", - }, - &Rrename{}, - &Treadlink{ - FID: 1, - }, - &Rreadlink{ - Target: "a", - }, - &Tgetattr{ - FID: 1, - AttrMask: AttrMask{Mode: true}, - }, - &Rgetattr{ - Valid: AttrMask{Mode: true}, - QID: QID{Type: 1}, - Attr: Attr{Mode: Write}, - }, - &Tsetattr{ - FID: 1, - Valid: SetAttrMask{Permissions: true}, - SetAttr: SetAttr{Permissions: Write}, - }, - &Rsetattr{}, - &Txattrwalk{ - FID: 1, - NewFID: 2, - Name: "a", - }, - &Rxattrwalk{ - Size: 1, - }, - &Txattrcreate{ - FID: 1, - Name: "a", - AttrSize: 2, - Flags: 3, - }, - &Rxattrcreate{}, - &Tgetxattr{ - FID: 1, - Name: "abc", - Size: 2, - }, - &Rgetxattr{ - Value: "xyz", - }, - &Tsetxattr{ - FID: 1, - Name: "abc", - Value: "xyz", - Flags: 2, - }, - &Rsetxattr{}, - &Treaddir{ - Directory: 1, - Offset: 2, - Count: 3, - }, - &Rreaddir{ - // Count must be sufficient to encode a dirent. - Count: 0x18, - Entries: []Dirent{{QID: QID{Type: 2}}}, - }, - &Tfsync{ - FID: 1, - }, - &Rfsync{}, - &Tlink{ - Directory: 1, - Target: 2, - Name: "a", - }, - &Rlink{}, - &Tmkdir{ - Directory: 1, - Name: "a", - Permissions: 2, - GID: 3, - }, - &Rmkdir{ - QID: QID{Type: 1}, - }, - &Trenameat{ - OldDirectory: 1, - OldName: "a", - NewDirectory: 2, - NewName: "b", - }, - &Rrenameat{}, - &Tunlinkat{ - Directory: 1, - Name: "a", - Flags: 2, - }, - &Runlinkat{}, - &Tversion{ - MSize: 1, - Version: "a", - }, - &Rversion{ - MSize: 1, - Version: "a", - }, - &Tauth{ - AuthenticationFID: 1, - UserName: "a", - AttachName: "b", - UID: 2, - }, - &Rauth{ - QID: QID{Type: 1}, - }, - &Tattach{ - FID: 1, - Auth: Tauth{AuthenticationFID: 2}, - }, - &Rattach{ - QID: QID{Type: 1}, - }, - &Tflush{ - OldTag: 1, - }, - &Rflush{}, - &Twalk{ - FID: 1, - NewFID: 2, - Names: []string{"a"}, - }, - &Rwalk{ - QIDs: []QID{{Type: 1}}, - }, - &Tread{ - FID: 1, - Offset: 2, - Count: 3, - }, - &Rread{ - Data: []byte{'a'}, - }, - &Twrite{ - FID: 1, - Offset: 2, - Data: []byte{'a'}, - }, - &Rwrite{ - Count: 1, - }, - &Tclunk{ - FID: 1, - }, - &Rclunk{}, - &Tremove{ - FID: 1, - }, - &Rremove{}, - &Tflushf{ - FID: 1, - }, - &Rflushf{}, - &Twalkgetattr{ - FID: 1, - NewFID: 2, - Names: []string{"a"}, - }, - &Rwalkgetattr{ - QIDs: []QID{{Type: 1}}, - Valid: AttrMask{Mode: true}, - Attr: Attr{Mode: Write}, - }, - &Tucreate{ - Tlcreate: Tlcreate{ - FID: 1, - Name: "a", - OpenFlags: 2, - Permissions: 3, - GID: 4, - }, - UID: 5, - }, - &Rucreate{ - Rlcreate{Rlopen{QID: QID{Type: 1}}}, - }, - &Tumkdir{ - Tmkdir: Tmkdir{ - Directory: 1, - Name: "a", - Permissions: 2, - GID: 3, - }, - UID: 4, - }, - &Rumkdir{ - Rmkdir{QID: QID{Type: 1}}, - }, - &Tusymlink{ - Tsymlink: Tsymlink{ - Directory: 1, - Name: "a", - Target: "b", - GID: 2, - }, - UID: 3, - }, - &Rusymlink{ - Rsymlink{QID: QID{Type: 1}}, - }, - &Tumknod{ - Tmknod: Tmknod{ - Directory: 1, - Name: "a", - Mode: 2, - Major: 3, - Minor: 4, - GID: 5, - }, - UID: 6, - }, - &Rumknod{ - Rmknod{QID: QID{Type: 1}}, - }, - } - - for _, enc := range objs { - // Encode the original. - data := make([]byte, initialBufferLength) - buf := buffer{data: data[:0]} - enc.encode(&buf) - - // Create a new object, same as the first. - enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder) - buf2 := buffer{data: buf.data} - - // To be fair, we need to add any payloads (directly). - if pl, ok := enc.(payloader); ok { - enc2.(payloader).SetPayload(pl.Payload()) - } - - // And any file payloads (directly). - if fl, ok := enc.(filer); ok { - enc2.(filer).SetFilePayload(fl.FilePayload()) - } - - // Mark sure it was okay. - enc2.decode(&buf2) - if buf2.isOverrun() { - t.Errorf("object %#v->%#v got overrun on decode", enc, enc2) - continue - } - - // Check that they are equal. - if !reflect.DeepEqual(enc, enc2) { - t.Errorf("object %#v and %#v differ", enc, enc2) - continue - } - } -} - -func TestMessageStrings(t *testing.T) { - for typ := range msgRegistry.factories { - entry := &msgRegistry.factories[typ] - if entry.create != nil { - name := fmt.Sprintf("%+v", typ) - t.Run(name, func(t *testing.T) { - defer func() { // Ensure no panic. - if r := recover(); r != nil { - t.Errorf("printing %s failed: %v", name, r) - } - }() - m := entry.create() - _ = fmt.Sprintf("%v", m) - err := ErrInvalidMsgType{MsgType(typ)} - _ = err.Error() - }) - } - } -} - -func TestRegisterDuplicate(t *testing.T) { - defer func() { - if r := recover(); r == nil { - // We expect a panic. - t.FailNow() - } - }() - - // Register a duplicate. - msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} }) -} - -func TestMsgCache(t *testing.T) { - // Cache starts empty. - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Message can be created with an empty cache. - msg, err := msgRegistry.get(0, MsgRlerror) - if err != nil { - t.Errorf("msgRegistry.get(): %v", err) - } - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Check that message is added to the cache when returned. - msgRegistry.put(msg) - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Check that returned message is reused. - if got, err := msgRegistry.get(0, MsgRlerror); err != nil { - t.Errorf("msgRegistry.get(): %v", err) - } else if msg != got { - t.Errorf("Message not reused, got: %d, want: %d", got, msg) - } - - // Check that cache doesn't grow beyond max size. - for i := 0; i < maxCacheSize+1; i++ { - msgRegistry.put(&Rlerror{}) - } - if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } -} diff --git a/pkg/p9/p9_state_autogen.go b/pkg/p9/p9_state_autogen.go new file mode 100755 index 000000000..bc9b1bd57 --- /dev/null +++ b/pkg/p9/p9_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package p9 diff --git a/pkg/p9/p9_test.go b/pkg/p9/p9_test.go deleted file mode 100644 index 8dda6cc64..000000000 --- a/pkg/p9/p9_test.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "os" - "testing" -) - -func TestFileModeHelpers(t *testing.T) { - fns := map[FileMode]struct { - // name identifies the file mode. - name string - - // function is the function that should return true given the - // right FileMode. - function func(m FileMode) bool - }{ - ModeRegular: { - name: "regular", - function: FileMode.IsRegular, - }, - ModeDirectory: { - name: "directory", - function: FileMode.IsDir, - }, - ModeNamedPipe: { - name: "named pipe", - function: FileMode.IsNamedPipe, - }, - ModeCharacterDevice: { - name: "character device", - function: FileMode.IsCharacterDevice, - }, - ModeBlockDevice: { - name: "block device", - function: FileMode.IsBlockDevice, - }, - ModeSymlink: { - name: "symlink", - function: FileMode.IsSymlink, - }, - ModeSocket: { - name: "socket", - function: FileMode.IsSocket, - }, - } - for mode, info := range fns { - // Make sure the mode doesn't identify as anything but itself. - for testMode, testfns := range fns { - if mode != testMode && testfns.function(mode) { - t.Errorf("Mode %s returned true when asked if it was mode %s", info.name, testfns.name) - } - } - - // Make sure mode identifies as itself. - if !info.function(mode) { - t.Errorf("Mode %s returned false when asked if it was itself", info.name) - } - } -} - -func TestFileModeToQID(t *testing.T) { - for _, test := range []struct { - // name identifies the test. - name string - - // mode is the FileMode we start out with. - mode FileMode - - // want is the corresponding QIDType we expect. - want QIDType - }{ - { - name: "Directories are of type directory", - mode: ModeDirectory, - want: TypeDir, - }, - { - name: "Sockets are append-only files", - mode: ModeSocket, - want: TypeAppendOnly, - }, - { - name: "Named pipes are append-only files", - mode: ModeNamedPipe, - want: TypeAppendOnly, - }, - { - name: "Character devices are append-only files", - mode: ModeCharacterDevice, - want: TypeAppendOnly, - }, - { - name: "Symlinks are of type symlink", - mode: ModeSymlink, - want: TypeSymlink, - }, - { - name: "Regular files are of type regular", - mode: ModeRegular, - want: TypeRegular, - }, - { - name: "Block devices are regular files", - mode: ModeBlockDevice, - want: TypeRegular, - }, - } { - if qidType := test.mode.QIDType(); qidType != test.want { - t.Errorf("ModeToQID test %s failed: got %o, wanted %o", test.name, qidType, test.want) - } - } -} - -func TestP9ModeConverters(t *testing.T) { - for _, m := range []FileMode{ - ModeRegular, - ModeDirectory, - ModeCharacterDevice, - ModeBlockDevice, - ModeSocket, - ModeSymlink, - ModeNamedPipe, - } { - if mb := ModeFromOS(m.OSMode()); mb != m { - t.Errorf("Converting %o to OS.FileMode gives %o and is converted back as %o", m, m.OSMode(), mb) - } - } -} - -func TestOSModeConverters(t *testing.T) { - // Modes that can be converted back and forth. - for _, m := range []os.FileMode{ - 0, // Regular file. - os.ModeDir, - os.ModeCharDevice | os.ModeDevice, - os.ModeDevice, - os.ModeSocket, - os.ModeSymlink, - os.ModeNamedPipe, - } { - if mb := ModeFromOS(m).OSMode(); mb != m { - t.Errorf("Converting %o to p9.FileMode gives %o and is converted back as %o", m, ModeFromOS(m), mb) - } - } - - // Modes that will be converted to a regular file since p9 cannot - // express these. - for _, m := range []os.FileMode{ - os.ModeAppend, - os.ModeExclusive, - os.ModeTemporary, - } { - if p9Mode := ModeFromOS(m); p9Mode != ModeRegular { - t.Errorf("Converting %o to p9.FileMode should have given ModeRegular, but yielded %o", m, p9Mode) - } - } -} - -func TestAttrMaskContains(t *testing.T) { - req := AttrMask{Mode: true, Size: true} - have := AttrMask{} - if have.Contains(req) { - t.Fatalf("AttrMask %v should not be a superset of %v", have, req) - } - have.Mode = true - if have.Contains(req) { - t.Fatalf("AttrMask %v should not be a superset of %v", have, req) - } - have.Size = true - have.MTime = true - if !have.Contains(req) { - t.Fatalf("AttrMask %v should be a superset of %v", have, req) - } -} diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD deleted file mode 100644 index 7ca67cb19..000000000 --- a/pkg/p9/p9test/BUILD +++ /dev/null @@ -1,88 +0,0 @@ -load("//tools:defs.bzl", "go_binary", "go_library", "go_test") - -package(licenses = ["notice"]) - -alias( - name = "mockgen", - actual = "@com_github_golang_mock//mockgen:mockgen", -) - -MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9" - -# mockgen_reflect is a source file that contains mock generation code that -# imports the p9 package and generates a specification via reflection. The -# usual generation path must be split into two distinct parts because the full -# source tree is not available to all build targets. Only declared depencies -# are available (and even then, not the Go source files). -genrule( - name = "mockgen_reflect", - testonly = 1, - outs = ["mockgen_reflect.go"], - cmd = ( - "$(location :mockgen) " + - "-package p9test " + - "-prog_only " + MOCK_SRC_PACKAGE + " " + - "Attacher,File > $@" - ), - tools = [":mockgen"], -) - -# mockgen_exec is the binary that includes the above reflection generator. -# Running this binary will emit an encoded version of the p9 Attacher and File -# structures. This is consumed by the mocks genrule, below. -go_binary( - name = "mockgen_exec", - testonly = 1, - srcs = ["mockgen_reflect.go"], - deps = [ - "//pkg/p9", - "@com_github_golang_mock//mockgen/model:go_default_library", - ], -) - -# mocks consumes the encoded output above, and generates the full source for a -# set of mocks. These are included directly in the p9test library. -genrule( - name = "mocks", - testonly = 1, - outs = ["mocks.go"], - cmd = ( - "$(location :mockgen) " + - "-package p9test " + - "-exec_only $(location :mockgen_exec) " + MOCK_SRC_PACKAGE + " File > $@" - ), - tools = [ - ":mockgen", - ":mockgen_exec", - ], -) - -go_library( - name = "p9test", - srcs = [ - "mocks.go", - "p9test.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/p9", - "//pkg/sync", - "//pkg/unet", - "@com_github_golang_mock//gomock:go_default_library", - ], -) - -go_test( - name = "client_test", - size = "medium", - srcs = ["client_test.go"], - library = ":p9test", - deps = [ - "//pkg/fd", - "//pkg/p9", - "//pkg/sync", - "@com_github_golang_mock//gomock:go_default_library", - ], -) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go deleted file mode 100644 index 6e7bb3db2..000000000 --- a/pkg/p9/p9test/client_test.go +++ /dev/null @@ -1,2242 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9test - -import ( - "bytes" - "fmt" - "io" - "math/rand" - "os" - "reflect" - "strings" - "syscall" - "testing" - "time" - - "github.com/golang/mock/gomock" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sync" -) - -func TestPanic(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(nil)(nil) - defer d.Close() // Needed manually. - h.Attacher.EXPECT().Attach().Return(d, nil).Do(func() { - // Panic here, and ensure that we get back EFAULT. - panic("handler") - }) - - // Attach to the client. - if _, err := c.Attach("/"); err != syscall.EFAULT { - t.Fatalf("got attach err %v, want EFAULT", err) - } -} - -func TestAttachNoLeak(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(nil)(nil) - h.Attacher.EXPECT().Attach().Return(d, nil).Times(1) - - // Attach to the client. - f, err := c.Attach("/") - if err != nil { - t.Fatalf("got attach err %v, want nil", err) - } - - // Don't close the file. This should be closed automatically when the - // client disconnects. The mock asserts that everything is closed - // exactly once. This statement just removes the unused variable error. - _ = f -} - -func TestBadAttach(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Return an error on attach. - h.Attacher.EXPECT().Attach().Return(nil, syscall.EINVAL).Times(1) - - // Attach to the client. - if _, err := c.Attach("/"); err != syscall.EINVAL { - t.Fatalf("got attach err %v, want syscall.EINVAL", err) - } -} - -func TestWalkAttach(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(map[string]Generator{ - "a": h.NewDirectory(map[string]Generator{ - "b": h.NewFile(), - }), - })(nil) - h.Attacher.EXPECT().Attach().Return(d, nil).Times(1) - - // Attach to the client as a non-root, and ensure that the walk above - // occurs as expected. We should get back b, and all references should - // be dropped when the file is closed. - f, err := c.Attach("/a/b") - if err != nil { - t.Fatalf("got attach err %v, want nil", err) - } - defer f.Close() - - // Check that's a regular file. - if _, _, attr, err := f.GetAttr(p9.AttrMaskAll()); err != nil { - t.Errorf("got err %v, want nil", err) - } else if !attr.Mode.IsRegular() { - t.Errorf("got mode %v, want regular file", err) - } -} - -// newTypeMap returns a new type map dictionary. -func newTypeMap(h *Harness) map[string]Generator { - return map[string]Generator{ - "directory": h.NewDirectory(map[string]Generator{}), - "file": h.NewFile(), - "symlink": h.NewSymlink(), - "block-device": h.NewBlockDevice(), - "character-device": h.NewCharacterDevice(), - "named-pipe": h.NewNamedPipe(), - "socket": h.NewSocket(), - } -} - -// newRoot returns a new root filesystem. -// -// This is set up in a deterministic way for testing most operations. -// -// The represented file system looks like: -// - file -// - symlink -// - directory -// ... -// + one -// - file -// - symlink -// - directory -// ... -// + two -// - file -// - symlink -// - directory -// ... -// + three -// - file -// - symlink -// - directory -// ... -func newRoot(h *Harness, c *p9.Client) (*Mock, p9.File) { - root := newTypeMap(h) - one := newTypeMap(h) - two := newTypeMap(h) - three := newTypeMap(h) - one["two"] = h.NewDirectory(two) // Will be nested in one. - root["one"] = h.NewDirectory(one) // Top level. - root["three"] = h.NewDirectory(three) // Alternate top-level. - - // Create a new root. - rootBackend := h.NewDirectory(root)(nil) - h.Attacher.EXPECT().Attach().Return(rootBackend, nil) - - // Attach to the client. - r, err := c.Attach("/") - if err != nil { - h.t.Fatalf("got attach err %v, want nil", err) - } - - return rootBackend, r -} - -func allInvalidNames(from string) []string { - return []string{ - from + "/other", - from + "/..", - from + "/.", - from + "/", - "other/" + from, - "/" + from, - "./" + from, - "../" + from, - ".", - "..", - "/", - "", - } -} - -func TestWalkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Run relevant tests. - for name := range newTypeMap(h) { - // These are all the various ways that one might attempt to - // construct compound paths. They should all be rejected, as - // any compound that contains a / is not allowed, as well as - // the singular paths of '.' and '..'. - if _, _, err := root.Walk([]string{".", name}); err != syscall.EINVAL { - t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{"..", name}); err != syscall.EINVAL { - t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{name, "."}); err != syscall.EINVAL { - t.Errorf("Walk through %s . wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{name, ".."}); err != syscall.EINVAL { - t.Errorf("Walk through %s .. wanted EINVAL, got %v", name, err) - } - for _, invalidName := range allInvalidNames(name) { - if _, _, err := root.Walk([]string{invalidName}); err != syscall.EINVAL { - t.Errorf("Walk through %s wanted EINVAL, got %v", invalidName, err) - } - } - wantErr := syscall.EINVAL - if name == "directory" { - // We can attempt a walk through a directory. However, - // we should never see a file named "other", so we - // expect this to return ENOENT. - wantErr = syscall.ENOENT - } - if _, _, err := root.Walk([]string{name, "other"}); err != wantErr { - t.Errorf("Walk through %s/other wanted %v, got %v", name, wantErr, err) - } - - // Do a successful walk. - _, f, err := root.Walk([]string{name}) - if err != nil { - t.Errorf("Walk to %s wanted nil, got %v", name, err) - } - defer f.Close() - local := h.Pop(f) - - // Check that the file matches. - _, localMask, localAttr, localErr := local.GetAttr(p9.AttrMaskAll()) - if _, mask, attr, err := f.GetAttr(p9.AttrMaskAll()); mask != localMask || attr != localAttr || err != localErr { - t.Errorf("GetAttr got (%v, %v, %v), wanted (%v, %v, %v)", - mask, attr, err, localMask, localAttr, localErr) - } - - // Ensure we can't walk backwards. - if _, _, err := f.Walk([]string{"."}); err != syscall.EINVAL { - t.Errorf("Walk through %s/. wanted EINVAL, got %v", name, err) - } - if _, _, err := f.Walk([]string{".."}); err != syscall.EINVAL { - t.Errorf("Walk through %s/.. wanted EINVAL, got %v", name, err) - } - } -} - -// fileGenerator is a function to generate files via walk or create. -// -// Examples are: -// - walkHelper -// - walkAndOpenHelper -// - createHelper -type fileGenerator func(*Harness, string, p9.File) (*Mock, *Mock, p9.File) - -// walkHelper walks to the given file. -// -// The backends of the parent and walked file are returned, as well as the -// walked client file. -func walkHelper(h *Harness, name string, dir p9.File) (parentBackend *Mock, walkedBackend *Mock, walked p9.File) { - _, parent, err := dir.Walk(nil) - if err != nil { - h.t.Fatalf("Walk(nil) got err %v, want nil", err) - } - defer parent.Close() - parentBackend = h.Pop(parent) - - _, walked, err = parent.Walk([]string{name}) - if err != nil { - h.t.Fatalf("Walk(%s) got err %v, want nil", name, err) - } - walkedBackend = h.Pop(walked) - - return parentBackend, walkedBackend, walked -} - -// walkAndOpenHelper additionally opens the walked file, if possible. -func walkAndOpenHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) { - parentBackend, walkedBackend, walked := walkHelper(h, name, dir) - if p9.CanOpen(walkedBackend.Attr.Mode) { - // Open for all file types that we can. We stick to a read-only - // open here because directories may not be opened otherwise. - walkedBackend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := walked.Open(p9.ReadOnly); err != nil { - h.t.Errorf("got open err %v, want nil", err) - } - } else { - // ... or assert an error for others. - if _, _, _, err := walked.Open(p9.ReadOnly); err != syscall.EINVAL { - h.t.Errorf("got open err %v, want EINVAL", err) - } - } - return parentBackend, walkedBackend, walked -} - -// createHelper creates the given file and returns the parent directory, -// created file and client file, which must be closed when done. -func createHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) { - // Clone the directory first, since Create replaces the existing file. - // We change the type after calling create. - _, dirThenFile, err := dir.Walk(nil) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - - // Create a new server-side file. On the server-side, the a new file is - // returned from a create call. The client will reuse the same file, - // but we still expect the normal chain of closes. This complicates - // things a bit because the "parent" will always chain to the cloned - // dir above. - dirBackend := h.Pop(dirThenFile) // New backend directory. - newFile := h.NewFile()(dirBackend) // New file with backend parent. - dirBackend.EXPECT().Create(name, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, newFile, newFile.QID, uint32(0), nil) - - // Create via the client. - _, dirThenFile, _, _, err = dirThenFile.Create(name, p9.ReadOnly, 0, 0, 0) - if err != nil { - h.t.Fatalf("got create err %v, want nil", err) - } - - // Ensure subsequent walks succeed. - dirBackend.AddChild(name, h.NewFile()) - return dirBackend, newFile, dirThenFile -} - -// deprecatedRemover allows us to access the deprecated Remove operation within -// the p9.File client object. -type deprecatedRemover interface { - Remove() error -} - -// checkDeleted asserts that relevant methods fail for an unlinked file. -// -// This function will close the file at the end. -func checkDeleted(h *Harness, file p9.File) { - defer file.Close() // See doc. - - if _, _, _, err := file.Open(p9.ReadOnly); err != syscall.EINVAL { - h.t.Errorf("open while deleted, got %v, want EINVAL", err) - } - if _, _, _, _, err := file.Create("created", p9.ReadOnly, 0, 0, 0); err != syscall.EINVAL { - h.t.Errorf("create while deleted, got %v, want EINVAL", err) - } - if _, err := file.Symlink("old", "new", 0, 0); err != syscall.EINVAL { - h.t.Errorf("symlink while deleted, got %v, want EINVAL", err) - } - // N.B. This link is technically invalid, but if a call to link is - // actually made in the backend then the mock will panic. - if err := file.Link(file, "new"); err != syscall.EINVAL { - h.t.Errorf("link while deleted, got %v, want EINVAL", err) - } - if err := file.RenameAt("src", file, "dst"); err != syscall.EINVAL { - h.t.Errorf("renameAt while deleted, got %v, want EINVAL", err) - } - if err := file.UnlinkAt("file", 0); err != syscall.EINVAL { - h.t.Errorf("unlinkAt while deleted, got %v, want EINVAL", err) - } - if err := file.Rename(file, "dst"); err != syscall.EINVAL { - h.t.Errorf("rename while deleted, got %v, want EINVAL", err) - } - if _, err := file.Readlink(); err != syscall.EINVAL { - h.t.Errorf("readlink while deleted, got %v, want EINVAL", err) - } - if _, err := file.Mkdir("dir", p9.ModeDirectory, 0, 0); err != syscall.EINVAL { - h.t.Errorf("mkdir while deleted, got %v, want EINVAL", err) - } - if _, err := file.Mknod("dir", p9.ModeDirectory, 0, 0, 0, 0); err != syscall.EINVAL { - h.t.Errorf("mknod while deleted, got %v, want EINVAL", err) - } - if _, err := file.Readdir(0, 1); err != syscall.EINVAL { - h.t.Errorf("readdir while deleted, got %v, want EINVAL", err) - } - if _, err := file.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL { - h.t.Errorf("connect while deleted, got %v, want EINVAL", err) - } - - // The remove method is technically deprecated, but we want to ensure - // that it still checks for deleted appropriately. We must first clone - // the file because remove is equivalent to close. - _, newFile, err := file.Walk(nil) - if err == syscall.EBUSY { - // We can't walk from here because this reference is open - // already. Okay, we will also have unopened cases through - // TestUnlink, just skip the remove operation for now. - return - } else if err != nil { - h.t.Fatalf("clone failed, got %v, want nil", err) - } - if err := newFile.(deprecatedRemover).Remove(); err != syscall.EINVAL { - h.t.Errorf("remove while deleted, got %v, want EINVAL", err) - } -} - -// deleter is a function to remove a file. -type deleter func(parent p9.File, name string) error - -// unlinkAt is a deleter. -func unlinkAt(parent p9.File, name string) error { - // Call unlink. Note that a filesystem may normally impose additional - // constaints on unlinkat success, such as ensuring that a directory is - // empty, requiring AT_REMOVEDIR in flags to remove a directory, etc. - // None of that is required internally (entire trees can be marked - // deleted when this operation succeeds), so the mock will succeed. - return parent.UnlinkAt(name, 0) -} - -// remove is a deleter. -func remove(parent p9.File, name string) error { - // See notes above re: remove. - _, newFile, err := parent.Walk([]string{name}) - if err != nil { - // Should not be expected. - return err - } - - // Do the actual remove. - if err := newFile.(deprecatedRemover).Remove(); err != nil { - return err - } - - // Ensure that the remove closed the file. - if err := newFile.(deprecatedRemover).Remove(); err != syscall.EBADF { - return syscall.EBADF // Propagate this code. - } - - return nil -} - -// unlinkHelper unlinks the noted path, and ensures that all relevant -// operations on that path, acquired from multiple paths, start failing. -func unlinkHelper(h *Harness, root p9.File, targetNames []string, targetGen fileGenerator, deleteFn deleter) { - // name is the file to be unlinked. - name := targetNames[len(targetNames)-1] - - // Walk to the directory containing the target. - _, parent, err := root.Walk(targetNames[:len(targetNames)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer parent.Close() - parentBackend := h.Pop(parent) - - // Walk to or generate the target file. - _, _, target := targetGen(h, name, parent) - defer checkDeleted(h, target) - - // Walk to a second reference. - _, second, err := parent.Walk([]string{name}) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer checkDeleted(h, second) - - // Walk to a third reference, from the start. - _, third, err := root.Walk(targetNames) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer checkDeleted(h, third) - - // This will be translated in the backend to an unlinkat. - parentBackend.EXPECT().UnlinkAt(name, uint32(0)).Return(nil) - - // Actually perform the deletion. - if err := deleteFn(parent, name); err != nil { - h.t.Fatalf("got delete err %v, want nil", err) - } -} - -func unlinkTest(t *testing.T, targetNames []string, targetGen fileGenerator) { - t.Run(fmt.Sprintf("unlinkAt(%s)", strings.Join(targetNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - unlinkHelper(h, root, targetNames, targetGen, unlinkAt) - }) - t.Run(fmt.Sprintf("remove(%s)", strings.Join(targetNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - unlinkHelper(h, root, targetNames, targetGen, remove) - }) -} - -func TestUnlink(t *testing.T) { - // Unlink all files. - for name := range newTypeMap(nil) { - unlinkTest(t, []string{name}, walkHelper) - unlinkTest(t, []string{name}, walkAndOpenHelper) - unlinkTest(t, []string{"one", name}, walkHelper) - unlinkTest(t, []string{"one", name}, walkAndOpenHelper) - unlinkTest(t, []string{"one", "two", name}, walkHelper) - unlinkTest(t, []string{"one", "two", name}, walkAndOpenHelper) - } - - // Unlink a directory. - unlinkTest(t, []string{"one"}, walkHelper) - unlinkTest(t, []string{"one"}, walkAndOpenHelper) - unlinkTest(t, []string{"one", "two"}, walkHelper) - unlinkTest(t, []string{"one", "two"}, walkAndOpenHelper) - - // Unlink created files. - unlinkTest(t, []string{"created"}, createHelper) - unlinkTest(t, []string{"one", "created"}, createHelper) - unlinkTest(t, []string{"one", "two", "created"}, createHelper) -} - -func TestUnlinkAtInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.UnlinkAt(invalidName, 0); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -// expectRenamed asserts an ordered sequence of rename calls, based on all the -// elements in elements being the source, and the first element therein -// changing to dstName, parented at dstParent. -func expectRenamed(file *Mock, elements []string, dstParent *Mock, dstName string) *gomock.Call { - if len(elements) > 0 { - // Recurse to the parent, if necessary. - call := expectRenamed(file.parent, elements[:len(elements)-1], dstParent, dstName) - - // Recursive case: this element is unchanged, but should have - // it's hook called after the parent. - return file.EXPECT().Renamed(file.parent, elements[len(elements)-1]).Do(func(p p9.File, _ string) { - file.parent = p.(*Mock) - }).After(call) - } - - // Base case: this is the changed element. - return file.EXPECT().Renamed(dstParent, dstName).Do(func(p p9.File, name string) { - file.parent = p.(*Mock) - }) -} - -// renamer is a rename function. -type renamer func(h *Harness, srcParent, dstParent p9.File, origName, newName string, selfRename bool) error - -// renameAt is a renamer. -func renameAt(_ *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error { - return srcParent.RenameAt(srcName, dstParent, dstName) -} - -// rename is a renamer. -func rename(h *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error { - _, f, err := srcParent.Walk([]string{srcName}) - if err != nil { - return err - } - defer f.Close() - if !selfRename { - backend := h.Pop(f) - backend.EXPECT().Renamed(gomock.Any(), dstName).Do(func(p p9.File, name string) { - backend.parent = p.(*Mock) // Required for close ordering. - }) - } - return f.Rename(dstParent, dstName) -} - -// renameHelper executes a rename, and asserts that all relevant elements -// receive expected notifications. If overwriting a file, this includes -// ensuring that the target has been appropriately marked as unlinked. -func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string, target fileGenerator, renameFn renamer) { - // Walk to the directory containing the target. - srcQID, targetParent, err := root.Walk(srcNames[:len(srcNames)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer targetParent.Close() - targetParentBackend := h.Pop(targetParent) - - // Walk to or generate the target file. - _, targetBackend, src := target(h, srcNames[len(srcNames)-1], targetParent) - defer src.Close() - - // Walk to a second reference. - _, second, err := targetParent.Walk([]string{srcNames[len(srcNames)-1]}) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer second.Close() - secondBackend := h.Pop(second) - - // Walk to a third reference, from the start. - _, third, err := root.Walk(srcNames) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer third.Close() - thirdBackend := h.Pop(third) - - // Find the common suffix to identify the rename parent. - var ( - renameDestPath []string - renameSrcPath []string - selfRename bool - ) - for i := 1; i <= len(srcNames) && i <= len(dstNames); i++ { - if srcNames[len(srcNames)-i] != dstNames[len(dstNames)-i] { - // Take the full prefix of dstNames up until this - // point, including the first mismatched name. The - // first mismatch must be the renamed entry. - renameDestPath = dstNames[:len(dstNames)-i+1] - renameSrcPath = srcNames[:len(srcNames)-i+1] - - // Does the renameDestPath fully contain the - // renameSrcPath here? If yes, then this is a mismatch. - // We can't rename the src to some subpath of itself. - if len(renameDestPath) > len(renameSrcPath) && - reflect.DeepEqual(renameDestPath[:len(renameSrcPath)], renameSrcPath) { - renameDestPath = nil - renameSrcPath = nil - continue - } - break - } - } - if len(renameSrcPath) == 0 || len(renameDestPath) == 0 { - // This must be a rename to self, or a tricky look-alike. This - // happens iff we fail to find a suitable divergence in the two - // paths. It's a true self move if the path length is the same. - renameDestPath = dstNames - renameSrcPath = srcNames - selfRename = len(srcNames) == len(dstNames) - } - - // Walk to the source parent. - _, srcParent, err := root.Walk(renameSrcPath[:len(renameSrcPath)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer srcParent.Close() - srcParentBackend := h.Pop(srcParent) - - // Walk to the destination parent. - _, dstParent, err := root.Walk(renameDestPath[:len(renameDestPath)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer dstParent.Close() - dstParentBackend := h.Pop(dstParent) - - // expectedErr is the result of the rename operation. - var expectedErr error - - // Walk to the target file, if one exists. - dstQID, dst, err := root.Walk(renameDestPath) - if err == nil { - if !selfRename && srcQID[0].Type == dstQID[0].Type { - // If there is a destination file, and is it of the - // same type as the source file, then we expect the - // rename to succeed. We expect the destination file to - // be deleted, so we run a deletion test on it in this - // case. - defer checkDeleted(h, dst) - } else { - if !selfRename { - // If the type is different than the - // destination, then we expect the rename to - // fail. We expect ensure that this is - // returned. - expectedErr = syscall.EINVAL - } else { - // This is the file being renamed to itself. - // This is technically allowed and a no-op, but - // all the triggers will fire. - } - dst.Close() - } - } - dstName := renameDestPath[len(renameDestPath)-1] // Renamed element. - srcName := renameSrcPath[len(renameSrcPath)-1] // Renamed element. - if expectedErr == nil && !selfRename { - // Expect all to be renamed appropriately. Note that if this is - // a final file being renamed, then we expect the file to be - // called with the new parent. If not, then we expect the - // rename hook to be called, but the parent will remain - // unchanged. - elements := srcNames[len(renameSrcPath):] - expectRenamed(targetBackend, elements, dstParentBackend, dstName) - expectRenamed(secondBackend, elements, dstParentBackend, dstName) - expectRenamed(thirdBackend, elements, dstParentBackend, dstName) - - // The target parent has also been opened, and may be moved - // directly or indirectly. - if len(elements) > 1 { - expectRenamed(targetParentBackend, elements[:len(elements)-1], dstParentBackend, dstName) - } - } - - // Expect the rename if it's not the same file. Note that like unlink, - // renames are always translated to the at variant in the backend. - if !selfRename { - srcParentBackend.EXPECT().RenameAt(srcName, dstParentBackend, dstName).Return(expectedErr) - } - - // Perform the actual rename; everything has been lined up. - if err := renameFn(h, srcParent, dstParent, srcName, dstName, selfRename); err != expectedErr { - h.t.Fatalf("got rename err %v, want %v", err, expectedErr) - } -} - -func renameTest(t *testing.T, srcNames []string, dstNames []string, target fileGenerator) { - t.Run(fmt.Sprintf("renameAt(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - renameHelper(h, root, srcNames, dstNames, target, renameAt) - }) - t.Run(fmt.Sprintf("rename(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - renameHelper(h, root, srcNames, dstNames, target, rename) - }) -} - -func TestRename(t *testing.T) { - // In-directory rename, simple case. - for name := range newTypeMap(nil) { - // Within the root. - renameTest(t, []string{name}, []string{"renamed"}, walkHelper) - renameTest(t, []string{name}, []string{"renamed"}, walkAndOpenHelper) - - // Within a subdirectory. - renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"created"}, []string{"renamed"}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"one", "renamed"}, createHelper) - - // Across directories. - for name := range newTypeMap(nil) { - // Down one level. - renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkAndOpenHelper) - - // Up one level. - renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkHelper) - renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkAndOpenHelper) - - // Across at the same level. - renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"one", "created"}, []string{"one", "two", "renamed"}, createHelper) - renameTest(t, []string{"one", "two", "created"}, []string{"one", "renamed"}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"three", "renamed"}, createHelper) - - // Renaming parents. - for name := range newTypeMap(nil) { - // Rename a parent. - renameTest(t, []string{"one", name}, []string{"renamed", name}, walkHelper) - renameTest(t, []string{"one", name}, []string{"renamed", name}, walkAndOpenHelper) - - // Rename a super parent. - renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkHelper) - renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"one", "created"}, []string{"renamed", "created"}, createHelper) - renameTest(t, []string{"one", "two", "created"}, []string{"renamed", "created"}, createHelper) - - // Over existing files, including itself. - for name := range newTypeMap(nil) { - for other := range newTypeMap(nil) { - // Overwrite the noted file (may be itself). - renameTest(t, []string{"one", name}, []string{"one", other}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", other}, walkAndOpenHelper) - - // Overwrite other files in another directory. - renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkAndOpenHelper) - } - - // Overwrite by moving the parent. - renameTest(t, []string{"three", name}, []string{"one", name}, walkHelper) - renameTest(t, []string{"three", name}, []string{"one", name}, walkAndOpenHelper) - - // Create over the types. - renameTest(t, []string{"one", "created"}, []string{"one", name}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"one", "two", name}, createHelper) - renameTest(t, []string{"three", "created"}, []string{"one", name}, createHelper) - } -} - -func TestRenameInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.Rename(root, invalidName); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestRenameAtInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.RenameAt(invalidName, root, "okay"); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - if err := root.RenameAt("okay", root, invalidName); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -// TestRenameSecondOrder tests that indirect rename targets continue to receive -// Renamed calls after a rename of its renamed parent. i.e., -// -// 1. Create /one/file -// 2. Create /directory -// 3. Rename /one -> /directory/one -// 4. Rename /directory -> /three/foo -// 5. file from (1) should still receive Renamed. -// -// This is a regression test for b/135219260. -func TestRenameSecondOrder(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - rootBackend, root := newRoot(h, c) - defer root.Close() - - // Walk to /one. - _, oneBackend, oneFile := walkHelper(h, "one", root) - defer oneFile.Close() - - // Walk to and generate /one/file. - // - // walkHelper re-walks to oneFile, so we need the second backend, - // which will also receive Renamed calls. - oneSecondBackend, fileBackend, fileFile := walkHelper(h, "file", oneFile) - defer fileFile.Close() - - // Walk to and generate /directory. - _, directoryBackend, directoryFile := walkHelper(h, "directory", root) - defer directoryFile.Close() - - // Rename /one to /directory/one. - rootBackend.EXPECT().RenameAt("one", directoryBackend, "one").Return(nil) - expectRenamed(oneBackend, []string{}, directoryBackend, "one") - expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one") - expectRenamed(fileBackend, []string{}, oneBackend, "file") - if err := renameAt(h, root, directoryFile, "one", "one", false); err != nil { - h.t.Fatalf("got rename err %v, want nil", err) - } - - // Walk to /three. - _, threeBackend, threeFile := walkHelper(h, "three", root) - defer threeFile.Close() - - // Rename /directory to /three/foo. - rootBackend.EXPECT().RenameAt("directory", threeBackend, "foo").Return(nil) - expectRenamed(directoryBackend, []string{}, threeBackend, "foo") - expectRenamed(oneBackend, []string{}, directoryBackend, "one") - expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one") - expectRenamed(fileBackend, []string{}, oneBackend, "file") - if err := renameAt(h, root, threeFile, "directory", "foo", false); err != nil { - h.t.Fatalf("got rename err %v, want nil", err) - } -} - -func TestReadlink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, f, err := root.Walk([]string{name}) - if err != nil { - t.Fatalf("walk failed: got %v, wanted nil", err) - } - defer f.Close() - backend := h.Pop(f) - - const symlinkTarget = "symlink-target" - - if backend.Attr.Mode.IsSymlink() { - // This should only go through on symlinks. - backend.EXPECT().Readlink().Return(symlinkTarget, nil) - } - - // Attempt a Readlink operation. - target, err := f.Readlink() - if err != nil && err != syscall.EINVAL { - t.Errorf("readlink got %v, wanted EINVAL", err) - } else if err == nil && target != symlinkTarget { - t.Errorf("readlink got %v, wanted %v", target, symlinkTarget) - } - }) - } -} - -// fdTest is a wrapper around operations that may send file descriptors. This -// asserts that the file descriptors are working as intended. -func fdTest(t *testing.T, sendFn func(*fd.FD) *fd.FD) { - // Create a pipe that we can read from. - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("unable to create pipe: %v", err) - } - defer r.Close() - defer w.Close() - - // Attempt to send the write end. - wFD, err := fd.NewFromFile(w) - if err != nil { - t.Fatalf("unable to convert file: %v", err) - } - defer wFD.Close() // This is a copy. - - // Send wFD and receive newFD. - newFD := sendFn(wFD) - defer newFD.Close() - - // Attempt to write. - const message = "hello" - if _, err := newFD.Write([]byte(message)); err != nil { - t.Fatalf("write got %v, wanted nil", err) - } - - // Should see the message on our end. - buffer := []byte(message) - if _, err := io.ReadFull(r, buffer); err != nil { - t.Fatalf("read got %v, wanted nil", err) - } - if string(buffer) != message { - t.Errorf("got message %v, wanted %v", string(buffer), message) - } -} - -func TestConnect(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Catch all the non-socket cases. - if !backend.Attr.Mode.IsSocket() { - // This has been set up to fail if Connect is called. - if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL { - t.Errorf("connect got %v, wanted EINVAL", err) - } - return - } - - // Ensure the fd exchange works. - fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Connect(p9.ConnectFlags(0)).Return(send, nil) - recv, err := backend.Connect(p9.ConnectFlags(0)) - if err != nil { - t.Fatalf("connect got %v, wanted nil", err) - } - return recv - }) - }) - } -} - -func TestReaddir(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Catch all the non-directory cases. - if !backend.Attr.Mode.IsDir() { - // This has also been set up to fail if Readdir is called. - if _, err := f.Readdir(0, 1); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - return - } - - // Ensure that readdir works for directories. - if _, err := f.Readdir(0, 1); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR { - t.Errorf("readdir got %v, wanted EISDIR", err) - } - if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR { - t.Errorf("readdir got %v, wanted EISDIR", err) - } - backend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := f.Open(p9.ReadOnly); err != nil { - t.Errorf("readdir got %v, wanted nil", err) - } - backend.EXPECT().Readdir(uint64(0), uint32(1)).Times(1) - if _, err := f.Readdir(0, 1); err != nil { - t.Errorf("readdir got %v, wanted nil", err) - } - }) - } -} - -func TestOpen(t *testing.T) { - type openTest struct { - name string - flags p9.OpenFlags - err error - match func(p9.FileMode) bool - } - - cases := []openTest{ - { - name: "not-openable-read-only", - flags: p9.ReadOnly, - err: syscall.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "not-openable-write-only", - flags: p9.WriteOnly, - err: syscall.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "not-openable-read-write", - flags: p9.ReadWrite, - err: syscall.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "directory-read-only", - flags: p9.ReadOnly, - err: nil, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "directory-read-write", - flags: p9.ReadWrite, - err: syscall.EISDIR, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "directory-write-only", - flags: p9.WriteOnly, - err: syscall.EISDIR, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "read-only", - flags: p9.ReadOnly, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - }, - { - name: "write-only", - flags: p9.WriteOnly, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "read-write", - flags: p9.ReadWrite, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "directory-read-only-truncate", - flags: p9.ReadOnly | p9.OpenTruncate, - err: syscall.EISDIR, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "read-only-truncate", - flags: p9.ReadOnly | p9.OpenTruncate, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "write-only-truncate", - flags: p9.WriteOnly | p9.OpenTruncate, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "read-write-truncate", - flags: p9.ReadWrite | p9.OpenTruncate, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - } - - // Open(flags OpenFlags) (*fd.FD, QID, uint32, error) - // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice - // - returning a file works as expected - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Does this match the case? - if !tc.match(backend.Attr.Mode) { - t.SkipNow() - } - - // Ensure open-required operations fail. - if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EINVAL { - t.Errorf("readAt got %v, wanted EINVAL", err) - } - if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EINVAL { - t.Errorf("writeAt got %v, wanted EINVAL", err) - } - if err := f.FSync(); err != syscall.EINVAL { - t.Errorf("fsync got %v, wanted EINVAL", err) - } - if _, err := f.Readdir(0, 1); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - - // Attempt the given open. - if tc.err != nil { - // We expect an error, just test and return. - if _, _, _, err := f.Open(tc.flags); err != tc.err { - t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) - } - return - } - - // Run an FD test, since we expect success. - fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1) - recv, _, _, err := f.Open(tc.flags) - if err != tc.err { - t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) - } - return recv - }) - - // If the open was successful, attempt another one. - if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL { - t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err) - } - - // Ensure that all illegal operations fail. - if _, _, err := f.Walk(nil); err != syscall.EINVAL && err != syscall.EBUSY { - t.Errorf("walk got %v, wanted EINVAL or EBUSY", err) - } - if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EINVAL && err != syscall.EBUSY { - t.Errorf("walkgetattr got %v, wanted EINVAL or EBUSY", err) - } - }) - } - } -} - -func TestClose(t *testing.T) { - type closeTest struct { - name string - closeFn func(backend *Mock, f p9.File) - } - - cases := []closeTest{ - { - name: "close", - closeFn: func(_ *Mock, f p9.File) { - f.Close() - }, - }, - { - name: "remove", - closeFn: func(backend *Mock, f p9.File) { - // Allow the rename call in the parent, automatically translated. - backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1) - f.(deprecatedRemover).Remove() - }, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s(%s)", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - - // Close via the prescribed method. - tc.closeFn(backend, f) - - // Everything should fail with EBADF. - if _, _, err := f.Walk(nil); err != syscall.EBADF { - t.Errorf("walk got %v, wanted EBADF", err) - } - if _, err := f.StatFS(); err != syscall.EBADF { - t.Errorf("statfs got %v, wanted EBADF", err) - } - if _, _, _, err := f.GetAttr(p9.AttrMaskAll()); err != syscall.EBADF { - t.Errorf("getattr got %v, wanted EBADF", err) - } - if err := f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}); err != syscall.EBADF { - t.Errorf("setattrk got %v, wanted EBADF", err) - } - if err := f.Rename(root, "new-name"); err != syscall.EBADF { - t.Errorf("rename got %v, wanted EBADF", err) - } - if err := f.Close(); err != syscall.EBADF { - t.Errorf("close got %v, wanted EBADF", err) - } - if _, _, _, err := f.Open(p9.ReadOnly); err != syscall.EBADF { - t.Errorf("open got %v, wanted EBADF", err) - } - if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EBADF { - t.Errorf("readAt got %v, wanted EBADF", err) - } - if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EBADF { - t.Errorf("writeAt got %v, wanted EBADF", err) - } - if err := f.FSync(); err != syscall.EBADF { - t.Errorf("fsync got %v, wanted EBADF", err) - } - if _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 0, 0); err != syscall.EBADF { - t.Errorf("create got %v, wanted EBADF", err) - } - if _, err := f.Mkdir("new-directory", 0, 0, 0); err != syscall.EBADF { - t.Errorf("mkdir got %v, wanted EBADF", err) - } - if _, err := f.Symlink("old-name", "new-name", 0, 0); err != syscall.EBADF { - t.Errorf("symlink got %v, wanted EBADF", err) - } - if err := f.Link(root, "new-name"); err != syscall.EBADF { - t.Errorf("link got %v, wanted EBADF", err) - } - if _, err := f.Mknod("new-block-device", 0, 0, 0, 0, 0); err != syscall.EBADF { - t.Errorf("mknod got %v, wanted EBADF", err) - } - if err := f.RenameAt("old-name", root, "new-name"); err != syscall.EBADF { - t.Errorf("renameAt got %v, wanted EBADF", err) - } - if err := f.UnlinkAt("name", 0); err != syscall.EBADF { - t.Errorf("unlinkAt got %v, wanted EBADF", err) - } - if _, err := f.Readdir(0, 1); err != syscall.EBADF { - t.Errorf("readdir got %v, wanted EBADF", err) - } - if _, err := f.Readlink(); err != syscall.EBADF { - t.Errorf("readlink got %v, wanted EBADF", err) - } - if err := f.Flush(); err != syscall.EBADF { - t.Errorf("flush got %v, wanted EBADF", err) - } - if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EBADF { - t.Errorf("walkgetattr got %v, wanted EBADF", err) - } - if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EBADF { - t.Errorf("connect got %v, wanted EBADF", err) - } - }) - } - } -} - -// onlyWorksOnOpenThings is a helper test method for operations that should -// only work on files that have been explicitly opened. -func onlyWorksOnOpenThings(h *Harness, t *testing.T, name string, root p9.File, mode p9.OpenFlags, expectedErr error, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) { - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Does it work before opening? - if err := fn(backend, f, false); err != syscall.EINVAL { - t.Errorf("operation got %v, wanted EINVAL", err) - } - - // Is this openable? - if !p9.CanOpen(backend.Attr.Mode) { - return // Nothing to do. - } - - // If this is a directory, we can't handle writing. - if backend.Attr.Mode.IsDir() && (mode == p9.ReadWrite || mode == p9.WriteOnly) { - return // Skip. - } - - // Open the file. - backend.EXPECT().Open(mode) - if _, _, _, err := f.Open(mode); err != nil { - t.Fatalf("open got %v, wanted nil", err) - } - - // Attempt the operation. - if err := fn(backend, f, expectedErr == nil); err != expectedErr { - t.Fatalf("operation got %v, wanted %v", err, expectedErr) - } -} - -func TestRead(t *testing.T) { - type readTest struct { - name string - mode p9.OpenFlags - err error - } - - cases := []readTest{ - { - name: "read-only", - mode: p9.ReadOnly, - err: nil, - }, - { - name: "read-write", - mode: p9.ReadWrite, - err: nil, - }, - { - name: "write-only", - mode: p9.WriteOnly, - err: syscall.EPERM, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - const message = "hello" - - onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, err := f.ReadAt([]byte(message), 0) - return err - } - - // Prepare for the call to readAt in the backend. - backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - copy(p, message) - }).Return(len(message), nil) - - // Make the client call. - p := make([]byte, 2*len(message)) // Double size. - n, err := f.ReadAt(p, 0) - - // Sanity check result. - if err != nil { - return err - } - if n != len(message) { - t.Fatalf("message length incorrect, got %d, want %d", n, len(message)) - } - if !bytes.Equal(p[:n], []byte(message)) { - t.Fatalf("message incorrect, got %v, want %v", p, []byte(message)) - } - return nil // Success. - }) - }) - } - } -} - -func TestWrite(t *testing.T) { - type writeTest struct { - name string - mode p9.OpenFlags - err error - } - - cases := []writeTest{ - { - name: "read-only", - mode: p9.ReadOnly, - err: syscall.EPERM, - }, - { - name: "read-write", - mode: p9.ReadWrite, - err: nil, - }, - { - name: "write-only", - mode: p9.WriteOnly, - err: nil, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - const message = "hello" - - onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, err := f.WriteAt([]byte(message), 0) - return err - } - - // Prepare for the call to readAt in the backend. - var output []byte // Saved by Do below. - backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - output = p - }).Return(len(message), nil) - - // Make the client call. - n, err := f.WriteAt([]byte(message), 0) - - // Sanity check result. - if err != nil { - return err - } - if n != len(message) { - t.Fatalf("message length incorrect, got %d, want %d", n, len(message)) - } - if !bytes.Equal(output, []byte(message)) { - t.Fatalf("message incorrect, got %v, want %v", output, []byte(message)) - } - return nil // Success. - }) - }) - } - } -} - -func TestFSync(t *testing.T) { - for name := range newTypeMap(nil) { - for _, mode := range []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} { - t.Run(fmt.Sprintf("%s-%s", mode, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnOpenThings(h, t, name, root, mode, nil, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().FSync().Times(1) - } - return f.FSync() - }) - }) - } - } -} - -func TestFlush(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - backend.EXPECT().Flush() - f.Flush() - }) - } -} - -// onlyWorksOnDirectories is a helper test method for operations that should -// only work on unopened directories, such as create, mkdir and symlink. -func onlyWorksOnDirectories(h *Harness, t *testing.T, name string, root p9.File, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) { - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Only directories support mknod. - if !backend.Attr.Mode.IsDir() { - if err := fn(backend, f, false); err != syscall.EINVAL { - t.Errorf("operation got %v, wanted EINVAL", err) - } - return // Nothing else to do. - } - - // Should succeed. - if err := fn(backend, f, true); err != nil { - t.Fatalf("operation got %v, wanted nil", err) - } - - // Open the directory. - backend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := f.Open(p9.ReadOnly); err != nil { - t.Fatalf("open got %v, wanted nil", err) - } - - // Should not work again. - if err := fn(backend, f, false); err != syscall.EINVAL { - t.Fatalf("operation got %v, wanted EINVAL", err) - } -} - -func TestCreate(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 1, 2) - return err - } - - // If the create is going to succeed, then we - // need to create a new backend file, and we - // clone to ensure that we don't close the - // original. - _, newF, err := f.Walk(nil) - if err != nil { - t.Fatalf("clone got %v, wanted nil", err) - } - defer newF.Close() - newBackend := h.Pop(newF) - - // Run a regular FD test to validate that path. - fdTest(t, func(send *fd.FD) *fd.FD { - // Return the send FD on success. - newFile := h.NewFile()(backend) // New file with the parent backend. - newBackend.EXPECT().Create("new-file", p9.ReadWrite, p9.FileMode(0), p9.UID(1), p9.GID(2)).Return(send, newFile, p9.QID{}, uint32(0), nil) - - // Receive the fd back. - recv, _, _, _, err := newF.Create("new-file", p9.ReadWrite, 0, 1, 2) - if err != nil { - t.Fatalf("create got %v, wanted nil", err) - } - return recv - }) - - // The above will fail via normal test flow, so - // we can assume that it passed. - return nil - }) - }) - } -} - -func TestCreateInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if _, _, _, _, err := root.Create(invalidName, p9.ReadWrite, 0, 0, 0); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestMkdir(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Mkdir("new-directory", p9.FileMode(0), p9.UID(1), p9.GID(2)) - } - _, err := f.Mkdir("new-directory", 0, 1, 2) - return err - }) - }) - } -} - -func TestMkdirInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if _, err := root.Mkdir(invalidName, 0, 0, 0); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestSymlink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Symlink("old-name", "new-name", p9.UID(1), p9.GID(2)) - } - _, err := f.Symlink("old-name", "new-name", 1, 2) - return err - }) - }) - } -} - -func TestSyminkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - // We need only test for invalid names in the new name, - // the target can be an arbitrary string and we don't - // need to sanity check it. - if _, err := root.Symlink("old-name", invalidName, 0, 0); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestLink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Link(gomock.Any(), "new-link") - } - return f.Link(f, "new-link") - }) - }) - } -} - -func TestLinkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.Link(root, invalidName); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestMknod(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Mknod("new-block-device", p9.FileMode(0), uint32(1), uint32(2), p9.UID(3), p9.GID(4)).Times(1) - } - _, err := f.Mknod("new-block-device", 0, 1, 2, 3, 4) - return err - }) - }) - } -} - -// concurrentFn is a specification of a concurrent operation. This is used to -// drive the concurrency tests below. -type concurrentFn struct { - name string - match func(p9.FileMode) bool - op func(h *Harness, backend *Mock, f p9.File, callback func()) -} - -func concurrentTest(t *testing.T, name string, fn1, fn2 concurrentFn, sameDir, expectedOkay bool) { - var ( - names1 []string - names2 []string - ) - if sameDir { - // Use the same file one directory up. - names1, names2 = []string{"one", name}, []string{"one", name} - } else { - // For different directories, just use siblings. - names1, names2 = []string{"one", name}, []string{"three", name} - } - - t.Run(fmt.Sprintf("%s(%v)+%s(%v)", fn1.name, names1, fn2.name, names2), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to both files as given. - _, f1, err := root.Walk(names1) - if err != nil { - t.Fatalf("error walking, got %v, want nil", err) - } - defer f1.Close() - b1 := h.Pop(f1) - _, f2, err := root.Walk(names2) - if err != nil { - t.Fatalf("error walking, got %v, want nil", err) - } - defer f2.Close() - b2 := h.Pop(f2) - - // Are these a good match for the current test case? - if !fn1.match(b1.Attr.Mode) { - t.SkipNow() - } - if !fn2.match(b2.Attr.Mode) { - t.SkipNow() - } - - // Construct our "concurrency creator". - in1 := make(chan struct{}, 1) - in2 := make(chan struct{}, 1) - var top sync.WaitGroup - var fns sync.WaitGroup - defer top.Wait() - top.Add(2) // Accounting for below. - defer fns.Done() - fns.Add(1) // See line above; released before top.Wait. - go func() { - defer top.Done() - fn1.op(h, b1, f1, func() { - in1 <- struct{}{} - fns.Wait() - }) - }() - go func() { - defer top.Done() - fn2.op(h, b2, f2, func() { - in2 <- struct{}{} - fns.Wait() - }) - }() - - // Compute a reasonable timeout. If we expect the operation to hang, - // give it 10 milliseconds before we assert that it's fine. After all, - // there will be a lot of these tests. If we don't expect it to hang, - // give it a full minute, since the machine could be slow. - timeout := 10 * time.Millisecond - if expectedOkay { - timeout = 1 * time.Minute - } - - // Read the first channel. - var second chan struct{} - select { - case <-in1: - second = in2 - case <-in2: - second = in1 - } - - // Catch concurrency. - select { - case <-second: - // We finished successful. Is this good? Depends on the - // expected result. - if !expectedOkay { - t.Errorf("%q and %q proceeded concurrently!", fn1.name, fn2.name) - } - case <-time.After(timeout): - // Great, things did not proceed concurrently. Is that what we - // expected? - if expectedOkay { - t.Errorf("%q and %q hung concurrently!", fn1.name, fn2.name) - } - } - }) -} - -func randomFileName() string { - return fmt.Sprintf("%x", rand.Int63()) -} - -func TestConcurrency(t *testing.T) { - readExclusive := []concurrentFn{ - { - // N.B. We can't explicitly check WalkGetAttr behavior, - // but we rely on the fact that the internal code paths - // are the same. - name: "walk", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // See the documentation of WalkCallback. - // Because walk is actually implemented by the - // mock, we need a special place for this - // callback. - // - // Note that a clone actually locks the parent - // node. So we walk from this node to test - // concurrent operations appropriately. - backend.WalkCallback = func() error { - callback() - return nil - } - f.Walk([]string{randomFileName()}) // Won't exist. - }, - }, - { - name: "fsync", - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()) - backend.EXPECT().FSync().Do(func() { - callback() - }) - f.Open(p9.ReadOnly) // Required. - f.FSync() - }, - }, - { - name: "readdir", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()) - backend.EXPECT().Readdir(gomock.Any(), gomock.Any()).Do(func(uint64, uint32) { - callback() - }) - f.Open(p9.ReadOnly) // Required. - f.Readdir(0, 1) - }, - }, - { - name: "readlink", - match: func(mode p9.FileMode) bool { return mode.IsSymlink() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Readlink().Do(func() { - callback() - }) - f.Readlink() - }, - }, - { - name: "connect", - match: func(mode p9.FileMode) bool { return mode.IsSocket() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Connect(gomock.Any()).Do(func(p9.ConnectFlags) { - callback() - }) - f.Connect(0) - }, - }, - { - name: "open", - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()).Do(func(p9.OpenFlags) { - callback() - }) - f.Open(p9.ReadOnly) - }, - }, - { - name: "flush", - match: func(mode p9.FileMode) bool { return true }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Flush().Do(func() { - callback() - }) - f.Flush() - }, - }, - } - writeExclusive := []concurrentFn{ - { - // N.B. We can't really check getattr. But this is an - // extremely low-risk function, it seems likely that - // this check is paranoid anyways. - name: "setattr", - match: func(mode p9.FileMode) bool { return true }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().SetAttr(gomock.Any(), gomock.Any()).Do(func(p9.SetAttrMask, p9.SetAttr) { - callback() - }) - f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}) - }, - }, - { - name: "unlinkAt", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) { - callback() - }) - f.UnlinkAt(randomFileName(), 0) - }, - }, - { - name: "mknod", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Mknod(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, uint32, uint32, p9.UID, p9.GID) { - callback() - }) - f.Mknod(randomFileName(), 0, 0, 0, 0, 0) - }, - }, - { - name: "link", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Link(gomock.Any(), gomock.Any()).Do(func(p9.File, string) { - callback() - }) - f.Link(f, randomFileName()) - }, - }, - { - name: "symlink", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Symlink(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, string, p9.UID, p9.GID) { - callback() - }) - f.Symlink(randomFileName(), randomFileName(), 0, 0) - }, - }, - { - name: "mkdir", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Mkdir(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, p9.UID, p9.GID) { - callback() - }) - f.Mkdir(randomFileName(), 0, 0, 0) - }, - }, - { - name: "create", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Return an error for the creation operation, as this is the simplest. - backend.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil, p9.QID{}, uint32(0), syscall.EINVAL).Do(func(string, p9.OpenFlags, p9.FileMode, p9.UID, p9.GID) { - callback() - }) - f.Create(randomFileName(), p9.ReadOnly, 0, 0, 0) - }, - }, - } - globalExclusive := []concurrentFn{ - { - name: "remove", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Remove operates on a locked parent. So we - // add a child, walk to it and call remove. - // Note that because this operation can operate - // concurrently with itself, we need to - // generate a random file name. - randomFile := randomFileName() - backend.AddChild(randomFile, h.NewFile()) - defer backend.RemoveChild(randomFile) - _, file, err := f.Walk([]string{randomFile}) - if err != nil { - h.t.Fatalf("walk got %v, want nil", err) - } - - // Remove is automatically translated to the parent. - backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) { - callback() - }) - - // Remove is also a close. - file.(deprecatedRemover).Remove() - }, - }, - { - name: "rename", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Similarly to remove, because we need to - // operate on a child, we allow a walk. - randomFile := randomFileName() - backend.AddChild(randomFile, h.NewFile()) - defer backend.RemoveChild(randomFile) - _, file, err := f.Walk([]string{randomFile}) - if err != nil { - h.t.Fatalf("walk got %v, want nil", err) - } - defer file.Close() - fileBackend := h.Pop(file) - - // Rename is automatically translated to the parent. - backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) { - callback() - }) - - // Attempt the rename. - fileBackend.EXPECT().Renamed(gomock.Any(), gomock.Any()) - file.Rename(f, randomFileName()) - }, - }, - { - name: "renameAt", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) { - callback() - }) - - // Attempt the rename. There are no active fids - // with this name, so we don't need to expect - // Renamed hooks on anything. - f.RenameAt(randomFileName(), f, randomFileName()) - }, - }, - } - - for _, fn1 := range readExclusive { - for _, fn2 := range readExclusive { - for name := range newTypeMap(nil) { - // Everything should be able to proceed in parallel. - concurrentTest(t, name, fn1, fn2, true, true) - concurrentTest(t, name, fn1, fn2, false, true) - } - } - } - - for _, fn1 := range append(readExclusive, writeExclusive...) { - for _, fn2 := range writeExclusive { - for name := range newTypeMap(nil) { - // Only cross-directory functions should proceed in parallel. - concurrentTest(t, name, fn1, fn2, true, false) - concurrentTest(t, name, fn1, fn2, false, true) - } - } - } - - for _, fn1 := range append(append(readExclusive, writeExclusive...), globalExclusive...) { - for _, fn2 := range globalExclusive { - for name := range newTypeMap(nil) { - // Nothing should be able to run in parallel. - concurrentTest(t, name, fn1, fn2, true, false) - concurrentTest(t, name, fn1, fn2, false, false) - } - } - } -} - -func TestReadWriteConcurrent(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - const ( - instances = 10 - iterations = 10000 - dataSize = 1024 - ) - var ( - dataSets [instances][dataSize]byte - backends [instances]*Mock - files [instances]p9.File - ) - - // Walk to the file normally. - for i := 0; i < instances; i++ { - _, backends[i], files[i] = walkHelper(h, "file", root) - defer files[i].Close() - } - - // Open the files. - for i := 0; i < instances; i++ { - backends[i].EXPECT().Open(p9.ReadWrite) - if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil { - t.Fatalf("open got %v, wanted nil", err) - } - } - - // Initialize random data for each instance. - for i := 0; i < instances; i++ { - if _, err := rand.Read(dataSets[i][:]); err != nil { - t.Fatalf("error initializing dataSet#%d, got %v", i, err) - } - } - - // Define our random read/write mechanism. - randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) { - // Prepare the backend. - backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - if n := copy(p, data); n != len(data) { - // Note that we have to assert the result here, as the Return statement - // below cannot be dynamic: it will be bound before this call is made. - h.t.Errorf("wanted length %d, got %d", len(data), n) - } - }).Return(len(data), nil) - - // Execute the read. - if n, err := f.ReadAt(test, 0); n != len(test) || err != nil { - t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err) - return // No sense doing check below. - } - if !bytes.Equal(test, data) { - t.Errorf("data integrity failed during read") // Not as expected. - } - } - randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) { - // Prepare the backend. - backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - if !bytes.Equal(p, data) { - h.t.Errorf("data integrity failed during write") // Not as expected. - } - }).Return(len(data), nil) - - // Execute the write. - if n, err := f.WriteAt(data, 0); n != len(data) || err != nil { - t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err) - } - } - randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) { - test := make([]byte, len(data)) - for i := 0; i < n; i++ { - if rand.Intn(2) == 0 { - randRead(h, backend, f, data, test) - } else { - randWrite(h, backend, f, data) - } - } - } - - // Start reading and writing. - var wg sync.WaitGroup - for i := 0; i < instances; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:]) - }(i) - } - wg.Wait() -} diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go deleted file mode 100644 index dd8b01b6d..000000000 --- a/pkg/p9/p9test/p9test.go +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package p9test provides standard mocks for p9. -package p9test - -import ( - "fmt" - "sync/atomic" - "syscall" - "testing" - - "github.com/golang/mock/gomock" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/unet" -) - -// Harness is an attacher mock. -type Harness struct { - t *testing.T - mockCtrl *gomock.Controller - Attacher *MockAttacher - wg sync.WaitGroup - clientSocket *unet.Socket - mu sync.Mutex - created []*Mock -} - -// globalPath is a QID.Path Generator. -var globalPath uint64 - -// MakePath returns a globally unique path. -func MakePath() uint64 { - return atomic.AddUint64(&globalPath, 1) -} - -// Generator is a function that generates a new file. -type Generator func(parent *Mock) *Mock - -// Mock is a common mock element. -type Mock struct { - p9.DefaultWalkGetAttr - *MockFile - parent *Mock - closed bool - harness *Harness - QID p9.QID - Attr p9.Attr - children map[string]Generator - - // WalkCallback is a special function that will be called from within - // the walk context. This is needed for the concurrent tests within - // this package. - WalkCallback func() error -} - -// globalMu protects the children maps in all mocks. Note that this is not a -// particularly elegant solution, but because the test has walks from the root -// through to final nodes, we must share maps below, and it's easiest to simply -// protect against concurrent access globally. -var globalMu sync.RWMutex - -// AddChild adds a new child to the Mock. -func (m *Mock) AddChild(name string, generator Generator) { - globalMu.Lock() - defer globalMu.Unlock() - m.children[name] = generator -} - -// RemoveChild removes the child with the given name. -func (m *Mock) RemoveChild(name string) { - globalMu.Lock() - defer globalMu.Unlock() - delete(m.children, name) -} - -// Matches implements gomock.Matcher.Matches. -func (m *Mock) Matches(x interface{}) bool { - if om, ok := x.(*Mock); ok { - return m.QID.Path == om.QID.Path - } - return false -} - -// String implements gomock.Matcher.String. -func (m *Mock) String() string { - return fmt.Sprintf("Mock{Mode: 0x%x, QID.Path: %d}", m.Attr.Mode, m.QID.Path) -} - -// GetAttr returns the current attributes. -func (m *Mock) GetAttr(mask p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) { - return m.QID, p9.AttrMaskAll(), m.Attr, nil -} - -// Walk supports clone and walking in directories. -func (m *Mock) Walk(names []string) ([]p9.QID, p9.File, error) { - if m.WalkCallback != nil { - if err := m.WalkCallback(); err != nil { - return nil, nil, err - } - } - if len(names) == 0 { - // Clone the file appropriately. - nm := m.harness.NewMock(m.parent, m.QID.Path, m.Attr) - nm.children = m.children // Inherit children. - return []p9.QID{nm.QID}, nm, nil - } else if len(names) != 1 { - m.harness.t.Fail() // Should not happen. - return nil, nil, syscall.EINVAL - } - - if m.Attr.Mode.IsDir() { - globalMu.RLock() - defer globalMu.RUnlock() - if fn, ok := m.children[names[0]]; ok { - // Generate the child. - nm := fn(m) - return []p9.QID{nm.QID}, nm, nil - } - // No child found. - return nil, nil, syscall.ENOENT - } - - // Call the underlying mock. - return m.MockFile.Walk(names) -} - -// WalkGetAttr calls the default implementation; this is a client-side optimization. -func (m *Mock) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) { - return m.DefaultWalkGetAttr.WalkGetAttr(names) -} - -// Pop pops off the most recently created Mock and assert that this mock -// represents the same file passed in. If nil is passed in, no check is -// performed. -// -// Precondition: there must be at least one Mock or this will panic. -func (h *Harness) Pop(clientFile p9.File) *Mock { - h.mu.Lock() - defer h.mu.Unlock() - - if clientFile == nil { - // If no clientFile is provided, then we always return the last - // created file. The caller can safely use this as long as - // there is no concurrency. - m := h.created[len(h.created)-1] - h.created = h.created[:len(h.created)-1] - return m - } - - qid, _, _, err := clientFile.GetAttr(p9.AttrMaskAll()) - if err != nil { - // We do not expect this to happen. - panic(fmt.Sprintf("err during Pop: %v", err)) - } - - // Find the relevant file in our created list. We must scan the last - // from back to front to ensure that we favor the most recently - // generated file. - for i := len(h.created) - 1; i >= 0; i-- { - m := h.created[i] - if qid.Path == m.QID.Path { - // Copy and truncate. - copy(h.created[i:], h.created[i+1:]) - h.created = h.created[:len(h.created)-1] - return m - } - } - - // Unable to find relevant file. - panic(fmt.Sprintf("unable to locate file with QID %+v", qid.Path)) -} - -// NewMock returns a new base file. -func (h *Harness) NewMock(parent *Mock, path uint64, attr p9.Attr) *Mock { - m := &Mock{ - MockFile: NewMockFile(h.mockCtrl), - parent: parent, - harness: h, - QID: p9.QID{ - Type: p9.QIDType((attr.Mode & p9.FileModeMask) >> 12), - Path: path, - }, - Attr: attr, - } - - // Always ensure Close is after the parent's close. Note that this - // can't be done via a straight-forward After call, because the parent - // might change after initial creation. We ensure that this is true at - // close time. - m.EXPECT().Close().Return(nil).Times(1).Do(func() { - if m.parent != nil && m.parent.closed { - h.t.FailNow() - } - // Note that this should not be racy, as this operation should - // be protected by the Times(1) above first. - m.closed = true - }) - - // Remember what was created. - h.mu.Lock() - defer h.mu.Unlock() - h.created = append(h.created, m) - - return m -} - -// NewFile returns a new file mock. -// -// Note that ReadAt and WriteAt must be mocked separately. -func (h *Harness) NewFile() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeRegular}) - } -} - -// NewDirectory returns a new mock directory. -// -// Note that Mkdir, Link, Mknod, RenameAt, UnlinkAt and Readdir must be mocked -// separately. Walk is provided and children may be manipulated via AddChild -// and RemoveChild. After calling Walk remotely, one can use Pop to find the -// corresponding backend mock on the server side. -func (h *Harness) NewDirectory(contents map[string]Generator) Generator { - return func(parent *Mock) *Mock { - m := h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeDirectory}) - m.children = contents // Save contents. - return m - } -} - -// NewSymlink returns a new mock directory. -// -// Note that Readlink must be mocked separately. -func (h *Harness) NewSymlink() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSymlink}) - } -} - -// NewBlockDevice returns a new mock block device. -func (h *Harness) NewBlockDevice() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeBlockDevice}) - } -} - -// NewCharacterDevice returns a new mock character device. -func (h *Harness) NewCharacterDevice() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeCharacterDevice}) - } -} - -// NewNamedPipe returns a new mock named pipe. -func (h *Harness) NewNamedPipe() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeNamedPipe}) - } -} - -// NewSocket returns a new mock socket. -func (h *Harness) NewSocket() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSocket}) - } -} - -// Finish completes all checks and shuts down the server. -func (h *Harness) Finish() { - h.clientSocket.Shutdown() - h.wg.Wait() - h.mockCtrl.Finish() -} - -// NewHarness creates and returns a new test server. -// -// It should always be used as: -// -// h, c := NewHarness(t) -// defer h.Finish() -// -func NewHarness(t *testing.T) (*Harness, *p9.Client) { - // Create the mock. - mockCtrl := gomock.NewController(t) - h := &Harness{ - t: t, - mockCtrl: mockCtrl, - Attacher: NewMockAttacher(mockCtrl), - } - - // Make socket pair. - serverSocket, clientSocket, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v wanted nil", err) - } - - // Start the server, synchronized on exit. - server := p9.NewServer(h.Attacher) - h.wg.Add(1) - go func() { - defer h.wg.Done() - server.Handle(serverSocket) - }() - - // Create the client. - client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString()) - if err != nil { - serverSocket.Close() - clientSocket.Close() - t.Fatalf("new client got %v, expected nil", err) - return nil, nil // Never hit. - } - - // Capture the client socket. - h.clientSocket = clientSocket - return h, client -} diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index a0d274f3b..a0d274f3b 100644..100755 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go deleted file mode 100644 index 3668fcad7..000000000 --- a/pkg/p9/transport_test.go +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "io/ioutil" - "os" - "testing" - - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/unet" -) - -const ( - MsgTypeBadEncode = iota + 252 - MsgTypeBadDecode - MsgTypeUnregistered -) - -func TestSendRecv(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &Tlopen{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - t.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - t.Fatalf("got tag %v expected 1", tag) - } - if _, ok := m.(*Tlopen); !ok { - t.Fatalf("got message %v expected *Tlopen", m) - } -} - -// badDecode overruns on decode. -type badDecode struct{} - -func (*badDecode) decode(b *buffer) { b.markOverrun() } -func (*badDecode) encode(b *buffer) {} -func (*badDecode) Type() MsgType { return MsgTypeBadDecode } -func (*badDecode) String() string { return "badDecode{}" } - -func TestRecvOverrun(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &badDecode{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil { - t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err) - } -} - -// unregistered is not registered on decode. -type unregistered struct{} - -func (*unregistered) decode(b *buffer) {} -func (*unregistered) encode(b *buffer) {} -func (*unregistered) Type() MsgType { return MsgTypeUnregistered } -func (*unregistered) String() string { return "unregistered{}" } - -func TestRecvInvalidType(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &unregistered{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - _, _, err = recv(server, maximumLength, msgRegistry.get) - if _, ok := err.(*ErrInvalidMsgType); !ok { - t.Fatalf("recv got err %v expected ErrInvalidMsgType", err) - } -} - -func TestSendRecvWithFile(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - // Create a tempfile. - osf, err := ioutil.TempFile("", "p9") - if err != nil { - t.Fatalf("tempfile got err %v expected nil", err) - } - os.Remove(osf.Name()) - f, err := fd.NewFromFile(osf) - osf.Close() - if err != nil { - t.Fatalf("unable to create file: %v", err) - } - - rlopen := &Rlopen{} - rlopen.SetFilePayload(f) - if err := send(client, Tag(1), rlopen); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - // Enable withFile. - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - t.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - t.Fatalf("got tag %v expected 1", tag) - } - rlopen, ok := m.(*Rlopen) - if !ok { - t.Fatalf("got m %v expected *Rlopen", m) - } - if rlopen.File == nil { - t.Fatalf("got nil file expected non-nil") - } -} - -func TestRecvClosed(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - client.Close() - - _, _, err = recv(server, maximumLength, msgRegistry.get) - if err == nil { - t.Fatalf("got err nil expected non-nil") - } - if _, ok := err.(ErrSocket); !ok { - t.Fatalf("got err %v expected ErrSocket", err) - } -} - -func TestSendClosed(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - server.Close() - defer client.Close() - - err = send(client, Tag(1), &Tlopen{}) - if err == nil { - t.Fatalf("send got err nil expected non-nil") - } - if _, ok := err.(ErrSocket); !ok { - t.Fatalf("got err %v expected ErrSocket", err) - } -} - -func BenchmarkSendRecv(b *testing.B) { - server, client, err := unet.SocketPair(false) - if err != nil { - b.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - // Exchange Rflush messages since these contain no data and therefore incur - // no additional marshaling overhead. - go func() { - for i := 0; i < b.N; i++ { - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - b.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - b.Fatalf("got tag %v expected 1", tag) - } - if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %T expected *Rflush", m) - } - if err := send(server, Tag(2), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) - } - } - }() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := send(client, Tag(1), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) - } - tag, m, err := recv(client, maximumLength, msgRegistry.get) - if err != nil { - b.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(2) { - b.Fatalf("got tag %v expected 2", tag) - } - if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %v expected *Rflush", m) - } - } -} - -func init() { - msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} }) -} diff --git a/pkg/p9/version_test.go b/pkg/p9/version_test.go deleted file mode 100644 index 291e8580e..000000000 --- a/pkg/p9/version_test.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "testing" -) - -func TestVersionNumberEquivalent(t *testing.T) { - for i := uint32(0); i < 1024; i++ { - str := versionString(i) - version, ok := parseVersion(str) - if !ok { - t.Errorf("#%d: parseVersion(%q) failed, want success", i, str) - continue - } - if i != version { - t.Errorf("#%d: got version %d, want %d", i, i, version) - } - } -} - -func TestVersionStringEquivalent(t *testing.T) { - // There is one case where the version is not equivalent on purpose, - // that is 9P2000.L.Google.0. It is not equivalent because versionString - // must always return the more generic 9P2000.L for legacy servers that - // check for it. See net/9p/client.c. - str := "9P2000.L.Google.0" - version, ok := parseVersion(str) - if !ok { - t.Errorf("parseVersion(%q) failed, want success", str) - } - if got := versionString(version); got != "9P2000.L" { - t.Errorf("versionString(%d) got %q, want %q", version, got, "9P2000.L") - } - - for _, test := range []struct { - versionString string - }{ - { - versionString: "9P2000.L", - }, - { - versionString: "9P2000.L.Google.1", - }, - { - versionString: "9P2000.L.Google.347823894", - }, - } { - version, ok := parseVersion(test.versionString) - if !ok { - t.Errorf("parseVersion(%q) failed, want success", test.versionString) - continue - } - if got := versionString(version); got != test.versionString { - t.Errorf("versionString(%d) got %q, want %q", version, got, test.versionString) - } - } -} - -func TestParseVersion(t *testing.T) { - for _, test := range []struct { - versionString string - expectSuccess bool - expectedVersion uint32 - }{ - { - versionString: "9P", - expectSuccess: false, - }, - { - versionString: "9P.L", - expectSuccess: false, - }, - { - versionString: "9P200.L", - expectSuccess: false, - }, - { - versionString: "9P2000", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.-1", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.3546343826724305832", - expectSuccess: false, - }, - { - versionString: "9P2001.L", - expectSuccess: false, - }, - { - versionString: "9P2000.L", - expectSuccess: true, - expectedVersion: 0, - }, - { - versionString: "9P2000.L.Google.0", - expectSuccess: true, - expectedVersion: 0, - }, - { - versionString: "9P2000.L.Google.1", - expectSuccess: true, - expectedVersion: 1, - }, - } { - version, ok := parseVersion(test.versionString) - if ok != test.expectSuccess { - t.Errorf("parseVersion(%q) got (_, %v), want (_, %v)", test.versionString, ok, test.expectSuccess) - continue - } - if !test.expectSuccess { - continue - } - if version != test.expectedVersion { - t.Errorf("parseVersion(%q) got (%d, _), want (%d, _)", test.versionString, version, test.expectedVersion) - } - } -} - -func BenchmarkParseVersion(b *testing.B) { - for n := 0; n < b.N; n++ { - parseVersion("9P2000.L.Google.1") - } -} diff --git a/pkg/pool/BUILD b/pkg/pool/BUILD deleted file mode 100644 index 7b1c6b75b..000000000 --- a/pkg/pool/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -go_library( - name = "pool", - srcs = [ - "pool.go", - ], - deps = [ - "//pkg/sync", - ], -) - -go_test( - name = "pool_test", - size = "small", - srcs = [ - "pool_test.go", - ], - library = ":pool", -) diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index a1b2e0cfe..a1b2e0cfe 100644..100755 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go diff --git a/pkg/pool/pool_state_autogen.go b/pkg/pool/pool_state_autogen.go new file mode 100755 index 000000000..1f4164c00 --- /dev/null +++ b/pkg/pool/pool_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pool diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go deleted file mode 100644 index d928439c1..000000000 --- a/pkg/pool/pool_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pool - -import ( - "testing" -) - -func TestPoolUnique(t *testing.T) { - p := Pool{Start: 1, Limit: 3} - got := make(map[uint64]bool) - - for { - n, ok := p.Get() - if !ok { - break - } - - // Check unique. - if _, ok := got[n]; ok { - t.Errorf("pool spit out %v multiple times", n) - } - - // Record. - got[n] = true - } -} - -func TestExausted(t *testing.T) { - p := Pool{Start: 1, Limit: 500} - for i := 0; i < 499; i++ { - _, ok := p.Get() - if !ok { - t.Fatalf("pool exhausted before 499 items") - } - } - - _, ok := p.Get() - if ok { - t.Errorf("pool not exhausted when it should be") - } -} - -func TestPoolRecycle(t *testing.T) { - p := Pool{Start: 1, Limit: 500} - n1, _ := p.Get() - p.Put(n1) - n2, _ := p.Get() - if n1 != n2 { - t.Errorf("pool not recycling items") - } -} diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD deleted file mode 100644 index aa3e3ac0b..000000000 --- a/pkg/procid/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "procid", - srcs = [ - "procid.go", - "procid_amd64.s", - "procid_arm64.s", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "procid_test", - size = "small", - srcs = [ - "procid_test.go", - ], - library = ":procid", - deps = ["//pkg/sync"], -) - -go_test( - name = "procid_net_test", - size = "small", - srcs = [ - "procid_net_test.go", - "procid_test.go", - ], - library = ":procid", - deps = ["//pkg/sync"], -) diff --git a/pkg/procid/procid_net_test.go b/pkg/procid/procid_net_test.go deleted file mode 100644 index b628e2285..000000000 --- a/pkg/procid/procid_net_test.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package procid - -// This file is just to force the inclusion of the "net" package, which will -// make the test binary a cgo one. -import ( - _ "net" -) diff --git a/pkg/procid/procid_state_autogen.go b/pkg/procid/procid_state_autogen.go new file mode 100755 index 000000000..662988d79 --- /dev/null +++ b/pkg/procid/procid_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package procid diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go deleted file mode 100644 index 9ec08c3d6..000000000 --- a/pkg/procid/procid_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package procid - -import ( - "os" - "runtime" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -// runOnMain is used to send functions to run on the main (initial) thread. -var runOnMain = make(chan func(), 10) - -func checkProcid(t *testing.T, start *sync.WaitGroup, done *sync.WaitGroup) { - defer done.Done() - - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - start.Done() - start.Wait() - - procID := Current() - tid := syscall.Gettid() - - if procID != uint64(tid) { - t.Logf("Bad procid: expected %v, got %v", tid, procID) - t.Fail() - } -} - -func TestProcidInitialized(t *testing.T) { - var start sync.WaitGroup - var done sync.WaitGroup - - count := 100 - start.Add(count + 1) - done.Add(count + 1) - - // Run the check on the main thread. - // - // When cgo is not included, the only case when procid isn't initialized - // is in the main (initial) thread, so we have to test this case - // specifically. - runOnMain <- func() { - checkProcid(t, &start, &done) - } - - // Run the check on a number of different threads. - for i := 0; i < count; i++ { - go checkProcid(t, &start, &done) - } - - done.Wait() -} - -func TestMain(m *testing.M) { - // Make sure we remain at the main (initial) thread. - runtime.LockOSThread() - - // Start running tests in a different goroutine. - go func() { - os.Exit(m.Run()) - }() - - // Execute any functions that have been sent for execution in the main - // thread. - for f := range runOnMain { - f() - } -} diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD deleted file mode 100644 index 80b8ceb02..000000000 --- a/pkg/rand/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "rand", - srcs = [ - "rand.go", - "rand_linux.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/rand/rand_linux_state_autogen.go b/pkg/rand/rand_linux_state_autogen.go new file mode 100755 index 000000000..f727c9314 --- /dev/null +++ b/pkg/rand/rand_linux_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package rand diff --git a/pkg/rand/rand_state_autogen.go b/pkg/rand/rand_state_autogen.go new file mode 100755 index 000000000..e0a5cd184 --- /dev/null +++ b/pkg/rand/rand_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build !linux + +package rand diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD deleted file mode 100644 index 74affc887..000000000 --- a/pkg/refs/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "weak_ref_list", - out = "weak_ref_list.go", - package = "refs", - prefix = "weakRef", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*WeakRef", - "Linker": "*WeakRef", - }, -) - -go_library( - name = "refs", - srcs = [ - "refcounter.go", - "refcounter_state.go", - "weak_ref_list.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/sync", - ], -) - -go_test( - name = "refs_test", - size = "small", - srcs = ["refcounter_test.go"], - library = ":refs", - deps = ["//pkg/sync"], -) diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go deleted file mode 100644 index 1ab4a4440..000000000 --- a/pkg/refs/refcounter_test.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package refs - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -type testCounter struct { - AtomicRefCount - - // mu protects the boolean below. - mu sync.Mutex - - // destroyed indicates whether this was destroyed. - destroyed bool -} - -func (t *testCounter) DecRef() { - t.AtomicRefCount.DecRefWithDestructor(t.destroy) -} - -func (t *testCounter) destroy() { - t.mu.Lock() - defer t.mu.Unlock() - t.destroyed = true -} - -func (t *testCounter) IsDestroyed() bool { - t.mu.Lock() - defer t.mu.Unlock() - return t.destroyed -} - -func newTestCounter() *testCounter { - return &testCounter{destroyed: false} -} - -func TestOneRef(t *testing.T) { - tc := newTestCounter() - tc.DecRef() - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestTwoRefs(t *testing.T) { - tc := newTestCounter() - tc.IncRef() - tc.DecRef() - tc.DecRef() - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestMultiRefs(t *testing.T) { - tc := newTestCounter() - tc.IncRef() - tc.DecRef() - - tc.IncRef() - tc.DecRef() - - tc.DecRef() - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestWeakRef(t *testing.T) { - tc := newTestCounter() - w := NewWeakRef(tc, nil) - - // Try resolving. - if x := w.Get(); x == nil { - t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) - } else { - x.DecRef() - } - - // Try resolving again. - if x := w.Get(); x == nil { - t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) - } else { - x.DecRef() - } - - // Shouldn't be destroyed yet. (Can't continue if this fails.) - if tc.IsDestroyed() { - t.Fatalf("original object destroyed earlier than expected") - } - - // Drop the original reference. - tc.DecRef() - - // Assert destroyed. - if !tc.IsDestroyed() { - t.Errorf("original object not destroyed as expected") - } - - // Shouldn't be anything. - if x := w.Get(); x != nil { - t.Errorf("weak reference resolved: expected nil, got %v", x) - } -} - -func TestWeakRefDrop(t *testing.T) { - tc := newTestCounter() - w := NewWeakRef(tc, nil) - w.Drop() - - // Just assert the list is empty. - if !tc.weakRefs.Empty() { - t.Errorf("weak reference not dropped") - } - - // Drop the original reference. - tc.DecRef() -} - -type testWeakRefUser struct { - weakRefGone func() -} - -func (u *testWeakRefUser) WeakRefGone() { - u.weakRefGone() -} - -func TestCallback(t *testing.T) { - called := false - tc := newTestCounter() - var w *WeakRef - w = NewWeakRef(tc, &testWeakRefUser{func() { - called = true - - // Check that the weak ref has been zapped. - rc := w.obj.Load().(RefCounter) - if v := reflect.ValueOf(rc); v != reflect.Zero(v.Type()) { - t.Fatalf("Callback called with non-nil ptr") - } - - // Check that we're not holding the mutex by acquiring and - // releasing it. - tc.mu.Lock() - tc.mu.Unlock() - }}) - - // Drop the original reference, this must trigger the callback. - tc.DecRef() - - if !called { - t.Fatalf("Callback not called") - } -} diff --git a/pkg/refs/refs_state_autogen.go b/pkg/refs/refs_state_autogen.go new file mode 100755 index 000000000..4c5591d30 --- /dev/null +++ b/pkg/refs/refs_state_autogen.go @@ -0,0 +1,81 @@ +// automatically generated by stateify. + +package refs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *WeakRef) beforeSave() {} +func (x *WeakRef) save(m state.Map) { + x.beforeSave() + var obj savedReference = x.saveObj() + m.SaveValue("obj", obj) + m.Save("user", &x.user) +} + +func (x *WeakRef) afterLoad() {} +func (x *WeakRef) load(m state.Map) { + m.Load("user", &x.user) + m.LoadValue("obj", new(savedReference), func(y interface{}) { x.loadObj(y.(savedReference)) }) +} + +func (x *AtomicRefCount) beforeSave() {} +func (x *AtomicRefCount) save(m state.Map) { + x.beforeSave() + m.Save("refCount", &x.refCount) + m.Save("name", &x.name) + m.Save("stack", &x.stack) +} + +func (x *AtomicRefCount) afterLoad() {} +func (x *AtomicRefCount) load(m state.Map) { + m.Load("refCount", &x.refCount) + m.Load("name", &x.name) + m.Load("stack", &x.stack) +} + +func (x *savedReference) beforeSave() {} +func (x *savedReference) save(m state.Map) { + x.beforeSave() + m.Save("obj", &x.obj) +} + +func (x *savedReference) afterLoad() {} +func (x *savedReference) load(m state.Map) { + m.Load("obj", &x.obj) +} + +func (x *weakRefList) beforeSave() {} +func (x *weakRefList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *weakRefList) afterLoad() {} +func (x *weakRefList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *weakRefEntry) beforeSave() {} +func (x *weakRefEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *weakRefEntry) afterLoad() {} +func (x *weakRefEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/refs.WeakRef", (*WeakRef)(nil), state.Fns{Save: (*WeakRef).save, Load: (*WeakRef).load}) + state.Register("pkg/refs.AtomicRefCount", (*AtomicRefCount)(nil), state.Fns{Save: (*AtomicRefCount).save, Load: (*AtomicRefCount).load}) + state.Register("pkg/refs.savedReference", (*savedReference)(nil), state.Fns{Save: (*savedReference).save, Load: (*savedReference).load}) + state.Register("pkg/refs.weakRefList", (*weakRefList)(nil), state.Fns{Save: (*weakRefList).save, Load: (*weakRefList).load}) + state.Register("pkg/refs.weakRefEntry", (*weakRefEntry)(nil), state.Fns{Save: (*weakRefEntry).save, Load: (*weakRefEntry).load}) +} diff --git a/pkg/refs/weak_ref_list.go b/pkg/refs/weak_ref_list.go new file mode 100755 index 000000000..1d0ae2099 --- /dev/null +++ b/pkg/refs/weak_ref_list.go @@ -0,0 +1,186 @@ +package refs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type weakRefElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (weakRefElementMapper) linkerFor(elem *WeakRef) *WeakRef { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type weakRefList struct { + head *WeakRef + tail *WeakRef +} + +// Reset resets list l to the empty state. +func (l *weakRefList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *weakRefList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *weakRefList) Front() *WeakRef { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *weakRefList) Back() *WeakRef { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *weakRefList) PushFront(e *WeakRef) { + linker := weakRefElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + weakRefElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *weakRefList) PushBack(e *WeakRef) { + linker := weakRefElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + weakRefElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *weakRefList) PushBackList(m *weakRefList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + weakRefElementMapper{}.linkerFor(l.tail).SetNext(m.head) + weakRefElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *weakRefList) InsertAfter(b, e *WeakRef) { + bLinker := weakRefElementMapper{}.linkerFor(b) + eLinker := weakRefElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + weakRefElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *weakRefList) InsertBefore(a, e *WeakRef) { + aLinker := weakRefElementMapper{}.linkerFor(a) + eLinker := weakRefElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + weakRefElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *weakRefList) Remove(e *WeakRef) { + linker := weakRefElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + weakRefElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + weakRefElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type weakRefEntry struct { + next *WeakRef + prev *WeakRef +} + +// Next returns the entry that follows e in the list. +func (e *weakRefEntry) Next() *WeakRef { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *weakRefEntry) Prev() *WeakRef { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *weakRefEntry) SetNext(elem *WeakRef) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *weakRefEntry) SetPrev(elem *WeakRef) { + e.prev = elem +} diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD deleted file mode 100644 index 426ef30c9..000000000 --- a/pkg/safecopy/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "safecopy", - srcs = [ - "atomic_amd64.s", - "atomic_arm64.s", - "memclr_amd64.s", - "memclr_arm64.s", - "memcpy_amd64.s", - "memcpy_arm64.s", - "safecopy.go", - "safecopy_unsafe.go", - "sighandler_amd64.s", - "sighandler_arm64.s", - ], - visibility = ["//:sandbox"], - deps = ["//pkg/syserror"], -) - -go_test( - name = "safecopy_test", - srcs = [ - "safecopy_test.go", - ], - library = ":safecopy", -) diff --git a/pkg/safecopy/LICENSE b/pkg/safecopy/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/safecopy/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s index a0cd78f33..a0cd78f33 100644..100755 --- a/pkg/safecopy/atomic_amd64.s +++ b/pkg/safecopy/atomic_amd64.s diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s index d58ed71f7..d58ed71f7 100644..100755 --- a/pkg/safecopy/atomic_arm64.s +++ b/pkg/safecopy/atomic_arm64.s diff --git a/pkg/safecopy/memclr_amd64.s b/pkg/safecopy/memclr_amd64.s index 64cf32f05..64cf32f05 100644..100755 --- a/pkg/safecopy/memclr_amd64.s +++ b/pkg/safecopy/memclr_amd64.s diff --git a/pkg/safecopy/memclr_arm64.s b/pkg/safecopy/memclr_arm64.s index 7361b9067..7361b9067 100644..100755 --- a/pkg/safecopy/memclr_arm64.s +++ b/pkg/safecopy/memclr_arm64.s diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s index 129691d68..129691d68 100644..100755 --- a/pkg/safecopy/memcpy_amd64.s +++ b/pkg/safecopy/memcpy_amd64.s diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s index e7e541565..e7e541565 100644..100755 --- a/pkg/safecopy/memcpy_arm64.s +++ b/pkg/safecopy/memcpy_arm64.s diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go index 2fb7e5809..2fb7e5809 100644..100755 --- a/pkg/safecopy/safecopy.go +++ b/pkg/safecopy/safecopy.go diff --git a/pkg/safecopy/safecopy_state_autogen.go b/pkg/safecopy/safecopy_state_autogen.go new file mode 100755 index 000000000..791eef959 --- /dev/null +++ b/pkg/safecopy/safecopy_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package safecopy diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go deleted file mode 100644 index 7f7f69d61..000000000 --- a/pkg/safecopy/safecopy_test.go +++ /dev/null @@ -1,629 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safecopy - -import ( - "bytes" - "fmt" - "io/ioutil" - "math/rand" - "os" - "runtime/debug" - "syscall" - "testing" - "unsafe" -) - -// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular -// dependency. -const pageSize = 4096 - -func initRandom(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func randBuf(size int) []byte { - b := make([]byte, size) - initRandom(b) - return b -} - -func TestCopyInSuccess(t *testing.T) { - // Test that CopyIn does not return an error when all pages are accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := CopyIn(b, unsafe.Pointer(&a[0])) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestCopyOutSuccess(t *testing.T) { - // Test that CopyOut does not return an error when all pages are - // accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := CopyOut(unsafe.Pointer(&b[0]), a) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestCopySuccess(t *testing.T) { - // Test that Copy does not return an error when all pages are accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := Copy(unsafe.Pointer(&b[0]), unsafe.Pointer(&a[0]), bufLen) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestZeroOutSuccess(t *testing.T) { - // Test that ZeroOut does not return an error when all pages are - // accessible. - const bufLen = 8192 - a := make([]byte, bufLen) - b := randBuf(bufLen) - - n, err := ZeroOut(unsafe.Pointer(&b[0]), bufLen) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestSwapUint32Success(t *testing.T) { - // Test that SwapUint32 does not return an error when the page is - // accessible. - before := uint32(rand.Int31()) - after := uint32(rand.Int31()) - val := before - - old, err := SwapUint32(unsafe.Pointer(&val), after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if val != after { - t.Errorf("Unexpected new value: got %v, want %v", val, after) - } -} - -func TestSwapUint32AlignmentError(t *testing.T) { - // Test that SwapUint32 returns an AlignmentError when passed an unaligned - // address. - data := make([]byte, 8) // 2 * sizeof(uint32). - alignedIndex := uintptr(0) - if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 { - alignedIndex = 4 - offset - } - ptr := unsafe.Pointer(&data[alignedIndex+1]) - want := AlignmentError{Addr: uintptr(ptr), Alignment: 4} - if _, err := SwapUint32(ptr, 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -func TestSwapUint64Success(t *testing.T) { - // Test that SwapUint64 does not return an error when the page is - // accessible. - before := uint64(rand.Int63()) - after := uint64(rand.Int63()) - // "The first word in ... an allocated struct or slice can be relied upon - // to be 64-bit aligned." - sync/atomic docs - data := new(struct{ val uint64 }) - data.val = before - - old, err := SwapUint64(unsafe.Pointer(&data.val), after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if data.val != after { - t.Errorf("Unexpected new value: got %v, want %v", data.val, after) - } -} - -func TestSwapUint64AlignmentError(t *testing.T) { - // Test that SwapUint64 returns an AlignmentError when passed an unaligned - // address. - data := make([]byte, 16) // 2 * sizeof(uint64). - alignedIndex := uintptr(0) - if offset := uintptr(unsafe.Pointer(&data[0])) % 8; offset != 0 { - alignedIndex = 8 - offset - } - ptr := unsafe.Pointer(&data[alignedIndex+1]) - want := AlignmentError{Addr: uintptr(ptr), Alignment: 8} - if _, err := SwapUint64(ptr, 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -func TestCompareAndSwapUint32Success(t *testing.T) { - // Test that CompareAndSwapUint32 does not return an error when the page is - // accessible. - before := uint32(rand.Int31()) - after := uint32(rand.Int31()) - val := before - - old, err := CompareAndSwapUint32(unsafe.Pointer(&val), before, after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if val != after { - t.Errorf("Unexpected new value: got %v, want %v", val, after) - } -} - -func TestCompareAndSwapUint32AlignmentError(t *testing.T) { - // Test that CompareAndSwapUint32 returns an AlignmentError when passed an - // unaligned address. - data := make([]byte, 8) // 2 * sizeof(uint32). - alignedIndex := uintptr(0) - if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 { - alignedIndex = 4 - offset - } - ptr := unsafe.Pointer(&data[alignedIndex+1]) - want := AlignmentError{Addr: uintptr(ptr), Alignment: 4} - if _, err := CompareAndSwapUint32(ptr, 0, 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -// withSegvErrorTestMapping calls fn with a two-page mapping. The first page -// contains random data, and the second page generates SIGSEGV when accessed. -func withSegvErrorTestMapping(t *testing.T, fn func(m []byte)) { - mapping, err := syscall.Mmap(-1, 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - } - defer syscall.Munmap(mapping) - if err := syscall.Mprotect(mapping[pageSize:], syscall.PROT_NONE); err != nil { - t.Fatalf("Mprotect failed: %v", err) - } - initRandom(mapping[:pageSize]) - - fn(mapping) -} - -// withBusErrorTestMapping calls fn with a two-page mapping. The first page -// contains random data, and the second page generates SIGBUS when accessed. -func withBusErrorTestMapping(t *testing.T, fn func(m []byte)) { - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - if err := f.Truncate(pageSize); err != nil { - t.Fatalf("Truncate failed: %v", err) - } - mapping, err := syscall.Mmap(int(f.Fd()), 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - } - defer syscall.Munmap(mapping) - initRandom(mapping[:pageSize]) - - fn(mapping) -} - -func TestCopyInSegvError(t *testing.T) { - // Test that CopyIn returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - dst := randBuf(pageSize) - n, err := CopyIn(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyInBusError(t *testing.T) { - // Test that CopyIn returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - dst := randBuf(pageSize) - n, err := CopyIn(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyOutSegvError(t *testing.T) { - // Test that CopyOut returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - src := randBuf(pageSize) - n, err := CopyOut(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyOutBusError(t *testing.T) { - // Test that CopyOut returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - src := randBuf(pageSize) - n, err := CopyOut(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopySourceSegvError(t *testing.T) { - // Test that Copy returns a SegvError when copying from a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - dst := randBuf(pageSize) - n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopySourceBusError(t *testing.T) { - // Test that Copy returns a BusError when copying from a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - dst := randBuf(pageSize) - n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyDestinationSegvError(t *testing.T) { - // Test that Copy returns a SegvError when copying to a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - src := randBuf(pageSize) - n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyDestinationBusError(t *testing.T) { - // Test that Copy returns a BusError when copying to a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - src := randBuf(pageSize) - n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestZeroOutSegvError(t *testing.T) { - // Test that ZeroOut returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - n, err := ZeroOut(dst, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { - t.Errorf("Non-zero bytes in written part of mapping: %v", got) - } - }) - }) - } -} - -func TestZeroOutBusError(t *testing.T) { - // Test that ZeroOut returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - n, err := ZeroOut(dst, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { - t.Errorf("Non-zero bytes in written part of mapping: %v", got) - } - }) - }) - } -} - -func TestSwapUint32SegvError(t *testing.T) { - // Test that SwapUint32 returns a SegvError when reaching a page that - // signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := SwapUint32(unsafe.Pointer(secondPage), 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint32BusError(t *testing.T) { - // Test that SwapUint32 returns a BusError when reaching a page that - // signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := SwapUint32(unsafe.Pointer(secondPage), 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint64SegvError(t *testing.T) { - // Test that SwapUint64 returns a SegvError when reaching a page that - // signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := SwapUint64(unsafe.Pointer(secondPage), 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint64BusError(t *testing.T) { - // Test that SwapUint64 returns a BusError when reaching a page that - // signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := SwapUint64(unsafe.Pointer(secondPage), 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestCompareAndSwapUint32SegvError(t *testing.T) { - // Test that CompareAndSwapUint32 returns a SegvError when reaching a page - // that signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestCompareAndSwapUint32BusError(t *testing.T) { - // Test that CompareAndSwapUint32 returns a BusError when reaching a page - // that signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func testCopy(dst, src []byte) (panicked bool) { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - debug.SetPanicOnFault(true) - copy(dst, src) - return -} - -func TestSegVOnMemmove(t *testing.T) { - // Test that SIGSEGVs received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - a, err := syscall.Mmap(-1, 0, bufLen, syscall.PROT_NONE, syscall.MAP_ANON|syscall.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer syscall.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} - -func TestSigbusOnMemmove(t *testing.T) { - // Test that SIGBUS received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - os.Remove(f.Name()) - defer f.Close() - - a, err := syscall.Mmap(int(f.Fd()), 0, bufLen, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer syscall.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index 41dd567f3..41dd567f3 100644..100755 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go diff --git a/pkg/safecopy/sighandler_amd64.s b/pkg/safecopy/sighandler_amd64.s index 475ae48e9..475ae48e9 100644..100755 --- a/pkg/safecopy/sighandler_amd64.s +++ b/pkg/safecopy/sighandler_amd64.s diff --git a/pkg/safecopy/sighandler_arm64.s b/pkg/safecopy/sighandler_arm64.s index 53e4ac2c1..53e4ac2c1 100644..100755 --- a/pkg/safecopy/sighandler_arm64.s +++ b/pkg/safecopy/sighandler_arm64.s diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD deleted file mode 100644 index ce30382ab..000000000 --- a/pkg/safemem/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "safemem", - srcs = [ - "block_unsafe.go", - "io.go", - "safemem.go", - "seq_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/safecopy", - ], -) - -go_test( - name = "safemem_test", - size = "small", - srcs = [ - "io_test.go", - "seq_test.go", - ], - library = ":safemem", -) diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go index e7fd30743..e7fd30743 100644..100755 --- a/pkg/safemem/block_unsafe.go +++ b/pkg/safemem/block_unsafe.go diff --git a/pkg/safemem/io.go b/pkg/safemem/io.go index f039a5c34..f039a5c34 100644..100755 --- a/pkg/safemem/io.go +++ b/pkg/safemem/io.go diff --git a/pkg/safemem/io_test.go b/pkg/safemem/io_test.go deleted file mode 100644 index 629741bee..000000000 --- a/pkg/safemem/io_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "io" - "testing" -) - -func makeBlocks(slices ...[]byte) []Block { - blocks := make([]Block, 0, len(slices)) - for _, s := range slices { - blocks = append(blocks, BlockFromSafeSlice(s)) - } - return blocks -} - -func TestFromIOReaderFullRead(t *testing.T) { - r := FromIOReader{bytes.NewBufferString("foobar")} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -type eofHidingReader struct { - Reader io.Reader -} - -func (r eofHidingReader) Read(dst []byte) (int, error) { - n, err := r.Reader.Read(dst) - if err == io.EOF { - return n, nil - } - return n, err -} - -func TestFromIOReaderPartialRead(t *testing.T) { - r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - // FromIOReader should stop after the eofHidingReader returns (1, nil) - // for a 3-byte read. - if wantN := uint64(4); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -type singleByteReader struct { - Reader io.Reader -} - -func (r singleByteReader) Read(dst []byte) (int, error) { - if len(dst) == 0 { - return r.Reader.Read(dst) - } - return r.Reader.Read(dst[:1]) -} - -func TestSingleByteReader(t *testing.T) { - r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - // FromIOReader should stop after the singleByteReader returns (1, nil) - // for a 3-byte read. - if wantN := uint64(1); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -func TestReadFullToBlocks(t *testing.T) { - r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts)) - // ReadFullToBlocks should call into FromIOReader => singleByteReader - // repeatedly until dsts is exhausted. - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -func TestFromIOWriterFullWrite(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{&dst} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -type limitedWriter struct { - Writer io.Writer - Done int - Limit int -} - -func (w *limitedWriter) Write(src []byte) (int, error) { - count := len(src) - if count > (w.Limit - w.Done) { - count = w.Limit - w.Done - } - n, err := w.Writer.Write(src[:count]) - w.Done += n - return n, err -} - -func TestFromIOWriterPartialWrite(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{&limitedWriter{&dst, 0, 4}} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - // FromIOWriter should stop after the limitedWriter returns (1, nil) for a - // 3-byte write. - if wantN := uint64(4); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -type singleByteWriter struct { - Writer io.Writer -} - -func (w singleByteWriter) Write(src []byte) (int, error) { - if len(src) == 0 { - return w.Writer.Write(src) - } - return w.Writer.Write(src[:1]) -} - -func TestSingleByteWriter(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{singleByteWriter{&dst}} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - // FromIOWriter should stop after the singleByteWriter returns (1, nil) - // for a 3-byte write. - if wantN := uint64(1); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestWriteFullToBlocks(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{singleByteWriter{&dst}} - n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs)) - // WriteFullToBlocks should call into FromIOWriter => singleByteWriter - // repeatedly until srcs is exhausted. - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} diff --git a/pkg/safemem/safemem.go b/pkg/safemem/safemem.go index 3e70d33a2..3e70d33a2 100644..100755 --- a/pkg/safemem/safemem.go +++ b/pkg/safemem/safemem.go diff --git a/pkg/safemem/safemem_state_autogen.go b/pkg/safemem/safemem_state_autogen.go new file mode 100755 index 000000000..66d53f22d --- /dev/null +++ b/pkg/safemem/safemem_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package safemem diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go deleted file mode 100644 index de34005e9..000000000 --- a/pkg/safemem/seq_test.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "reflect" - "testing" -) - -func TestBlockSeqOfEmptyBlock(t *testing.T) { - bs := BlockSeqOf(Block{}) - if !bs.IsEmpty() { - t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs) - } -} - -func TestBlockSeqOfNonemptyBlock(t *testing.T) { - b := BlockFromSafeSlice(make([]byte, 1)) - bs := BlockSeqOf(b) - if bs.IsEmpty() { - t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs) - } - if head := bs.Head(); head != b { - t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b) - } - if tail := bs.Tail(); !tail.IsEmpty() { - t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail) - } -} - -type blockSeqTest struct { - desc string - - pieces []string - haveOffset bool - offset uint64 - haveLimit bool - limit uint64 - - want string -} - -func (t blockSeqTest) NonEmptyByteSlices() [][]byte { - // t is a value, so we can mutate it freely. - slices := make([][]byte, 0, len(t.pieces)) - for _, str := range t.pieces { - if t.haveOffset { - strOff := t.offset - if strOff > uint64(len(str)) { - strOff = uint64(len(str)) - } - str = str[strOff:] - t.offset -= strOff - } - if t.haveLimit { - strLim := t.limit - if strLim > uint64(len(str)) { - strLim = uint64(len(str)) - } - str = str[:strLim] - t.limit -= strLim - } - if len(str) != 0 { - slices = append(slices, []byte(str)) - } - } - return slices -} - -func (t blockSeqTest) BlockSeq() BlockSeq { - blocks := make([]Block, 0, len(t.pieces)) - for _, str := range t.pieces { - blocks = append(blocks, BlockFromSafeSlice([]byte(str))) - } - bs := BlockSeqFromSlice(blocks) - if t.haveOffset { - bs = bs.DropFirst64(t.offset) - } - if t.haveLimit { - bs = bs.TakeFirst64(t.limit) - } - return bs -} - -var blockSeqTests = []blockSeqTest{ - { - desc: "Empty sequence", - }, - { - desc: "Sequence of length 1", - pieces: []string{"foobar"}, - want: "foobar", - }, - { - desc: "Sequence of length 2", - pieces: []string{"foo", "bar"}, - want: "foobar", - }, - { - desc: "Empty Blocks", - pieces: []string{"", "foo", "", "", "bar", ""}, - want: "foobar", - }, - { - desc: "Sequence with non-zero offset", - pieces: []string{"foo", "bar"}, - haveOffset: true, - offset: 2, - want: "obar", - }, - { - desc: "Sequence with non-maximal limit", - pieces: []string{"foo", "bar"}, - haveLimit: true, - limit: 5, - want: "fooba", - }, - { - desc: "Sequence with offset and limit", - pieces: []string{"foo", "bar"}, - haveOffset: true, - offset: 2, - haveLimit: true, - limit: 3, - want: "oba", - }, -} - -func TestBlockSeqNumBytes(t *testing.T) { - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - if got, want := test.BlockSeq().NumBytes(), uint64(len(test.want)); got != want { - t.Errorf("NumBytes: got %d, wanted %d", got, want) - } - }) - } -} - -func TestBlockSeqIterBlocks(t *testing.T) { - // Tests BlockSeq iteration using Head/Tail. - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - srcs := test.BlockSeq() - // "Note that a non-nil empty slice and a nil slice ... are not - // deeply equal." - reflect - slices := make([][]byte, 0, 0) - for !srcs.IsEmpty() { - src := srcs.Head() - slices = append(slices, src.ToSlice()) - nextSrcs := srcs.Tail() - if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-uint64(src.Len()); got != want { - t.Fatalf("%v.Tail(): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) - } - srcs = nextSrcs - } - if wantSlices := test.NonEmptyByteSlices(); !reflect.DeepEqual(slices, wantSlices) { - t.Errorf("Accumulated slices: got %v, wanted %v", slices, wantSlices) - } - }) - } -} - -func TestBlockSeqIterBytes(t *testing.T) { - // Tests BlockSeq iteration using Head/DropFirst. - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - srcs := test.BlockSeq() - var dst bytes.Buffer - for !srcs.IsEmpty() { - src := srcs.Head() - var b [1]byte - n, err := Copy(BlockFromSafeSlice(b[:]), src) - if n != 1 || err != nil { - t.Fatalf("Copy: got (%v, %v), wanted (1, nil)", n, err) - } - dst.WriteByte(b[0]) - nextSrcs := srcs.DropFirst(1) - if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-1; got != want { - t.Fatalf("%v.DropFirst(1): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) - } - srcs = nextSrcs - } - if got := string(dst.Bytes()); got != test.want { - t.Errorf("Copied string: got %q, wanted %q", got, test.want) - } - }) - } -} - -func TestBlockSeqDropBeyondLimit(t *testing.T) { - blocks := []Block{BlockFromSafeSlice([]byte("123")), BlockFromSafeSlice([]byte("4"))} - bs := BlockSeqFromSlice(blocks) - if got, want := bs.NumBytes(), uint64(4); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } - bs = bs.TakeFirst(1) - if got, want := bs.NumBytes(), uint64(1); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } - bs = bs.DropFirst(2) - if got, want := bs.NumBytes(), uint64(0); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } -} diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go index f5f0574f8..f5f0574f8 100644..100755 --- a/pkg/safemem/seq_unsafe.go +++ b/pkg/safemem/seq_unsafe.go diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD deleted file mode 100644 index c5fca2ba3..000000000 --- a/pkg/seccomp/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_binary", "go_embed_data", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_binary( - name = "victim", - testonly = 1, - srcs = ["seccomp_test_victim.go"], - deps = [":seccomp"], -) - -go_embed_data( - name = "victim_data", - testonly = 1, - src = "victim", - package = "seccomp", - var = "victimData", -) - -go_library( - name = "seccomp", - srcs = [ - "seccomp.go", - "seccomp_amd64.go", - "seccomp_arm64.go", - "seccomp_rules.go", - "seccomp_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/bpf", - "//pkg/log", - ], -) - -go_test( - name = "seccomp_test", - size = "small", - srcs = [ - "seccomp_test.go", - ":victim_data", - ], - library = ":seccomp", - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/bpf", - ], -) diff --git a/pkg/seccomp/seccomp_amd64_state_autogen.go b/pkg/seccomp/seccomp_amd64_state_autogen.go new file mode 100755 index 000000000..27a96018b --- /dev/null +++ b/pkg/seccomp/seccomp_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package seccomp diff --git a/pkg/seccomp/seccomp_arm64_state_autogen.go b/pkg/seccomp/seccomp_arm64_state_autogen.go new file mode 100755 index 000000000..96c64c23d --- /dev/null +++ b/pkg/seccomp/seccomp_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package seccomp diff --git a/pkg/seccomp/seccomp_state_autogen.go b/pkg/seccomp/seccomp_state_autogen.go new file mode 100755 index 000000000..e16b5d7c2 --- /dev/null +++ b/pkg/seccomp/seccomp_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package seccomp diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go deleted file mode 100644 index 88766f33b..000000000 --- a/pkg/seccomp/seccomp_test.go +++ /dev/null @@ -1,580 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seccomp - -import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "math" - "math/rand" - "os" - "os/exec" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/bpf" -) - -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - -// newVictim makes a victim binary. -func newVictim() (string, error) { - f, err := ioutil.TempFile("", "victim") - if err != nil { - return "", err - } - defer f.Close() - path := f.Name() - if _, err := io.Copy(f, bytes.NewBuffer(victimData)); err != nil { - os.Remove(path) - return "", err - } - if err := os.Chmod(path, 0755); err != nil { - os.Remove(path) - return "", err - } - return path, nil -} - -// asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} -} - -func TestBasic(t *testing.T) { - type spec struct { - // desc is the test's description. - desc string - - // data is the input data. - data seccompData - - // want is the expected return value of the BPF program. - want linux.BPFAction - } - - for _, test := range []struct { - ruleSets []RuleSet - defaultAction linux.BPFAction - specs []spec - }{ - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{1: {}}, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Single syscall allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Single syscall disallowed", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - AllowValue(0x1), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - { - Rules: SyscallRules{ - 1: {}, - 2: {}, - }, - Action: linux.SECCOMP_RET_TRAP, - }, - }, - defaultAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "Multiple rulesets allowed (1a)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x1}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Multiple rulesets allowed (1b)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple rulesets allowed (2)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple rulesets allowed (2)", - data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_KILL_THREAD, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - 3: {}, - 5: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Multiple syscalls allowed (1)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Multiple syscalls allowed (3)", - data: seccompData{nr: 3, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Multiple syscalls allowed (5)", - data: seccompData{nr: 5, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Multiple syscalls disallowed (0)", - data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple syscalls disallowed (2)", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple syscalls disallowed (4)", - data: seccompData{nr: 4, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple syscalls disallowed (6)", - data: seccompData{nr: 6, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple syscalls disallowed (100)", - data: seccompData{nr: 100, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Wrong architecture", - data: seccompData{nr: 1, arch: 123}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Syscall disallowed, action trap", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - AllowAny{}, - AllowValue(0xf), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xf}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Syscall argument disallowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xe}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - AllowValue(0xf), - }, - { - AllowValue(0xe), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Syscall argument allowed, two rules", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Syscall argument allowed, two rules", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xe}}, - want: linux.SECCOMP_RET_ALLOW, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - AllowValue(0), - AllowValue(math.MaxUint64 - 1), - AllowValue(math.MaxUint32), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "64bit syscall argument allowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, - }, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "64bit syscall argument disallowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "64bit syscall argument disallowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - GreaterThan(0xf), - GreaterThan(0xabcd000d), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "GreaterThan: Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xffffffff}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "GreaterThan: Syscall argument disallowed (equal)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Syscall argument disallowed (smaller)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x0, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "GreaterThan2: Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xfbcd000d}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "GreaterThan2: Syscall argument disallowed (equal)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xabcd000d}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "GreaterThan2: Syscall argument disallowed (smaller)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xa000ffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - RuleIP: AllowValue(0x7aabbccdd), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "IP: Syscall instruction pointer allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x7aabbccdd}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "IP: Syscall instruction pointer disallowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x711223344}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - } { - instrs, err := BuildProgram(test.ruleSets, test.defaultAction) - if err != nil { - t.Errorf("%s: buildProgram() got error: %v", test.specs[0].desc, err) - continue - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Errorf("%s: bpf.Compile() got error: %v", test.specs[0].desc, err) - continue - } - for _, spec := range test.specs { - got, err := bpf.Exec(p, spec.data.asInput()) - if err != nil { - t.Errorf("%s: bpf.Exec() got error: %v", spec.desc, err) - continue - } - if got != uint32(spec.want) { - t.Errorf("%s: bpd.Exec() = %d, want: %d", spec.desc, got, spec.want) - } - } - } -} - -// TestRandom tests that randomly generated rules are encoded correctly. -func TestRandom(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - size := rand.Intn(50) + 1 - syscallRules := make(map[uintptr][]Rule) - for len(syscallRules) < size { - n := uintptr(rand.Intn(200)) - if _, ok := syscallRules[n]; !ok { - syscallRules[n] = []Rule{} - } - } - - t.Logf("Testing filters: %v", syscallRules) - instrs, err := BuildProgram([]RuleSet{ - RuleSet{ - Rules: syscallRules, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, linux.SECCOMP_RET_TRAP) - if err != nil { - t.Fatalf("buildProgram() got error: %v", err) - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Fatalf("bpf.Compile() got error: %v", err) - } - for i := uint32(0); i < 200; i++ { - data := seccompData{nr: i, arch: linux.AUDIT_ARCH_X86_64} - got, err := bpf.Exec(p, data.asInput()) - if err != nil { - t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i) - continue - } - want := linux.SECCOMP_RET_TRAP - if _, ok := syscallRules[uintptr(i)]; ok { - want = linux.SECCOMP_RET_ALLOW - } - if got != uint32(want) { - t.Errorf("bpf.Exec() = %d, want: %d, for syscall %d", got, want, i) - } - } -} - -// TestReadDeal checks that a process dies when it trips over the filter and -// that it doesn't die when the filter is not triggered. -func TestRealDeal(t *testing.T) { - for _, test := range []struct { - die bool - want string - }{ - {die: true, want: "bad system call"}, - {die: false, want: "Syscall was allowed!!!"}, - } { - victim, err := newVictim() - if err != nil { - t.Fatalf("unable to get victim: %v", err) - } - defer os.Remove(victim) - dieFlag := fmt.Sprintf("-die=%v", test.die) - cmd := exec.Command(victim, dieFlag) - - out, err := cmd.CombinedOutput() - if test.die { - if err == nil { - t.Errorf("victim was not killed as expected, output: %s", out) - continue - } - // Depending on kernel version, either RET_TRAP or RET_KILL_PROCESS is - // used. RET_TRAP dumps reason for exit in output, while RET_KILL_PROCESS - // returns SIGSYS as exit status. - if !strings.Contains(string(out), test.want) && - !strings.Contains(err.Error(), test.want) { - t.Errorf("Victim error is wrong, got: %v, err: %v, want: %v", string(out), err, test.want) - continue - } - } else { - if err != nil { - t.Errorf("victim failed to execute, err: %v", err) - continue - } - if !strings.Contains(string(out), test.want) { - t.Errorf("Victim output is wrong, got: %v, want: %v", string(out), test.want) - continue - } - } - } -} - -// TestMerge ensures that empty rules are not erased when rules are merged. -func TestMerge(t *testing.T) { - for _, tst := range []struct { - name string - main []Rule - merge []Rule - want []Rule - }{ - { - name: "empty both", - main: nil, - merge: nil, - want: []Rule{{}, {}}, - }, - { - name: "empty main", - main: nil, - merge: []Rule{{}}, - want: []Rule{{}, {}}, - }, - { - name: "empty merge", - main: []Rule{{}}, - merge: nil, - want: []Rule{{}, {}}, - }, - } { - t.Run(tst.name, func(t *testing.T) { - mainRules := SyscallRules{1: tst.main} - mergeRules := SyscallRules{1: tst.merge} - mainRules.Merge(mergeRules) - if got, want := len(mainRules[1]), len(tst.want); got != want { - t.Errorf("wrong length, got: %d, want: %d", got, want) - } - for i, r := range mainRules[1] { - if r != tst.want[i] { - t.Errorf("result, got: %v, want: %v", r, tst.want[i]) - } - } - }) - } -} - -// TestAddRule ensures that empty rules are not erased when rules are added. -func TestAddRule(t *testing.T) { - rules := SyscallRules{1: {}} - rules.AddRule(1, Rule{}) - if got, want := len(rules[1]), 2; got != want { - t.Errorf("len(rules[1]), got: %d, want: %d", got, want) - } -} diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go deleted file mode 100644 index da6b9eaaf..000000000 --- a/pkg/seccomp/seccomp_test_victim.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Test binary used to test that seccomp filters are properly constructed and -// indeed kill the process on violation. -package main - -import ( - "flag" - "fmt" - "os" - "syscall" - - "gvisor.dev/gvisor/pkg/seccomp" -) - -func main() { - dieFlag := flag.Bool("die", false, "trips over the filter if true") - flag.Parse() - - syscalls := seccomp.SyscallRules{ - syscall.SYS_ACCEPT: {}, - syscall.SYS_ARCH_PRCTL: {}, - syscall.SYS_BIND: {}, - syscall.SYS_BRK: {}, - syscall.SYS_CLOCK_GETTIME: {}, - syscall.SYS_CLONE: {}, - syscall.SYS_CLOSE: {}, - syscall.SYS_DUP: {}, - syscall.SYS_DUP3: {}, - syscall.SYS_EPOLL_CREATE1: {}, - syscall.SYS_EPOLL_CTL: {}, - syscall.SYS_EPOLL_WAIT: {}, - syscall.SYS_EPOLL_PWAIT: {}, - syscall.SYS_EXIT: {}, - syscall.SYS_EXIT_GROUP: {}, - syscall.SYS_FALLOCATE: {}, - syscall.SYS_FCHMOD: {}, - syscall.SYS_FCNTL: {}, - syscall.SYS_FSTAT: {}, - syscall.SYS_FSYNC: {}, - syscall.SYS_FTRUNCATE: {}, - syscall.SYS_FUTEX: {}, - syscall.SYS_GETDENTS64: {}, - syscall.SYS_GETPEERNAME: {}, - syscall.SYS_GETPID: {}, - syscall.SYS_GETSOCKNAME: {}, - syscall.SYS_GETSOCKOPT: {}, - syscall.SYS_GETTID: {}, - syscall.SYS_GETTIMEOFDAY: {}, - syscall.SYS_LISTEN: {}, - syscall.SYS_LSEEK: {}, - syscall.SYS_MADVISE: {}, - syscall.SYS_MINCORE: {}, - syscall.SYS_MMAP: {}, - syscall.SYS_MPROTECT: {}, - syscall.SYS_MUNLOCK: {}, - syscall.SYS_MUNMAP: {}, - syscall.SYS_NANOSLEEP: {}, - syscall.SYS_NEWFSTATAT: {}, - syscall.SYS_OPEN: {}, - syscall.SYS_PPOLL: {}, - syscall.SYS_PREAD64: {}, - syscall.SYS_PSELECT6: {}, - syscall.SYS_PWRITE64: {}, - syscall.SYS_READ: {}, - syscall.SYS_READLINKAT: {}, - syscall.SYS_READV: {}, - syscall.SYS_RECVMSG: {}, - syscall.SYS_RENAMEAT: {}, - syscall.SYS_RESTART_SYSCALL: {}, - syscall.SYS_RT_SIGACTION: {}, - syscall.SYS_RT_SIGPROCMASK: {}, - syscall.SYS_RT_SIGRETURN: {}, - syscall.SYS_SCHED_YIELD: {}, - syscall.SYS_SENDMSG: {}, - syscall.SYS_SETITIMER: {}, - syscall.SYS_SET_ROBUST_LIST: {}, - syscall.SYS_SETSOCKOPT: {}, - syscall.SYS_SHUTDOWN: {}, - syscall.SYS_SIGALTSTACK: {}, - syscall.SYS_SOCKET: {}, - syscall.SYS_SYNC_FILE_RANGE: {}, - syscall.SYS_TGKILL: {}, - syscall.SYS_UTIMENSAT: {}, - syscall.SYS_WRITE: {}, - syscall.SYS_WRITEV: {}, - } - die := *dieFlag - if !die { - syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{ - { - seccomp.AllowValue(10), - }, - } - } - - if err := seccomp.Install(syscalls); err != nil { - fmt.Printf("Failed to install seccomp: %v", err) - os.Exit(1) - } - fmt.Printf("Filters installed\n") - - syscall.RawSyscall(syscall.SYS_OPENAT, 10, 0, 0) - fmt.Printf("Syscall was allowed!!!\n") -} diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD deleted file mode 100644 index 60f63c7a6..000000000 --- a/pkg/secio/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "secio", - srcs = [ - "full_reader.go", - "secio.go", - ], - visibility = ["//pkg/sentry:internal"], -) - -go_test( - name = "secio_test", - size = "small", - srcs = ["secio_test.go"], - library = ":secio", -) diff --git a/pkg/secio/secio_state_autogen.go b/pkg/secio/secio_state_autogen.go new file mode 100755 index 000000000..372ac4b92 --- /dev/null +++ b/pkg/secio/secio_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package secio diff --git a/pkg/secio/secio_test.go b/pkg/secio/secio_test.go deleted file mode 100644 index d1d905187..000000000 --- a/pkg/secio/secio_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package secio - -import ( - "bytes" - "errors" - "io" - "io/ioutil" - "math" - "testing" -) - -var errEndOfBuffer = errors.New("write beyond end of buffer") - -// buffer resembles bytes.Buffer, but implements io.ReaderAt and io.WriterAt. -// Reads beyond the end of the buffer return io.EOF. Writes beyond the end of -// the buffer return errEndOfBuffer. -type buffer struct { - Bytes []byte -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (b *buffer) ReadAt(dst []byte, off int64) (int, error) { - if off >= int64(len(b.Bytes)) { - return 0, io.EOF - } - n := copy(dst, b.Bytes[off:]) - if n < len(dst) { - return n, io.EOF - } - return n, nil -} - -// WriteAt implements io.WriterAt.WriteAt. -func (b *buffer) WriteAt(src []byte, off int64) (int, error) { - if off >= int64(len(b.Bytes)) { - return 0, errEndOfBuffer - } - n := copy(b.Bytes[off:], src) - if n < len(src) { - return n, errEndOfBuffer - } - return n, nil -} - -func newBufferString(s string) *buffer { - return &buffer{[]byte(s)} -} - -func TestOffsetReader(t *testing.T) { - buf := newBufferString("foobar") - r := NewOffsetReader(buf, 3) - dst, err := ioutil.ReadAll(r) - if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) - } -} - -func TestSectionReader(t *testing.T) { - buf := newBufferString("foobarbaz") - r := NewSectionReader(buf, 3, 3) - dst, err := ioutil.ReadAll(r) - if want, wantErr := []byte("bar"), ErrReachedLimit; !bytes.Equal(dst, want) || err != wantErr { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, %v)", dst, err, want, wantErr) - } -} - -func TestSectionReaderLimitOverflow(t *testing.T) { - // SectionReader behaves like OffsetReader when limit overflows int64. - buf := newBufferString("foobar") - r := NewSectionReader(buf, 3, math.MaxInt64) - dst, err := ioutil.ReadAll(r) - if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) - } -} - -func TestOffsetWriter(t *testing.T) { - buf := newBufferString("ABCDEF") - w := NewOffsetWriter(buf, 3) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} - -func TestSectionWriter(t *testing.T) { - buf := newBufferString("ABCDEFGHI") - w := NewSectionWriter(buf, 3, 3) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, ErrReachedLimit; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfooGHI"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} - -func TestSectionWriterLimitOverflow(t *testing.T) { - // SectionWriter behaves like OffsetWriter when limit overflows int64. - buf := newBufferString("ABCDEF") - w := NewSectionWriter(buf, 3, math.MaxInt64) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD deleted file mode 100644 index 1b487b887..000000000 --- a/pkg/segment/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("//tools/go_generics:defs.bzl", "go_template") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -go_template( - name = "generic_range", - srcs = ["range.go"], - types = [ - "T", - ], -) - -go_template( - name = "generic_set", - srcs = [ - "set.go", - "set_state.go", - ], - opt_consts = [ - "minDegree", - ], - types = [ - "Key", - "Range", - "Value", - "Functions", - ], -) diff --git a/pkg/segment/set_state.go b/pkg/segment/set_state.go deleted file mode 100644 index 76de92591..000000000 --- a/pkg/segment/set_state.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -func (s *Set) saveRoot() *SegmentDataSlices { - return s.ExportSortedSlices() -} - -func (s *Set) loadRoot(sds *SegmentDataSlices) { - if err := s.ImportSortedSlices(sds); err != nil { - panic(err) - } -} diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD deleted file mode 100644 index f2d8462d8..000000000 --- a/pkg/segment/test/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package( - default_visibility = ["//visibility:private"], - licenses = ["notice"], -) - -go_template_instance( - name = "int_range", - out = "int_range.go", - package = "segment", - template = "//pkg/segment:generic_range", - types = { - "T": "int", - }, -) - -go_template_instance( - name = "int_set", - out = "int_set.go", - package = "segment", - template = "//pkg/segment:generic_set", - types = { - "Key": "int", - "Range": "Range", - "Value": "int", - "Functions": "setFunctions", - }, -) - -go_library( - name = "segment", - testonly = 1, - srcs = [ - "int_range.go", - "int_set.go", - "set_functions.go", - ], - deps = [ - "//pkg/state", - ], -) - -go_test( - name = "segment_test", - size = "small", - srcs = ["segment_test.go"], - library = ":segment", -) diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go deleted file mode 100644 index f19a005f3..000000000 --- a/pkg/segment/test/segment_test.go +++ /dev/null @@ -1,564 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -import ( - "fmt" - "math/rand" - "testing" -) - -const ( - // testSize is the baseline number of elements inserted into sets under - // test, and is chosen to be large enough to ensure interesting amounts of - // tree rebalancing. - // - // Note that because checkSet is called between each insertion/removal in - // some tests that use it, tests may be quadratic in testSize. - testSize = 8000 - - // valueOffset is the difference between the value and start of test - // segments. - valueOffset = 100000 -) - -func shuffle(xs []int) { - for i := range xs { - j := rand.Intn(i + 1) - xs[i], xs[j] = xs[j], xs[i] - } -} - -func randPermutation(size int) []int { - p := make([]int, size) - for i := range p { - p[i] = i - } - shuffle(p) - return p -} - -// checkSet returns an error if s is incorrectly sorted, does not contain -// exactly expectedSegments segments, or contains a segment for which val != -// key + valueOffset. -func checkSet(s *Set, expectedSegments int) error { - havePrev := false - prev := 0 - nrSegments := 0 - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - next := seg.Start() - if havePrev && prev >= next { - return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) - } - if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start, got, want) - } - prev = next - havePrev = true - nrSegments++ - } - if nrSegments != expectedSegments { - return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) - } - return nil -} - -// countSegmentsIn returns the number of segments in s. -func countSegmentsIn(s *Set) int { - var count int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ - } - return count -} - -func TestAddRandom(t *testing.T) { - var s Set - order := randPermutation(testSize) - var nrInsertions int - for i, j := range order { - if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Insertion order: %v", order[:nrInsertions]) - t.Logf("Set contents:\n%v", &s) - } -} - -func TestRemoveRandom(t *testing.T) { - var s Set - for i := 0; i < testSize; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - } - order := randPermutation(testSize) - var nrRemovals int - for i, j := range order { - seg := s.FindSegment(j) - if !seg.Ok() { - t.Errorf("Iteration %d: failed to find segment with key %d", i, j) - break - } - s.Remove(seg) - nrRemovals++ - if err := checkSet(&s, testSize-nrRemovals); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := countSegmentsIn(&s), testSize-nrRemovals; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Removal order: %v", order[:nrRemovals]) - t.Logf("Set contents:\n%v", &s) - t.FailNow() - } -} - -func TestAddSequentialAdjacent(t *testing.T) { - var s Set - var nrInsertions int - for i := 0; i < testSize; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Set contents:\n%v", &s) - } - - first := s.FirstSegment() - gotSeg, gotGap := first.PrevNonEmpty() - if wantGap := s.FirstGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("FirstSegment().PrevNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", gotSeg, gotGap, wantGap) - } - gotSeg, gotGap = first.NextNonEmpty() - if wantSeg := first.NextSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("FirstSegment().NextNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", gotSeg, gotGap, wantSeg) - } - - last := s.LastSegment() - gotSeg, gotGap = last.PrevNonEmpty() - if wantSeg := last.PrevSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("LastSegment().PrevNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", gotSeg, gotGap, wantSeg) - } - gotSeg, gotGap = last.NextNonEmpty() - if wantGap := s.LastGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("LastSegment().NextNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", gotSeg, gotGap, wantGap) - } - - for seg := first.NextSegment(); seg != last; seg = seg.NextSegment() { - gotSeg, gotGap = seg.PrevNonEmpty() - if wantSeg := seg.PrevSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("%v.PrevNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", seg, gotSeg, gotGap, wantSeg) - } - gotSeg, gotGap = seg.NextNonEmpty() - if wantSeg := seg.NextSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("%v.NextNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", seg, gotSeg, gotGap, wantSeg) - } - } -} - -func TestAddSequentialNonAdjacent(t *testing.T) { - var s Set - var nrInsertions int - for i := 0; i < testSize; i++ { - // The range here differs from TestAddSequentialAdjacent so that - // consecutive segments are not adjacent. - if !s.AddWithoutMerging(Range{2 * i, 2*i + 1}, 2*i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Set contents:\n%v", &s) - } - - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - gotSeg, gotGap := seg.PrevNonEmpty() - if wantGap := seg.PrevGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("%v.PrevNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", seg, gotSeg, gotGap, wantGap) - } - gotSeg, gotGap = seg.NextNonEmpty() - if wantGap := seg.NextGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("%v.NextNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", seg, gotSeg, gotGap, wantGap) - } - } -} - -func TestMergeSplit(t *testing.T) { - tests := []struct { - name string - initial []Range - split bool - splitAddr int - final []Range - }{ - { - name: "Add merges after existing segment", - initial: []Range{{1000, 1100}, {1100, 1200}}, - final: []Range{{1000, 1200}}, - }, - { - name: "Add merges before existing segment", - initial: []Range{{1100, 1200}, {1000, 1100}}, - final: []Range{{1000, 1200}}, - }, - { - name: "Add merges between existing segments", - initial: []Range{{1000, 1100}, {1200, 1300}, {1100, 1200}}, - final: []Range{{1000, 1300}}, - }, - { - name: "SplitAt does nothing at a free address", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 300, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt does nothing at the beginning of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 100, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt does nothing at the end of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 200, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt splits in the middle of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 150, - final: []Range{{100, 150}, {150, 200}}, - }, - } -Tests: - for _, test := range tests { - var s Set - for _, r := range test.initial { - if !s.Add(r, 0) { - t.Errorf("%s: Add(%v) failed; set contents:\n%v", test.name, r, &s) - continue Tests - } - } - if test.split { - s.SplitAt(test.splitAddr) - } - var i int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) - continue Tests - } - if got, want := seg.Range(), test.final[i]; got != want { - t.Errorf("%s: Segment %d mismatch: got %v, wanted %v; set contents:\n%v", test.name, i, got, want, &s) - continue Tests - } - i++ - } - if i < len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, i, len(test.final), &s) - } - } -} - -func TestIsolate(t *testing.T) { - tests := []struct { - name string - initial Range - bounds Range - final []Range - }{ - { - name: "Isolate does not split a segment that falls inside bounds", - initial: Range{100, 200}, - bounds: Range{100, 200}, - final: []Range{{100, 200}}, - }, - { - name: "Isolate splits at beginning of segment", - initial: Range{50, 200}, - bounds: Range{100, 200}, - final: []Range{{50, 100}, {100, 200}}, - }, - { - name: "Isolate splits at end of segment", - initial: Range{100, 250}, - bounds: Range{100, 200}, - final: []Range{{100, 200}, {200, 250}}, - }, - { - name: "Isolate splits at beginning and end of segment", - initial: Range{50, 250}, - bounds: Range{100, 200}, - final: []Range{{50, 100}, {100, 200}, {200, 250}}, - }, - } -Tests: - for _, test := range tests { - var s Set - seg := s.Insert(s.FirstGap(), test.initial, 0) - seg = s.Isolate(seg, test.bounds) - if !test.bounds.IsSupersetOf(seg.Range()) { - t.Errorf("%s: Isolated segment %v lies outside bounds %v; set contents:\n%v", test.name, seg.Range(), test.bounds, &s) - } - var i int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) - continue Tests - } - if got, want := seg.Range(), test.final[i]; got != want { - t.Errorf("%s: Segment %d mismatch: got %v, wanted %v; set contents:\n%v", test.name, i, got, want, &s) - continue Tests - } - i++ - } - if i < len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, i, len(test.final), &s) - } - } -} - -func benchmarkAddSequential(b *testing.B, size int) { - for n := 0; n < b.N; n++ { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - } -} - -func benchmarkAddRandom(b *testing.B, size int) { - order := randPermutation(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - var s Set - for _, i := range order { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - } -} - -func benchmarkFindSequential(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - - b.ResetTimer() - for n := 0; n < b.N; n++ { - for i := 0; i < size; i++ { - if seg := s.FindSegment(i); !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - } - } -} - -func benchmarkFindRandom(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - order := randPermutation(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - for _, i := range order { - if si := s.FindSegment(i); !si.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - } - } -} - -func benchmarkIteration(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - - b.ResetTimer() - var count uint64 - for n := 0; n < b.N; n++ { - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ - } - } - if got, want := count, uint64(size)*uint64(b.N); got != want { - b.Fatalf("Iterated wrong number of segments: got %d, wanted %d", got, want) - } -} - -func benchmarkAddFindRemoveSequential(b *testing.B, size int) { - for n := 0; n < b.N; n++ { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - for i := 0; i < size; i++ { - seg := s.FindSegment(i) - if !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - s.Remove(seg) - } - if !s.IsEmpty() { - b.Fatalf("Set not empty after all removals:\n%v", &s) - } - } -} - -func benchmarkAddFindRemoveRandom(b *testing.B, size int) { - order := randPermutation(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - var s Set - for _, i := range order { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - for _, i := range order { - seg := s.FindSegment(i) - if !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - s.Remove(seg) - } - if !s.IsEmpty() { - b.Fatalf("Set not empty after all removals:\n%v", &s) - } - } -} - -// Although we don't generally expect our segment sets to get this big, they're -// useful for emulating the effect of cache pressure. -var testSizes = []struct { - desc string - size int -}{ - {"64", 1 << 6}, - {"256", 1 << 8}, - {"1K", 1 << 10}, - {"4K", 1 << 12}, - {"16K", 1 << 14}, - {"64K", 1 << 16}, -} - -func BenchmarkAddSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddSequential(b, test.size) - }) - } -} - -func BenchmarkAddRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddRandom(b, test.size) - }) - } -} - -func BenchmarkFindSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkFindSequential(b, test.size) - }) - } -} - -func BenchmarkFindRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkFindRandom(b, test.size) - }) - } -} - -func BenchmarkIteration(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkIteration(b, test.size) - }) - } -} - -func BenchmarkAddFindRemoveSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddFindRemoveSequential(b, test.size) - }) - } -} - -func BenchmarkAddFindRemoveRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddFindRemoveRandom(b, test.size) - }) - } -} diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go deleted file mode 100644 index bcddb39bb..000000000 --- a/pkg/segment/test/set_functions.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -// Basic numeric constants that we define because the math package doesn't. -// TODO(nlacasse): These should be Math.MaxInt64/MinInt64? -const ( - maxInt = int(^uint(0) >> 1) - minInt = -maxInt - 1 -) - -type setFunctions struct{} - -func (setFunctions) MinKey() int { - return minInt -} - -func (setFunctions) MaxKey() int { - return maxInt -} - -func (setFunctions) ClearValue(*int) {} - -func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) { - return val1, true -} - -func (setFunctions) Split(_ Range, val int, _ int) (int, int) { - return val, val -} diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD deleted file mode 100644 index e759dc36f..000000000 --- a/pkg/sentry/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -package(licenses = ["notice"]) - -# The "internal" package_group should be used as much as possible by packages -# that should remain Sentry-internal (i.e. not be exposed directly to command -# line tooling or APIs). -package_group( - name = "internal", - packages = [ - "//pkg/sentry/...", - "//runsc/...", - # Code generated by go_marshal relies on go_marshal libraries. - "//tools/go_marshal/...", - ], -) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD deleted file mode 100644 index e27f21e5e..000000000 --- a/pkg/sentry/arch/BUILD +++ /dev/null @@ -1,48 +0,0 @@ -load("//tools:defs.bzl", "go_library", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "arch", - srcs = [ - "aligned.go", - "arch.go", - "arch_aarch64.go", - "arch_amd64.go", - "arch_amd64.s", - "arch_arm64.go", - "arch_state_aarch64.go", - "arch_state_x86.go", - "arch_x86.go", - "arch_x86_impl.go", - "auxv.go", - "signal.go", - "signal_act.go", - "signal_amd64.go", - "signal_arm64.go", - "signal_info.go", - "signal_stack.go", - "stack.go", - "syscalls_amd64.go", - "syscalls_arm64.go", - ], - visibility = ["//:sandbox"], - deps = [ - ":registers_go_proto", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/cpuid", - "//pkg/log", - "//pkg/sentry/limits", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) - -proto_library( - name = "registers", - srcs = ["registers.proto"], - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index b998f84fc..b998f84fc 100644..100755 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go diff --git a/pkg/sentry/arch/arch_aarch64_state_autogen.go b/pkg/sentry/arch/arch_aarch64_state_autogen.go new file mode 100755 index 000000000..9c6dfdf2e --- /dev/null +++ b/pkg/sentry/arch/arch_aarch64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build arm64 +// +build arm64 + +package arch diff --git a/pkg/sentry/arch/arch_amd64_state_autogen.go b/pkg/sentry/arch/arch_amd64_state_autogen.go new file mode 100755 index 000000000..73c523c90 --- /dev/null +++ b/pkg/sentry/arch/arch_amd64_state_autogen.go @@ -0,0 +1,28 @@ +// automatically generated by stateify. + +// +build amd64 +// +build amd64 +// +build amd64 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *context64) beforeSave() {} +func (x *context64) save(m state.Map) { + x.beforeSave() + m.Save("State", &x.State) + m.Save("sigFPState", &x.sigFPState) +} + +func (x *context64) afterLoad() {} +func (x *context64) load(m state.Map) { + m.Load("State", &x.State) + m.Load("sigFPState", &x.sigFPState) +} + +func init() { + state.Register("pkg/sentry/arch.context64", (*context64)(nil), state.Fns{Save: (*context64).save, Load: (*context64).load}) +} diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index db99c5acb..db99c5acb 100644..100755 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go diff --git a/pkg/sentry/arch/arch_arm64_state_autogen.go b/pkg/sentry/arch/arch_arm64_state_autogen.go new file mode 100755 index 000000000..49f2e3d67 --- /dev/null +++ b/pkg/sentry/arch/arch_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package arch diff --git a/pkg/sentry/arch/arch_impl_state_autogen.go b/pkg/sentry/arch/arch_impl_state_autogen.go new file mode 100755 index 000000000..8b567801f --- /dev/null +++ b/pkg/sentry/arch/arch_impl_state_autogen.go @@ -0,0 +1,29 @@ +// automatically generated by stateify. + +// +build amd64 i386 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *State) beforeSave() {} +func (x *State) save(m state.Map) { + x.beforeSave() + var Regs syscallPtraceRegs = x.saveRegs() + m.SaveValue("Regs", Regs) + m.Save("x86FPState", &x.x86FPState) + m.Save("FeatureSet", &x.FeatureSet) +} + +func (x *State) load(m state.Map) { + m.LoadWait("x86FPState", &x.x86FPState) + m.Load("FeatureSet", &x.FeatureSet) + m.LoadValue("Regs", new(syscallPtraceRegs), func(y interface{}) { x.loadRegs(y.(syscallPtraceRegs)) }) + m.AfterLoad(x.afterLoad) +} + +func init() { + state.Register("pkg/sentry/arch.State", (*State)(nil), state.Fns{Save: (*State).save, Load: (*State).load}) +} diff --git a/pkg/sentry/arch/arch_state_aarch64.go b/pkg/sentry/arch/arch_state_aarch64.go index 0136a85ad..0136a85ad 100644..100755 --- a/pkg/sentry/arch/arch_state_aarch64.go +++ b/pkg/sentry/arch/arch_state_aarch64.go diff --git a/pkg/sentry/arch/arch_state_autogen.go b/pkg/sentry/arch/arch_state_autogen.go new file mode 100755 index 000000000..a06d96c71 --- /dev/null +++ b/pkg/sentry/arch/arch_state_autogen.go @@ -0,0 +1,166 @@ +// automatically generated by stateify. + +// +build amd64 i386 +// +build amd64 i386 +// +build i386 amd64 arm64 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *MmapLayout) beforeSave() {} +func (x *MmapLayout) save(m state.Map) { + x.beforeSave() + m.Save("MinAddr", &x.MinAddr) + m.Save("MaxAddr", &x.MaxAddr) + m.Save("BottomUpBase", &x.BottomUpBase) + m.Save("TopDownBase", &x.TopDownBase) + m.Save("DefaultDirection", &x.DefaultDirection) + m.Save("MaxStackRand", &x.MaxStackRand) +} + +func (x *MmapLayout) afterLoad() {} +func (x *MmapLayout) load(m state.Map) { + m.Load("MinAddr", &x.MinAddr) + m.Load("MaxAddr", &x.MaxAddr) + m.Load("BottomUpBase", &x.BottomUpBase) + m.Load("TopDownBase", &x.TopDownBase) + m.Load("DefaultDirection", &x.DefaultDirection) + m.Load("MaxStackRand", &x.MaxStackRand) +} + +func (x *syscallPtraceRegs) beforeSave() {} +func (x *syscallPtraceRegs) save(m state.Map) { + x.beforeSave() + m.Save("R15", &x.R15) + m.Save("R14", &x.R14) + m.Save("R13", &x.R13) + m.Save("R12", &x.R12) + m.Save("Rbp", &x.Rbp) + m.Save("Rbx", &x.Rbx) + m.Save("R11", &x.R11) + m.Save("R10", &x.R10) + m.Save("R9", &x.R9) + m.Save("R8", &x.R8) + m.Save("Rax", &x.Rax) + m.Save("Rcx", &x.Rcx) + m.Save("Rdx", &x.Rdx) + m.Save("Rsi", &x.Rsi) + m.Save("Rdi", &x.Rdi) + m.Save("Orig_rax", &x.Orig_rax) + m.Save("Rip", &x.Rip) + m.Save("Cs", &x.Cs) + m.Save("Eflags", &x.Eflags) + m.Save("Rsp", &x.Rsp) + m.Save("Ss", &x.Ss) + m.Save("Fs_base", &x.Fs_base) + m.Save("Gs_base", &x.Gs_base) + m.Save("Ds", &x.Ds) + m.Save("Es", &x.Es) + m.Save("Fs", &x.Fs) + m.Save("Gs", &x.Gs) +} + +func (x *syscallPtraceRegs) afterLoad() {} +func (x *syscallPtraceRegs) load(m state.Map) { + m.Load("R15", &x.R15) + m.Load("R14", &x.R14) + m.Load("R13", &x.R13) + m.Load("R12", &x.R12) + m.Load("Rbp", &x.Rbp) + m.Load("Rbx", &x.Rbx) + m.Load("R11", &x.R11) + m.Load("R10", &x.R10) + m.Load("R9", &x.R9) + m.Load("R8", &x.R8) + m.Load("Rax", &x.Rax) + m.Load("Rcx", &x.Rcx) + m.Load("Rdx", &x.Rdx) + m.Load("Rsi", &x.Rsi) + m.Load("Rdi", &x.Rdi) + m.Load("Orig_rax", &x.Orig_rax) + m.Load("Rip", &x.Rip) + m.Load("Cs", &x.Cs) + m.Load("Eflags", &x.Eflags) + m.Load("Rsp", &x.Rsp) + m.Load("Ss", &x.Ss) + m.Load("Fs_base", &x.Fs_base) + m.Load("Gs_base", &x.Gs_base) + m.Load("Ds", &x.Ds) + m.Load("Es", &x.Es) + m.Load("Fs", &x.Fs) + m.Load("Gs", &x.Gs) +} + +func (x *AuxEntry) beforeSave() {} +func (x *AuxEntry) save(m state.Map) { + x.beforeSave() + m.Save("Key", &x.Key) + m.Save("Value", &x.Value) +} + +func (x *AuxEntry) afterLoad() {} +func (x *AuxEntry) load(m state.Map) { + m.Load("Key", &x.Key) + m.Load("Value", &x.Value) +} + +func (x *SignalAct) beforeSave() {} +func (x *SignalAct) save(m state.Map) { + x.beforeSave() + m.Save("Handler", &x.Handler) + m.Save("Flags", &x.Flags) + m.Save("Restorer", &x.Restorer) + m.Save("Mask", &x.Mask) +} + +func (x *SignalAct) afterLoad() {} +func (x *SignalAct) load(m state.Map) { + m.Load("Handler", &x.Handler) + m.Load("Flags", &x.Flags) + m.Load("Restorer", &x.Restorer) + m.Load("Mask", &x.Mask) +} + +func (x *SignalStack) beforeSave() {} +func (x *SignalStack) save(m state.Map) { + x.beforeSave() + m.Save("Addr", &x.Addr) + m.Save("Flags", &x.Flags) + m.Save("Size", &x.Size) +} + +func (x *SignalStack) afterLoad() {} +func (x *SignalStack) load(m state.Map) { + m.Load("Addr", &x.Addr) + m.Load("Flags", &x.Flags) + m.Load("Size", &x.Size) +} + +func (x *SignalInfo) beforeSave() {} +func (x *SignalInfo) save(m state.Map) { + x.beforeSave() + m.Save("Signo", &x.Signo) + m.Save("Errno", &x.Errno) + m.Save("Code", &x.Code) + m.Save("Fields", &x.Fields) +} + +func (x *SignalInfo) afterLoad() {} +func (x *SignalInfo) load(m state.Map) { + m.Load("Signo", &x.Signo) + m.Load("Errno", &x.Errno) + m.Load("Code", &x.Code) + m.Load("Fields", &x.Fields) +} + +func init() { + state.Register("pkg/sentry/arch.MmapLayout", (*MmapLayout)(nil), state.Fns{Save: (*MmapLayout).save, Load: (*MmapLayout).load}) + state.Register("pkg/sentry/arch.syscallPtraceRegs", (*syscallPtraceRegs)(nil), state.Fns{Save: (*syscallPtraceRegs).save, Load: (*syscallPtraceRegs).load}) + state.Register("pkg/sentry/arch.AuxEntry", (*AuxEntry)(nil), state.Fns{Save: (*AuxEntry).save, Load: (*AuxEntry).load}) + state.Register("pkg/sentry/arch.SignalAct", (*SignalAct)(nil), state.Fns{Save: (*SignalAct).save, Load: (*SignalAct).load}) + state.Register("pkg/sentry/arch.SignalStack", (*SignalStack)(nil), state.Fns{Save: (*SignalStack).save, Load: (*SignalStack).load}) + state.Register("pkg/sentry/arch.SignalInfo", (*SignalInfo)(nil), state.Fns{Save: (*SignalInfo).save, Load: (*SignalInfo).load}) +} diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go index 04ac283c6..04ac283c6 100644..100755 --- a/pkg/sentry/arch/arch_x86_impl.go +++ b/pkg/sentry/arch/arch_x86_impl.go diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto deleted file mode 100644 index 60c027aab..000000000 --- a/pkg/sentry/arch/registers.proto +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -message AMD64Registers { - uint64 rax = 1; - uint64 rbx = 2; - uint64 rcx = 3; - uint64 rdx = 4; - uint64 rsi = 5; - uint64 rdi = 6; - uint64 rsp = 7; - uint64 rbp = 8; - - uint64 r8 = 9; - uint64 r9 = 10; - uint64 r10 = 11; - uint64 r11 = 12; - uint64 r12 = 13; - uint64 r13 = 14; - uint64 r14 = 15; - uint64 r15 = 16; - - uint64 rip = 17; - uint64 rflags = 18; - uint64 orig_rax = 19; - uint64 cs = 20; - uint64 ds = 21; - uint64 es = 22; - uint64 fs = 23; - uint64 gs = 24; - uint64 ss = 25; - uint64 fs_base = 26; - uint64 gs_base = 27; -} - -message ARM64Registers { - uint64 r0 = 1; - uint64 r1 = 2; - uint64 r2 = 3; - uint64 r3 = 4; - uint64 r4 = 5; - uint64 r5 = 6; - uint64 r6 = 7; - uint64 r7 = 8; - uint64 r8 = 9; - uint64 r9 = 10; - uint64 r10 = 11; - uint64 r11 = 12; - uint64 r12 = 13; - uint64 r13 = 14; - uint64 r14 = 15; - uint64 r15 = 16; - uint64 r16 = 17; - uint64 r17 = 18; - uint64 r18 = 19; - uint64 r19 = 20; - uint64 r20 = 21; - uint64 r21 = 22; - uint64 r22 = 23; - uint64 r23 = 24; - uint64 r24 = 25; - uint64 r25 = 26; - uint64 r26 = 27; - uint64 r27 = 28; - uint64 r28 = 29; - uint64 r29 = 30; - uint64 r30 = 31; - uint64 sp = 32; - uint64 pc = 33; - uint64 pstate = 34; -} -message Registers { - oneof arch { - AMD64Registers amd64 = 1; - ARM64Registers arm64 = 2; - } -} diff --git a/pkg/sentry/arch/registers_go_proto/registers.pb.go b/pkg/sentry/arch/registers_go_proto/registers.pb.go new file mode 100755 index 000000000..c4e9584b5 --- /dev/null +++ b/pkg/sentry/arch/registers_go_proto/registers.pb.go @@ -0,0 +1,697 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/arch/registers.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type AMD64Registers struct { + Rax uint64 `protobuf:"varint,1,opt,name=rax,proto3" json:"rax,omitempty"` + Rbx uint64 `protobuf:"varint,2,opt,name=rbx,proto3" json:"rbx,omitempty"` + Rcx uint64 `protobuf:"varint,3,opt,name=rcx,proto3" json:"rcx,omitempty"` + Rdx uint64 `protobuf:"varint,4,opt,name=rdx,proto3" json:"rdx,omitempty"` + Rsi uint64 `protobuf:"varint,5,opt,name=rsi,proto3" json:"rsi,omitempty"` + Rdi uint64 `protobuf:"varint,6,opt,name=rdi,proto3" json:"rdi,omitempty"` + Rsp uint64 `protobuf:"varint,7,opt,name=rsp,proto3" json:"rsp,omitempty"` + Rbp uint64 `protobuf:"varint,8,opt,name=rbp,proto3" json:"rbp,omitempty"` + R8 uint64 `protobuf:"varint,9,opt,name=r8,proto3" json:"r8,omitempty"` + R9 uint64 `protobuf:"varint,10,opt,name=r9,proto3" json:"r9,omitempty"` + R10 uint64 `protobuf:"varint,11,opt,name=r10,proto3" json:"r10,omitempty"` + R11 uint64 `protobuf:"varint,12,opt,name=r11,proto3" json:"r11,omitempty"` + R12 uint64 `protobuf:"varint,13,opt,name=r12,proto3" json:"r12,omitempty"` + R13 uint64 `protobuf:"varint,14,opt,name=r13,proto3" json:"r13,omitempty"` + R14 uint64 `protobuf:"varint,15,opt,name=r14,proto3" json:"r14,omitempty"` + R15 uint64 `protobuf:"varint,16,opt,name=r15,proto3" json:"r15,omitempty"` + Rip uint64 `protobuf:"varint,17,opt,name=rip,proto3" json:"rip,omitempty"` + Rflags uint64 `protobuf:"varint,18,opt,name=rflags,proto3" json:"rflags,omitempty"` + OrigRax uint64 `protobuf:"varint,19,opt,name=orig_rax,json=origRax,proto3" json:"orig_rax,omitempty"` + Cs uint64 `protobuf:"varint,20,opt,name=cs,proto3" json:"cs,omitempty"` + Ds uint64 `protobuf:"varint,21,opt,name=ds,proto3" json:"ds,omitempty"` + Es uint64 `protobuf:"varint,22,opt,name=es,proto3" json:"es,omitempty"` + Fs uint64 `protobuf:"varint,23,opt,name=fs,proto3" json:"fs,omitempty"` + Gs uint64 `protobuf:"varint,24,opt,name=gs,proto3" json:"gs,omitempty"` + Ss uint64 `protobuf:"varint,25,opt,name=ss,proto3" json:"ss,omitempty"` + FsBase uint64 `protobuf:"varint,26,opt,name=fs_base,json=fsBase,proto3" json:"fs_base,omitempty"` + GsBase uint64 `protobuf:"varint,27,opt,name=gs_base,json=gsBase,proto3" json:"gs_base,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AMD64Registers) Reset() { *m = AMD64Registers{} } +func (m *AMD64Registers) String() string { return proto.CompactTextString(m) } +func (*AMD64Registers) ProtoMessage() {} +func (*AMD64Registers) Descriptor() ([]byte, []int) { + return fileDescriptor_082b7510610e0457, []int{0} +} + +func (m *AMD64Registers) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AMD64Registers.Unmarshal(m, b) +} +func (m *AMD64Registers) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AMD64Registers.Marshal(b, m, deterministic) +} +func (m *AMD64Registers) XXX_Merge(src proto.Message) { + xxx_messageInfo_AMD64Registers.Merge(m, src) +} +func (m *AMD64Registers) XXX_Size() int { + return xxx_messageInfo_AMD64Registers.Size(m) +} +func (m *AMD64Registers) XXX_DiscardUnknown() { + xxx_messageInfo_AMD64Registers.DiscardUnknown(m) +} + +var xxx_messageInfo_AMD64Registers proto.InternalMessageInfo + +func (m *AMD64Registers) GetRax() uint64 { + if m != nil { + return m.Rax + } + return 0 +} + +func (m *AMD64Registers) GetRbx() uint64 { + if m != nil { + return m.Rbx + } + return 0 +} + +func (m *AMD64Registers) GetRcx() uint64 { + if m != nil { + return m.Rcx + } + return 0 +} + +func (m *AMD64Registers) GetRdx() uint64 { + if m != nil { + return m.Rdx + } + return 0 +} + +func (m *AMD64Registers) GetRsi() uint64 { + if m != nil { + return m.Rsi + } + return 0 +} + +func (m *AMD64Registers) GetRdi() uint64 { + if m != nil { + return m.Rdi + } + return 0 +} + +func (m *AMD64Registers) GetRsp() uint64 { + if m != nil { + return m.Rsp + } + return 0 +} + +func (m *AMD64Registers) GetRbp() uint64 { + if m != nil { + return m.Rbp + } + return 0 +} + +func (m *AMD64Registers) GetR8() uint64 { + if m != nil { + return m.R8 + } + return 0 +} + +func (m *AMD64Registers) GetR9() uint64 { + if m != nil { + return m.R9 + } + return 0 +} + +func (m *AMD64Registers) GetR10() uint64 { + if m != nil { + return m.R10 + } + return 0 +} + +func (m *AMD64Registers) GetR11() uint64 { + if m != nil { + return m.R11 + } + return 0 +} + +func (m *AMD64Registers) GetR12() uint64 { + if m != nil { + return m.R12 + } + return 0 +} + +func (m *AMD64Registers) GetR13() uint64 { + if m != nil { + return m.R13 + } + return 0 +} + +func (m *AMD64Registers) GetR14() uint64 { + if m != nil { + return m.R14 + } + return 0 +} + +func (m *AMD64Registers) GetR15() uint64 { + if m != nil { + return m.R15 + } + return 0 +} + +func (m *AMD64Registers) GetRip() uint64 { + if m != nil { + return m.Rip + } + return 0 +} + +func (m *AMD64Registers) GetRflags() uint64 { + if m != nil { + return m.Rflags + } + return 0 +} + +func (m *AMD64Registers) GetOrigRax() uint64 { + if m != nil { + return m.OrigRax + } + return 0 +} + +func (m *AMD64Registers) GetCs() uint64 { + if m != nil { + return m.Cs + } + return 0 +} + +func (m *AMD64Registers) GetDs() uint64 { + if m != nil { + return m.Ds + } + return 0 +} + +func (m *AMD64Registers) GetEs() uint64 { + if m != nil { + return m.Es + } + return 0 +} + +func (m *AMD64Registers) GetFs() uint64 { + if m != nil { + return m.Fs + } + return 0 +} + +func (m *AMD64Registers) GetGs() uint64 { + if m != nil { + return m.Gs + } + return 0 +} + +func (m *AMD64Registers) GetSs() uint64 { + if m != nil { + return m.Ss + } + return 0 +} + +func (m *AMD64Registers) GetFsBase() uint64 { + if m != nil { + return m.FsBase + } + return 0 +} + +func (m *AMD64Registers) GetGsBase() uint64 { + if m != nil { + return m.GsBase + } + return 0 +} + +type ARM64Registers struct { + R0 uint64 `protobuf:"varint,1,opt,name=r0,proto3" json:"r0,omitempty"` + R1 uint64 `protobuf:"varint,2,opt,name=r1,proto3" json:"r1,omitempty"` + R2 uint64 `protobuf:"varint,3,opt,name=r2,proto3" json:"r2,omitempty"` + R3 uint64 `protobuf:"varint,4,opt,name=r3,proto3" json:"r3,omitempty"` + R4 uint64 `protobuf:"varint,5,opt,name=r4,proto3" json:"r4,omitempty"` + R5 uint64 `protobuf:"varint,6,opt,name=r5,proto3" json:"r5,omitempty"` + R6 uint64 `protobuf:"varint,7,opt,name=r6,proto3" json:"r6,omitempty"` + R7 uint64 `protobuf:"varint,8,opt,name=r7,proto3" json:"r7,omitempty"` + R8 uint64 `protobuf:"varint,9,opt,name=r8,proto3" json:"r8,omitempty"` + R9 uint64 `protobuf:"varint,10,opt,name=r9,proto3" json:"r9,omitempty"` + R10 uint64 `protobuf:"varint,11,opt,name=r10,proto3" json:"r10,omitempty"` + R11 uint64 `protobuf:"varint,12,opt,name=r11,proto3" json:"r11,omitempty"` + R12 uint64 `protobuf:"varint,13,opt,name=r12,proto3" json:"r12,omitempty"` + R13 uint64 `protobuf:"varint,14,opt,name=r13,proto3" json:"r13,omitempty"` + R14 uint64 `protobuf:"varint,15,opt,name=r14,proto3" json:"r14,omitempty"` + R15 uint64 `protobuf:"varint,16,opt,name=r15,proto3" json:"r15,omitempty"` + R16 uint64 `protobuf:"varint,17,opt,name=r16,proto3" json:"r16,omitempty"` + R17 uint64 `protobuf:"varint,18,opt,name=r17,proto3" json:"r17,omitempty"` + R18 uint64 `protobuf:"varint,19,opt,name=r18,proto3" json:"r18,omitempty"` + R19 uint64 `protobuf:"varint,20,opt,name=r19,proto3" json:"r19,omitempty"` + R20 uint64 `protobuf:"varint,21,opt,name=r20,proto3" json:"r20,omitempty"` + R21 uint64 `protobuf:"varint,22,opt,name=r21,proto3" json:"r21,omitempty"` + R22 uint64 `protobuf:"varint,23,opt,name=r22,proto3" json:"r22,omitempty"` + R23 uint64 `protobuf:"varint,24,opt,name=r23,proto3" json:"r23,omitempty"` + R24 uint64 `protobuf:"varint,25,opt,name=r24,proto3" json:"r24,omitempty"` + R25 uint64 `protobuf:"varint,26,opt,name=r25,proto3" json:"r25,omitempty"` + R26 uint64 `protobuf:"varint,27,opt,name=r26,proto3" json:"r26,omitempty"` + R27 uint64 `protobuf:"varint,28,opt,name=r27,proto3" json:"r27,omitempty"` + R28 uint64 `protobuf:"varint,29,opt,name=r28,proto3" json:"r28,omitempty"` + R29 uint64 `protobuf:"varint,30,opt,name=r29,proto3" json:"r29,omitempty"` + R30 uint64 `protobuf:"varint,31,opt,name=r30,proto3" json:"r30,omitempty"` + Sp uint64 `protobuf:"varint,32,opt,name=sp,proto3" json:"sp,omitempty"` + Pc uint64 `protobuf:"varint,33,opt,name=pc,proto3" json:"pc,omitempty"` + Pstate uint64 `protobuf:"varint,34,opt,name=pstate,proto3" json:"pstate,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ARM64Registers) Reset() { *m = ARM64Registers{} } +func (m *ARM64Registers) String() string { return proto.CompactTextString(m) } +func (*ARM64Registers) ProtoMessage() {} +func (*ARM64Registers) Descriptor() ([]byte, []int) { + return fileDescriptor_082b7510610e0457, []int{1} +} + +func (m *ARM64Registers) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ARM64Registers.Unmarshal(m, b) +} +func (m *ARM64Registers) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ARM64Registers.Marshal(b, m, deterministic) +} +func (m *ARM64Registers) XXX_Merge(src proto.Message) { + xxx_messageInfo_ARM64Registers.Merge(m, src) +} +func (m *ARM64Registers) XXX_Size() int { + return xxx_messageInfo_ARM64Registers.Size(m) +} +func (m *ARM64Registers) XXX_DiscardUnknown() { + xxx_messageInfo_ARM64Registers.DiscardUnknown(m) +} + +var xxx_messageInfo_ARM64Registers proto.InternalMessageInfo + +func (m *ARM64Registers) GetR0() uint64 { + if m != nil { + return m.R0 + } + return 0 +} + +func (m *ARM64Registers) GetR1() uint64 { + if m != nil { + return m.R1 + } + return 0 +} + +func (m *ARM64Registers) GetR2() uint64 { + if m != nil { + return m.R2 + } + return 0 +} + +func (m *ARM64Registers) GetR3() uint64 { + if m != nil { + return m.R3 + } + return 0 +} + +func (m *ARM64Registers) GetR4() uint64 { + if m != nil { + return m.R4 + } + return 0 +} + +func (m *ARM64Registers) GetR5() uint64 { + if m != nil { + return m.R5 + } + return 0 +} + +func (m *ARM64Registers) GetR6() uint64 { + if m != nil { + return m.R6 + } + return 0 +} + +func (m *ARM64Registers) GetR7() uint64 { + if m != nil { + return m.R7 + } + return 0 +} + +func (m *ARM64Registers) GetR8() uint64 { + if m != nil { + return m.R8 + } + return 0 +} + +func (m *ARM64Registers) GetR9() uint64 { + if m != nil { + return m.R9 + } + return 0 +} + +func (m *ARM64Registers) GetR10() uint64 { + if m != nil { + return m.R10 + } + return 0 +} + +func (m *ARM64Registers) GetR11() uint64 { + if m != nil { + return m.R11 + } + return 0 +} + +func (m *ARM64Registers) GetR12() uint64 { + if m != nil { + return m.R12 + } + return 0 +} + +func (m *ARM64Registers) GetR13() uint64 { + if m != nil { + return m.R13 + } + return 0 +} + +func (m *ARM64Registers) GetR14() uint64 { + if m != nil { + return m.R14 + } + return 0 +} + +func (m *ARM64Registers) GetR15() uint64 { + if m != nil { + return m.R15 + } + return 0 +} + +func (m *ARM64Registers) GetR16() uint64 { + if m != nil { + return m.R16 + } + return 0 +} + +func (m *ARM64Registers) GetR17() uint64 { + if m != nil { + return m.R17 + } + return 0 +} + +func (m *ARM64Registers) GetR18() uint64 { + if m != nil { + return m.R18 + } + return 0 +} + +func (m *ARM64Registers) GetR19() uint64 { + if m != nil { + return m.R19 + } + return 0 +} + +func (m *ARM64Registers) GetR20() uint64 { + if m != nil { + return m.R20 + } + return 0 +} + +func (m *ARM64Registers) GetR21() uint64 { + if m != nil { + return m.R21 + } + return 0 +} + +func (m *ARM64Registers) GetR22() uint64 { + if m != nil { + return m.R22 + } + return 0 +} + +func (m *ARM64Registers) GetR23() uint64 { + if m != nil { + return m.R23 + } + return 0 +} + +func (m *ARM64Registers) GetR24() uint64 { + if m != nil { + return m.R24 + } + return 0 +} + +func (m *ARM64Registers) GetR25() uint64 { + if m != nil { + return m.R25 + } + return 0 +} + +func (m *ARM64Registers) GetR26() uint64 { + if m != nil { + return m.R26 + } + return 0 +} + +func (m *ARM64Registers) GetR27() uint64 { + if m != nil { + return m.R27 + } + return 0 +} + +func (m *ARM64Registers) GetR28() uint64 { + if m != nil { + return m.R28 + } + return 0 +} + +func (m *ARM64Registers) GetR29() uint64 { + if m != nil { + return m.R29 + } + return 0 +} + +func (m *ARM64Registers) GetR30() uint64 { + if m != nil { + return m.R30 + } + return 0 +} + +func (m *ARM64Registers) GetSp() uint64 { + if m != nil { + return m.Sp + } + return 0 +} + +func (m *ARM64Registers) GetPc() uint64 { + if m != nil { + return m.Pc + } + return 0 +} + +func (m *ARM64Registers) GetPstate() uint64 { + if m != nil { + return m.Pstate + } + return 0 +} + +type Registers struct { + // Types that are valid to be assigned to Arch: + // *Registers_Amd64 + // *Registers_Arm64 + Arch isRegisters_Arch `protobuf_oneof:"arch"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Registers) Reset() { *m = Registers{} } +func (m *Registers) String() string { return proto.CompactTextString(m) } +func (*Registers) ProtoMessage() {} +func (*Registers) Descriptor() ([]byte, []int) { + return fileDescriptor_082b7510610e0457, []int{2} +} + +func (m *Registers) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Registers.Unmarshal(m, b) +} +func (m *Registers) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Registers.Marshal(b, m, deterministic) +} +func (m *Registers) XXX_Merge(src proto.Message) { + xxx_messageInfo_Registers.Merge(m, src) +} +func (m *Registers) XXX_Size() int { + return xxx_messageInfo_Registers.Size(m) +} +func (m *Registers) XXX_DiscardUnknown() { + xxx_messageInfo_Registers.DiscardUnknown(m) +} + +var xxx_messageInfo_Registers proto.InternalMessageInfo + +type isRegisters_Arch interface { + isRegisters_Arch() +} + +type Registers_Amd64 struct { + Amd64 *AMD64Registers `protobuf:"bytes,1,opt,name=amd64,proto3,oneof"` +} + +type Registers_Arm64 struct { + Arm64 *ARM64Registers `protobuf:"bytes,2,opt,name=arm64,proto3,oneof"` +} + +func (*Registers_Amd64) isRegisters_Arch() {} + +func (*Registers_Arm64) isRegisters_Arch() {} + +func (m *Registers) GetArch() isRegisters_Arch { + if m != nil { + return m.Arch + } + return nil +} + +func (m *Registers) GetAmd64() *AMD64Registers { + if x, ok := m.GetArch().(*Registers_Amd64); ok { + return x.Amd64 + } + return nil +} + +func (m *Registers) GetArm64() *ARM64Registers { + if x, ok := m.GetArch().(*Registers_Arm64); ok { + return x.Arm64 + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*Registers) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*Registers_Amd64)(nil), + (*Registers_Arm64)(nil), + } +} + +func init() { + proto.RegisterType((*AMD64Registers)(nil), "gvisor.AMD64Registers") + proto.RegisterType((*ARM64Registers)(nil), "gvisor.ARM64Registers") + proto.RegisterType((*Registers)(nil), "gvisor.Registers") +} + +func init() { proto.RegisterFile("pkg/sentry/arch/registers.proto", fileDescriptor_082b7510610e0457) } + +var fileDescriptor_082b7510610e0457 = []byte{ + // 544 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xcc, 0x94, 0x4b, 0x72, 0xd3, 0x40, + 0x10, 0x86, 0xc9, 0xc4, 0xb1, 0x93, 0x09, 0x98, 0x20, 0xc0, 0xe9, 0xf0, 0x4a, 0xf0, 0x8a, 0x95, + 0xad, 0x97, 0x65, 0x7b, 0x49, 0x8a, 0x05, 0x9b, 0x6c, 0x7c, 0x81, 0x94, 0x2c, 0xcb, 0x42, 0x05, + 0xc1, 0x53, 0xd3, 0x2e, 0x4a, 0xac, 0xb9, 0x29, 0x27, 0xa1, 0xfa, 0x31, 0x90, 0xdc, 0x80, 0x5d, + 0x7f, 0xdf, 0xb4, 0x4a, 0x23, 0xfd, 0x3d, 0x63, 0x2f, 0xdd, 0xd7, 0x66, 0x8a, 0xf5, 0xf7, 0xbd, + 0xff, 0x39, 0x2d, 0x7d, 0xf5, 0x65, 0xea, 0xeb, 0xa6, 0xc5, 0x7d, 0xed, 0x71, 0xe2, 0xfc, 0x6e, + 0xbf, 0x8b, 0xfa, 0xcd, 0x8f, 0x16, 0x77, 0x7e, 0xfc, 0xab, 0x67, 0x87, 0x1f, 0x6f, 0x3e, 0x15, + 0xf9, 0x2a, 0x34, 0x44, 0x67, 0xf6, 0xd0, 0x97, 0x1d, 0x1c, 0x5c, 0x1d, 0x7c, 0xe8, 0xad, 0xa8, + 0x64, 0xb3, 0xee, 0xc0, 0xa8, 0x59, 0x8b, 0xa9, 0x3a, 0x38, 0x54, 0x53, 0x89, 0xd9, 0x74, 0xd0, + 0x53, 0xb3, 0x11, 0x83, 0x2d, 0x1c, 0xa9, 0xc1, 0x56, 0x7a, 0x5a, 0xe8, 0x87, 0x1e, 0x31, 0xe8, + 0x60, 0x10, 0x7a, 0x9c, 0xbc, 0xcb, 0xc1, 0x71, 0x78, 0x97, 0x8b, 0x86, 0xd6, 0xf8, 0x05, 0x9c, + 0xb0, 0x30, 0x7e, 0xc1, 0xbc, 0x04, 0xab, 0xbc, 0xe4, 0x27, 0x92, 0x18, 0x4e, 0xf5, 0x89, 0x24, + 0x16, 0x93, 0xc0, 0xe3, 0x60, 0x12, 0x31, 0x29, 0x3c, 0x09, 0x26, 0x15, 0x93, 0xc1, 0x30, 0x98, + 0x4c, 0x4c, 0x0e, 0x4f, 0x83, 0xc9, 0xc5, 0xcc, 0xe0, 0x2c, 0x98, 0x19, 0x9b, 0xd6, 0xc1, 0x33, + 0x35, 0xad, 0x8b, 0x46, 0xb6, 0xef, 0xb7, 0xdf, 0xca, 0x06, 0x21, 0x62, 0xa9, 0x14, 0x5d, 0xd8, + 0xe3, 0x9d, 0x6f, 0x9b, 0x5b, 0xfa, 0x95, 0xcf, 0x79, 0x65, 0x40, 0xbc, 0x2a, 0x3b, 0xfa, 0x80, + 0x0a, 0xe1, 0x85, 0x7c, 0x40, 0x85, 0xc4, 0x1b, 0x84, 0x97, 0xc2, 0x1b, 0xe6, 0x1a, 0x61, 0x24, + 0x5c, 0x33, 0x6f, 0x11, 0xce, 0x85, 0xb7, 0xcc, 0x0d, 0x02, 0x08, 0x37, 0xcc, 0x88, 0x70, 0x21, + 0x8c, 0x18, 0x9d, 0xdb, 0xc1, 0x16, 0x6f, 0xd7, 0x25, 0xd6, 0xf0, 0x4a, 0xf6, 0xb4, 0xc5, 0xeb, + 0x12, 0x6b, 0x5a, 0x68, 0x74, 0xe1, 0xb5, 0x2c, 0x34, 0xbc, 0x30, 0xfe, 0x4d, 0x53, 0xb0, 0xba, + 0xb9, 0x3f, 0x05, 0xf4, 0x97, 0x63, 0x1d, 0x02, 0xe3, 0x63, 0xe6, 0x44, 0x47, 0xc0, 0xf8, 0x84, + 0x39, 0xd5, 0x01, 0x30, 0x3e, 0x65, 0xce, 0x34, 0x7e, 0xe3, 0x33, 0xe6, 0x5c, 0xc3, 0x37, 0x3e, + 0x67, 0x9e, 0x69, 0xf4, 0xc6, 0xcf, 0x98, 0x0b, 0x0d, 0xde, 0xf8, 0x82, 0x79, 0xae, 0xb1, 0x1b, + 0x3f, 0xff, 0xef, 0x52, 0x4f, 0x8a, 0xbf, 0xa9, 0x27, 0x85, 0x98, 0xb9, 0x46, 0x4e, 0xa5, 0x98, + 0x85, 0x46, 0x4d, 0xa5, 0x98, 0xa5, 0xe6, 0x4c, 0x25, 0x9b, 0x34, 0xd6, 0xa4, 0xa9, 0x14, 0x93, + 0x68, 0xd6, 0x54, 0x8a, 0x49, 0x35, 0x6d, 0x2a, 0xc5, 0x64, 0x9a, 0x37, 0x95, 0x62, 0x72, 0x4d, + 0x9c, 0x4a, 0x31, 0x33, 0x8d, 0x9b, 0x4a, 0x31, 0x85, 0xe6, 0x4c, 0xa5, 0x98, 0x39, 0xbc, 0x09, + 0x46, 0xf6, 0x9c, 0x2e, 0xe0, 0x6d, 0x30, 0xb2, 0xe7, 0x74, 0x09, 0xef, 0x82, 0x91, 0x3d, 0x67, + 0x31, 0x5c, 0xaa, 0xc9, 0x78, 0x12, 0xd0, 0xc1, 0x95, 0x8e, 0x1b, 0x9f, 0x4f, 0x57, 0xc1, 0x7b, + 0x61, 0x57, 0xd1, 0x89, 0x70, 0xb8, 0x2f, 0xf7, 0x35, 0x8c, 0x65, 0xc8, 0x84, 0xc6, 0x68, 0x4f, + 0xfe, 0x8d, 0xd7, 0xc4, 0x1e, 0x95, 0x77, 0x9b, 0x22, 0xe7, 0x09, 0x3b, 0x4d, 0x47, 0x13, 0xb9, + 0x8f, 0x26, 0x0f, 0xef, 0xa2, 0xcf, 0x8f, 0x56, 0xd2, 0xc6, 0xfd, 0xfe, 0xae, 0xc8, 0x79, 0x02, + 0xef, 0xf7, 0x3f, 0x98, 0x5a, 0xee, 0xa7, 0xb6, 0xeb, 0xbe, 0xed, 0xd1, 0xbd, 0xb7, 0xee, 0xf3, + 0x75, 0x97, 0xfd, 0x09, 0x00, 0x00, 0xff, 0xff, 0x79, 0x30, 0x5f, 0x13, 0x11, 0x05, 0x00, 0x00, +} diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index 8b03d0187..8b03d0187 100644..100755 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 0c1db4b13..0c1db4b13 100644..100755 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go index 00d5ef461..00d5ef461 100644..100755 --- a/pkg/sentry/arch/syscalls_arm64.go +++ b/pkg/sentry/arch/syscalls_arm64.go diff --git a/pkg/sentry/contexttest/BUILD b/pkg/sentry/contexttest/BUILD deleted file mode 100644 index 6f4c86684..000000000 --- a/pkg/sentry/contexttest/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "contexttest", - testonly = 1, - srcs = ["contexttest.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/memutil", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/platform/ptrace", - "//pkg/sentry/uniqueid", - ], -) diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go deleted file mode 100644 index 031fc64ec..000000000 --- a/pkg/sentry/contexttest/contexttest.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package contexttest builds a test context.Context. -package contexttest - -import ( - "os" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ptrace" - "gvisor.dev/gvisor/pkg/sentry/uniqueid" -) - -// Context returns a Context that may be used in tests. Uses ptrace as the -// platform.Platform. -// -// Note that some filesystems may require a minimal kernel for testing, which -// this test context does not provide. For such tests, see kernel/contexttest. -func Context(tb testing.TB) context.Context { - const memfileName = "contexttest-memory" - memfd, err := memutil.CreateMemFD(memfileName, 0) - if err != nil { - tb.Fatalf("error creating application memory file: %v", err) - } - memfile := os.NewFile(uintptr(memfd), memfileName) - mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) - if err != nil { - memfile.Close() - tb.Fatalf("error creating pgalloc.MemoryFile: %v", err) - } - p, err := ptrace.New() - if err != nil { - tb.Fatal(err) - } - // Test usage of context.Background is fine. - return &TestContext{ - Context: context.Background(), - l: limits.NewLimitSet(), - mf: mf, - platform: p, - creds: auth.NewAnonymousCredentials(), - otherValues: make(map[interface{}]interface{}), - } -} - -// TestContext represents a context with minimal functionality suitable for -// running tests. -type TestContext struct { - context.Context - l *limits.LimitSet - mf *pgalloc.MemoryFile - platform platform.Platform - creds *auth.Credentials - otherValues map[interface{}]interface{} -} - -// globalUniqueID tracks incremental unique identifiers for tests. -var globalUniqueID uint64 - -// globalUniqueIDProvider implements unix.UniqueIDProvider. -type globalUniqueIDProvider struct{} - -// UniqueID implements unix.UniqueIDProvider.UniqueID. -func (*globalUniqueIDProvider) UniqueID() uint64 { - return atomic.AddUint64(&globalUniqueID, 1) -} - -// lastInotifyCookie is a monotonically increasing counter for generating unique -// inotify cookies. Must be accessed using atomic ops. -var lastInotifyCookie uint32 - -// hostClock implements ktime.Clock. -type hostClock struct { - ktime.WallRateClock - ktime.NoClockEvents -} - -// Now implements ktime.Clock.Now. -func (hostClock) Now() ktime.Time { - return ktime.FromNanoseconds(time.Now().UnixNano()) -} - -// RegisterValue registers additional values with this test context. Useful for -// providing values from external packages that contexttest can't depend on. -func (t *TestContext) RegisterValue(key, value interface{}) { - t.otherValues[key] = value -} - -// Value implements context.Context. -func (t *TestContext) Value(key interface{}) interface{} { - switch key { - case auth.CtxCredentials: - return t.creds - case limits.CtxLimits: - return t.l - case pgalloc.CtxMemoryFile: - return t.mf - case pgalloc.CtxMemoryFileProvider: - return t - case platform.CtxPlatform: - return t.platform - case uniqueid.CtxGlobalUniqueID: - return (*globalUniqueIDProvider).UniqueID(nil) - case uniqueid.CtxGlobalUniqueIDProvider: - return &globalUniqueIDProvider{} - case uniqueid.CtxInotifyCookie: - return atomic.AddUint32(&lastInotifyCookie, 1) - case ktime.CtxRealtimeClock: - return hostClock{} - default: - if val, ok := t.otherValues[key]; ok { - return val - } - return t.Context.Value(key) - } -} - -// MemoryFile implements pgalloc.MemoryFileProvider.MemoryFile. -func (t *TestContext) MemoryFile() *pgalloc.MemoryFile { - return t.mf -} - -// RootContext returns a Context that may be used in tests that need root -// credentials. Uses ptrace as the platform.Platform. -func RootContext(tb testing.TB) context.Context { - return WithCreds(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) -} - -// WithCreds returns a copy of ctx carrying creds. -func WithCreds(ctx context.Context, creds *auth.Credentials) context.Context { - return &authContext{ctx, creds} -} - -type authContext struct { - context.Context - creds *auth.Credentials -} - -// Value implements context.Context. -func (ac *authContext) Value(key interface{}) interface{} { - switch key { - case auth.CtxCredentials: - return ac.creds - default: - return ac.Context.Value(key) - } -} - -// WithLimitSet returns a copy of ctx carrying l. -func WithLimitSet(ctx context.Context, l *limits.LimitSet) context.Context { - return limitContext{ctx, l} -} - -type limitContext struct { - context.Context - l *limits.LimitSet -} - -// Value implements context.Context. -func (lc limitContext) Value(key interface{}) interface{} { - switch key { - case limits.CtxLimits: - return lc.l - default: - return lc.Context.Value(key) - } -} diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD deleted file mode 100644 index d16d78aa5..000000000 --- a/pkg/sentry/control/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "control", - srcs = [ - "control.go", - "logging.go", - "pprof.go", - "proc.go", - "state.go", - ], - visibility = [ - "//:sandbox", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fd", - "//pkg/fspath", - "//pkg/log", - "//pkg/sentry/fs", - "//pkg/sentry/fs/host", - "//pkg/sentry/fsbridge", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/state", - "//pkg/sentry/strace", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sentry/watchdog", - "//pkg/sync", - "//pkg/syserror", - "//pkg/tcpip/link/sniffer", - "//pkg/urpc", - ], -) - -go_test( - name = "control_test", - size = "small", - srcs = ["proc_test.go"], - library = ":control", - deps = [ - "//pkg/log", - "//pkg/sentry/kernel/time", - "//pkg/sentry/usage", - ], -) diff --git a/pkg/sentry/control/control_state_autogen.go b/pkg/sentry/control/control_state_autogen.go new file mode 100755 index 000000000..bd5797221 --- /dev/null +++ b/pkg/sentry/control/control_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package control diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go deleted file mode 100644 index 0a88459b2..000000000 --- a/pkg/sentry/control/proc_test.go +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package control - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/log" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usage" -) - -func init() { - log.SetLevel(log.Debug) -} - -// Tests that ProcessData.Table() prints with the correct format. -func TestProcessListTable(t *testing.T) { - testCases := []struct { - pl []*Process - expected string - }{ - { - pl: []*Process{}, - expected: "UID PID PPID C TTY STIME TIME CMD", - }, - { - pl: []*Process{ - { - UID: 0, - PID: 0, - PPID: 0, - C: 0, - TTY: "?", - STime: "0", - Time: "0", - Cmd: "zero", - }, - { - UID: 1, - PID: 1, - PPID: 1, - C: 1, - TTY: "pts/4", - STime: "1", - Time: "1", - Cmd: "one", - }, - }, - expected: `UID PID PPID C TTY STIME TIME CMD -0 0 0 0 ? 0 0 zero -1 1 1 1 pts/4 1 1 one`, - }, - } - - for _, tc := range testCases { - output := ProcessListToTable(tc.pl) - - if tc.expected != output { - t.Errorf("PrintTable(%v): got:\n%s\nwant:\n%s", tc.pl, output, tc.expected) - } - } -} - -func TestProcessListJSON(t *testing.T) { - testCases := []struct { - pl []*Process - expected string - }{ - { - pl: []*Process{}, - expected: "[]", - }, - { - pl: []*Process{ - { - UID: 0, - PID: 0, - PPID: 0, - C: 0, - STime: "0", - Time: "0", - Cmd: "zero", - }, - { - UID: 1, - PID: 1, - PPID: 1, - C: 1, - STime: "1", - Time: "1", - Cmd: "one", - }, - }, - expected: "[0,1]", - }, - } - - for _, tc := range testCases { - output, err := PrintPIDsJSON(tc.pl) - if err != nil { - t.Errorf("failed to generate JSON: %v", err) - } - - if tc.expected != output { - t.Errorf("PrintJSON(%v): got:\n%s\nwant:\n%s", tc.pl, output, tc.expected) - } - } -} - -func TestPercentCPU(t *testing.T) { - testCases := []struct { - stats usage.CPUStats - startTime ktime.Time - now ktime.Time - expected int32 - }{ - { - // Verify that 100% use is capped at 99. - stats: usage.CPUStats{UserTime: 1e9, SysTime: 1e9}, - startTime: ktime.FromNanoseconds(7e9), - now: ktime.FromNanoseconds(9e9), - expected: 99, - }, - { - // Verify that if usage > lifetime, we get at most 99% - // usage. - stats: usage.CPUStats{UserTime: 2e9, SysTime: 2e9}, - startTime: ktime.FromNanoseconds(7e9), - now: ktime.FromNanoseconds(9e9), - expected: 99, - }, - { - // Verify that 50% usage is reported correctly. - stats: usage.CPUStats{UserTime: 1e9, SysTime: 1e9}, - startTime: ktime.FromNanoseconds(12e9), - now: ktime.FromNanoseconds(16e9), - expected: 50, - }, - { - // Verify that 0% usage is reported correctly. - stats: usage.CPUStats{UserTime: 0, SysTime: 0}, - startTime: ktime.FromNanoseconds(12e9), - now: ktime.FromNanoseconds(14e9), - expected: 0, - }, - } - - for _, tc := range testCases { - if pcpu := percentCPU(tc.stats, tc.startTime, tc.now); pcpu != tc.expected { - t.Errorf("percentCPU(%v, %v, %v): got %d, want %d", tc.stats, tc.startTime, tc.now, pcpu, tc.expected) - } - } -} diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD deleted file mode 100644 index e403cbd8b..000000000 --- a/pkg/sentry/device/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "device", - srcs = ["device.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sync", - ], -) - -go_test( - name = "device_test", - size = "small", - srcs = ["device_test.go"], - library = ":device", -) diff --git a/pkg/sentry/device/device_state_autogen.go b/pkg/sentry/device/device_state_autogen.go new file mode 100755 index 000000000..dd41a5659 --- /dev/null +++ b/pkg/sentry/device/device_state_autogen.go @@ -0,0 +1,52 @@ +// automatically generated by stateify. + +package device + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Registry) beforeSave() {} +func (x *Registry) save(m state.Map) { + x.beforeSave() + m.Save("lastAnonDeviceMinor", &x.lastAnonDeviceMinor) + m.Save("devices", &x.devices) +} + +func (x *Registry) afterLoad() {} +func (x *Registry) load(m state.Map) { + m.Load("lastAnonDeviceMinor", &x.lastAnonDeviceMinor) + m.Load("devices", &x.devices) +} + +func (x *ID) beforeSave() {} +func (x *ID) save(m state.Map) { + x.beforeSave() + m.Save("Major", &x.Major) + m.Save("Minor", &x.Minor) +} + +func (x *ID) afterLoad() {} +func (x *ID) load(m state.Map) { + m.Load("Major", &x.Major) + m.Load("Minor", &x.Minor) +} + +func (x *Device) beforeSave() {} +func (x *Device) save(m state.Map) { + x.beforeSave() + m.Save("ID", &x.ID) + m.Save("last", &x.last) +} + +func (x *Device) afterLoad() {} +func (x *Device) load(m state.Map) { + m.Load("ID", &x.ID) + m.Load("last", &x.last) +} + +func init() { + state.Register("pkg/sentry/device.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load}) + state.Register("pkg/sentry/device.ID", (*ID)(nil), state.Fns{Save: (*ID).save, Load: (*ID).load}) + state.Register("pkg/sentry/device.Device", (*Device)(nil), state.Fns{Save: (*Device).save, Load: (*Device).load}) +} diff --git a/pkg/sentry/device/device_test.go b/pkg/sentry/device/device_test.go deleted file mode 100644 index e3f51ce4f..000000000 --- a/pkg/sentry/device/device_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package device - -import ( - "testing" -) - -func TestMultiDevice(t *testing.T) { - device := &MultiDevice{} - - // Check that Load fails to install virtual inodes that are - // uninitialized. - if device.Load(MultiDeviceKey{}, 0) { - t.Fatalf("got load of invalid virtual inode 0, want unsuccessful") - } - - inode := device.Map(MultiDeviceKey{}) - - // Assert that the same raw device and inode map to - // a consistent virtual inode. - if i := device.Map(MultiDeviceKey{}); i != inode { - t.Fatalf("got inode %d, want %d in %s", i, inode, device) - } - - // Assert that a new inode or new device does not conflict. - if i := device.Map(MultiDeviceKey{Device: 0, Inode: 1}); i == inode { - t.Fatalf("got reused inode %d, want new distinct inode in %s", i, device) - } - last := device.Map(MultiDeviceKey{Device: 1, Inode: 0}) - if last == inode { - t.Fatalf("got reused inode %d, want new distinct inode in %s", last, device) - } - - // Virtual is the virtual inode we want to load. - virtual := last + 1 - - // Assert that we can load a virtual inode at a new place. - if !device.Load(MultiDeviceKey{Device: 0, Inode: 2}, virtual) { - t.Fatalf("got load of virtual inode %d failed, want success in %s", virtual, device) - } - - // Assert that the next inode skips over the loaded one. - if i := device.Map(MultiDeviceKey{Device: 0, Inode: 3}); i != virtual+1 { - t.Fatalf("got inode %d, want %d in %s", i, virtual+1, device) - } -} diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD deleted file mode 100644 index abe58f818..000000000 --- a/pkg/sentry/devices/memdev/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "memdev", - srcs = [ - "full.go", - "memdev.go", - "null.go", - "random.go", - "zero.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/rand", - "//pkg/safemem", - "//pkg/sentry/fsimpl/devtmpfs", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go deleted file mode 100644 index c7e197691..000000000 --- a/pkg/sentry/devices/memdev/full.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memdev - -import ( - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -const fullDevMinor = 7 - -// fullDevice implements vfs.Device for /dev/full. -type fullDevice struct{} - -// Open implements vfs.Device.Open. -func (fullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &fullFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ - UseDentryMetadata: true, - }); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - -// fullFD implements vfs.FileDescriptionImpl for /dev/full. -type fullFD struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl - vfs.DentryMetadataFileDescriptionImpl -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *fullFD) Release() { - // noop -} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *fullFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return dst.ZeroOut(ctx, dst.NumBytes()) -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *fullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return dst.ZeroOut(ctx, dst.NumBytes()) -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *fullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.ENOSPC -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *fullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.ENOSPC -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *fullFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - return 0, nil -} diff --git a/pkg/sentry/devices/memdev/memdev.go b/pkg/sentry/devices/memdev/memdev.go deleted file mode 100644 index 5759900c4..000000000 --- a/pkg/sentry/devices/memdev/memdev.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package memdev implements "mem" character devices, as implemented in Linux -// by drivers/char/mem.c and drivers/char/random.c. -package memdev - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -// Register registers all devices implemented by this package in vfsObj. -func Register(vfsObj *vfs.VirtualFilesystem) error { - for minor, dev := range map[uint32]vfs.Device{ - nullDevMinor: nullDevice{}, - zeroDevMinor: zeroDevice{}, - fullDevMinor: fullDevice{}, - randomDevMinor: randomDevice{}, - urandomDevMinor: randomDevice{}, - } { - if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MEM_MAJOR, minor, dev, &vfs.RegisterDeviceOptions{ - GroupName: "mem", - }); err != nil { - return err - } - } - return nil -} - -// CreateDevtmpfsFiles creates device special files in dev representing all -// devices implemented by this package. -func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { - for minor, name := range map[uint32]string{ - nullDevMinor: "null", - zeroDevMinor: "zero", - fullDevMinor: "full", - randomDevMinor: "random", - urandomDevMinor: "urandom", - } { - if err := dev.CreateDeviceFile(ctx, name, vfs.CharDevice, linux.MEM_MAJOR, minor, 0666 /* mode */); err != nil { - return err - } - } - return nil -} diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go deleted file mode 100644 index 33d060d02..000000000 --- a/pkg/sentry/devices/memdev/null.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memdev - -import ( - "io" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -const nullDevMinor = 3 - -// nullDevice implements vfs.Device for /dev/null. -type nullDevice struct{} - -// Open implements vfs.Device.Open. -func (nullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &nullFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ - UseDentryMetadata: true, - }); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - -// nullFD implements vfs.FileDescriptionImpl for /dev/null. -type nullFD struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl - vfs.DentryMetadataFileDescriptionImpl -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *nullFD) Release() { - // noop -} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *nullFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return 0, io.EOF -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *nullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return 0, io.EOF -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *nullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return src.NumBytes(), nil -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *nullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return src.NumBytes(), nil -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *nullFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - return 0, nil -} diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go deleted file mode 100644 index acfa23149..000000000 --- a/pkg/sentry/devices/memdev/random.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memdev - -import ( - "sync/atomic" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - randomDevMinor = 8 - urandomDevMinor = 9 -) - -// randomDevice implements vfs.Device for /dev/random and /dev/urandom. -type randomDevice struct{} - -// Open implements vfs.Device.Open. -func (randomDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &randomFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ - UseDentryMetadata: true, - }); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - -// randomFD implements vfs.FileDescriptionImpl for /dev/random. -type randomFD struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl - vfs.DentryMetadataFileDescriptionImpl - - // off is the "file offset". off is accessed using atomic memory - // operations. - off int64 -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *randomFD) Release() { - // noop -} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *randomFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader}) -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *randomFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader}) - atomic.AddInt64(&fd.off, n) - return n, err -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *randomFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - // In Linux, this mixes the written bytes into the entropy pool; we just - // throw them away. - return src.NumBytes(), nil -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *randomFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - atomic.AddInt64(&fd.off, src.NumBytes()) - return src.NumBytes(), nil -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *randomFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - // Linux: drivers/char/random.c:random_fops.llseek == urandom_fops.llseek - // == noop_llseek - return atomic.LoadInt64(&fd.off), nil -} diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go deleted file mode 100644 index 3b1372b9e..000000000 --- a/pkg/sentry/devices/memdev/zero.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memdev - -import ( - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -const zeroDevMinor = 5 - -// zeroDevice implements vfs.Device for /dev/zero. -type zeroDevice struct{} - -// Open implements vfs.Device.Open. -func (zeroDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &zeroFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ - UseDentryMetadata: true, - }); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - -// zeroFD implements vfs.FileDescriptionImpl for /dev/zero. -type zeroFD struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl - vfs.DentryMetadataFileDescriptionImpl -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *zeroFD) Release() { - // noop -} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *zeroFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return dst.ZeroOut(ctx, dst.NumBytes()) -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *zeroFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return dst.ZeroOut(ctx, dst.NumBytes()) -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *zeroFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return src.NumBytes(), nil -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *zeroFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return src.NumBytes(), nil -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *zeroFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - return 0, nil -} - -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx)) - if err != nil { - return err - } - opts.MappingIdentity = m - opts.Mappable = m - return nil -} diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD deleted file mode 100644 index ea85ab33c..000000000 --- a/pkg/sentry/fs/BUILD +++ /dev/null @@ -1,135 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_library( - name = "fs", - srcs = [ - "attr.go", - "context.go", - "copy_up.go", - "dentry.go", - "dirent.go", - "dirent_cache.go", - "dirent_cache_limiter.go", - "dirent_list.go", - "dirent_state.go", - "event_list.go", - "file.go", - "file_operations.go", - "file_overlay.go", - "file_state.go", - "filesystems.go", - "flags.go", - "fs.go", - "inode.go", - "inode_inotify.go", - "inode_operations.go", - "inode_overlay.go", - "inotify.go", - "inotify_event.go", - "inotify_watch.go", - "mock.go", - "mount.go", - "mount_overlay.go", - "mounts.go", - "offset.go", - "overlay.go", - "path.go", - "restore.go", - "save.go", - "seek.go", - "splice.go", - "sync.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/context", - "//pkg/log", - "//pkg/metric", - "//pkg/p9", - "//pkg/refs", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/platform", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/state", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_template_instance( - name = "dirent_list", - out = "dirent_list.go", - package = "fs", - prefix = "dirent", - template = "//pkg/ilist:generic_list", - types = { - "Linker": "*Dirent", - "Element": "*Dirent", - }, -) - -go_template_instance( - name = "event_list", - out = "event_list.go", - package = "fs", - prefix = "event", - template = "//pkg/ilist:generic_list", - types = { - "Linker": "*Event", - "Element": "*Event", - }, -) - -go_test( - name = "fs_x_test", - size = "small", - srcs = [ - "copy_up_test.go", - "file_overlay_test.go", - "inode_overlay_test.go", - "mounts_test.go", - ], - deps = [ - ":fs", - "//pkg/context", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/kernel/contexttest", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) - -go_test( - name = "fs_test", - size = "small", - srcs = [ - "dirent_cache_test.go", - "dirent_refs_test.go", - "mount_test.go", - "path_test.go", - ], - library = ":fs", - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - ], -) diff --git a/pkg/sentry/fs/README.md b/pkg/sentry/fs/README.md deleted file mode 100644 index db4a1b730..000000000 --- a/pkg/sentry/fs/README.md +++ /dev/null @@ -1,229 +0,0 @@ -This package provides an implementation of the Linux virtual filesystem. - -[TOC] - -## Overview - -- An `fs.Dirent` caches an `fs.Inode` in memory at a path in the VFS, giving - the `fs.Inode` a relative position with respect to other `fs.Inode`s. - -- If an `fs.Dirent` is referenced by two file descriptors, then those file - descriptors are coherent with each other: they depend on the same - `fs.Inode`. - -- A mount point is an `fs.Dirent` for which `fs.Dirent.mounted` is true. It - exposes the root of a mounted filesystem. - -- The `fs.Inode` produced by a registered filesystem on mount(2) owns an - `fs.MountedFilesystem` from which other `fs.Inode`s will be looked up. For a - remote filesystem, the `fs.MountedFilesystem` owns the connection to that - remote filesystem. - -- In general: - -``` -fs.Inode <------------------------------ -| | -| | -produced by | -exactly one | -| responsible for the -| virtual identity of -v | -fs.MountedFilesystem ------------------- -``` - -Glossary: - -- VFS: virtual filesystem. - -- inode: a virtual file object holding a cached view of a file on a backing - filesystem (includes metadata and page caches). - -- superblock: the virtual state of a mounted filesystem (e.g. the virtual - inode number set). - -- mount namespace: a view of the mounts under a root (during path traversal, - the VFS makes visible/follows the mount point that is in the current task's - mount namespace). - -## Save and restore - -An application's hard dependencies on filesystem state can be broken down into -two categories: - -- The state necessary to execute a traversal on or view the *virtual* - filesystem hierarchy, regardless of what files an application has open. - -- The state necessary to represent open files. - -The first is always necessary to save and restore. An application may never have -any open file descriptors, but across save and restore it should see a coherent -view of any mount namespace. NOTE(b/63601033): Currently only one "initial" -mount namespace is supported. - -The second is so that system calls across save and restore are coherent with -each other (e.g. so that unintended re-reads or overwrites do not occur). - -Specifically this state is: - -- An `fs.MountManager` containing mount points. - -- A `kernel.FDTable` containing pointers to open files. - -Anything else managed by the VFS that can be easily loaded into memory from a -filesystem is synced back to those filesystems and is not saved. Examples are -pages in page caches used for optimizations (i.e. readahead and writeback), and -directory entries used to accelerate path lookups. - -### Mount points - -Saving and restoring a mount point means saving and restoring: - -- The root of the mounted filesystem. - -- Mount flags, which control how the VFS interacts with the mounted - filesystem. - -- Any relevant metadata about the mounted filesystem. - -- All `fs.Inode`s referenced by the application that reside under the mount - point. - -`fs.MountedFilesystem` is metadata about a filesystem that is mounted. It is -referenced by every `fs.Inode` loaded into memory under the mount point -including the `fs.Inode` of the mount point itself. The `fs.MountedFilesystem` -maps file objects on the filesystem to a virtualized `fs.Inode` number and vice -versa. - -To restore all `fs.Inode`s under a given mount point, each `fs.Inode` leverages -its dependency on an `fs.MountedFilesystem`. Since the `fs.MountedFilesystem` -knows how an `fs.Inode` maps to a file object on a backing filesystem, this -mapping can be trivially consulted by each `fs.Inode` when the `fs.Inode` is -restored. - -In detail, a mount point is saved in two steps: - -- First, after the kernel is paused but before state.Save, we walk all mount - namespaces and install a mapping from `fs.Inode` numbers to file paths - relative to the root of the mounted filesystem in each - `fs.MountedFilesystem`. This is subsequently called the set of `fs.Inode` - mappings. - -- Second, during state.Save, each `fs.MountedFilesystem` decides whether to - save the set of `fs.Inode` mappings. In-memory filesystems, like tmpfs, have - no need to save a set of `fs.Inode` mappings, since the `fs.Inode`s can be - entirely encoded in state file. Each `fs.MountedFilesystem` also optionally - saves the device name from when the filesystem was originally mounted. Each - `fs.Inode` saves its virtual identifier and a reference to a - `fs.MountedFilesystem`. - -A mount point is restored in two steps: - -- First, before state.Load, all mount configurations are stored in a global - `fs.RestoreEnvironment`. This tells us what mount points the user wants to - restore and how to re-establish pointers to backing filesystems. - -- Second, during state.Load, each `fs.MountedFilesystem` optionally searches - for a mount in the `fs.RestoreEnvironment` that matches its saved device - name. The `fs.MountedFilesystem` then reestablishes a pointer to the root of - the mounted filesystem. For example, the mount specification provides the - network connection for a mounted remote filesystem client to communicate - with its remote file server. The `fs.MountedFilesystem` also trivially loads - its set of `fs.Inode` mappings. When an `fs.Inode` is encountered, the - `fs.Inode` loads its virtual identifier and its reference a - `fs.MountedFilesystem`. It uses the `fs.MountedFilesystem` to obtain the - root of the mounted filesystem and the `fs.Inode` mappings to obtain the - relative file path to its data. With these, the `fs.Inode` re-establishes a - pointer to its file object. - -A mount point can trivially restore its `fs.Inode`s in parallel since -`fs.Inode`s have a restore dependency on their `fs.MountedFilesystem` and not on -each other. - -### Open files - -An `fs.File` references the following filesystem objects: - -```go -fs.File -> fs.Dirent -> fs.Inode -> fs.MountedFilesystem -``` - -The `fs.Inode` is restored using its `fs.MountedFilesystem`. The -[Mount points](#mount-points) section above describes how this happens in -detail. The `fs.Dirent` restores its pointer to an `fs.Inode`, pointers to -parent and children `fs.Dirents`, and the basename of the file. - -Otherwise an `fs.File` restores flags, an offset, and a unique identifier (only -used internally). - -It may use the `fs.Inode`, which it indirectly holds a reference on through the -`fs.Dirent`, to reestablish an open file handle on the backing filesystem (e.g. -to continue reading and writing). - -## Overlay - -The overlay implementation in the fs package takes Linux overlayfs as a frame of -reference but corrects for several POSIX consistency errors. - -In Linux overlayfs, the `struct inode` used for reading and writing to the same -file may be different. This is because the `struct inode` is dissociated with -the process of copying up the file from the upper to the lower directory. Since -flock(2) and fcntl(2) locks, inotify(7) watches, page caches, and a file's -identity are all stored directly or indirectly off the `struct inode`, these -properties of the `struct inode` may be stale after the first modification. This -can lead to file locking bugs, missed inotify events, and inconsistent data in -shared memory mappings of files, to name a few problems. - -The fs package maintains a single `fs.Inode` to represent a directory entry in -an overlay and defines operations on this `fs.Inode` which synchronize with the -copy up process. This achieves several things: - -+ File locks, inotify watches, and the identity of the file need not be copied - at all. - -+ Memory mappings of files coordinate with the copy up process so that if a - file in the lower directory is memory mapped, all references to it are - invalidated, forcing the application to re-fault on memory mappings of the - file under the upper directory. - -The `fs.Inode` holds metadata about files in the upper and/or lower directories -via an `fs.overlayEntry`. The `fs.overlayEntry` implements the `fs.Mappable` -interface. It multiplexes between upper and lower directory memory mappings and -stores a copy of memory references so they can be transferred to the upper -directory `fs.Mappable` when the file is copied up. - -The lower filesystem in an overlay may contain another (nested) overlay, but the -upper filesystem may not contain another overlay. In other words, nested -overlays form a tree structure that only allows branching in the lower -filesystem. - -Caching decisions in the overlay are delegated to the upper filesystem, meaning -that the Keep and Revalidate methods on the overlay return the same values as -the upper filesystem. A small wrinkle is that the lower filesystem is not -allowed to return `true` from Revalidate, as the overlay can not reload inodes -from the lower filesystem. A lower filesystem that does return `true` from -Revalidate will trigger a panic. - -The `fs.Inode` also holds a reference to a `fs.MountedFilesystem` that -normalizes across the mounted filesystem state of the upper and lower -directories. - -When a file is copied from the lower to the upper directory, attempts to -interact with the file block until the copy completes. All copying synchronizes -with rename(2). - -## Future Work - -### Overlay - -When a file is copied from a lower directory to an upper directory, several -locks are taken: the global renamuMu and the copyMu of the `fs.Inode` being -copied. This blocks operations on the file, including fault handling of memory -mappings. Performance could be improved by copying files into a temporary -directory that resides on the same filesystem as the upper directory and doing -an atomic rename, holding locks only during the rename operation. - -Additionally files are copied up synchronously. For large files, this causes a -noticeable latency. Performance could be improved by pipelining copies at -non-overlapping file offsets. diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD deleted file mode 100644 index aedcecfa1..000000000 --- a/pkg/sentry/fs/anon/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "anon", - srcs = [ - "anon.go", - "device.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/anon/anon_state_autogen.go b/pkg/sentry/fs/anon/anon_state_autogen.go new file mode 100755 index 000000000..b2b1a466e --- /dev/null +++ b/pkg/sentry/fs/anon/anon_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package anon diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go deleted file mode 100644 index 91792d9fe..000000000 --- a/pkg/sentry/fs/copy_up_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/fs" - _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - // origFileSize is the original file size. This many bytes should be - // copied up before the test file is modified. - origFileSize = 4096 - - // truncatedFileSize is the size to truncate all test files. - truncateFileSize = 10 -) - -// TestConcurrentCopyUp is a copy up stress test for an overlay. -// -// It creates a 64-level deep directory tree in the lower filesystem and -// populates the last subdirectory with 64 files containing random content: -// -// /lower -// /sudir0/.../subdir63/ -// /file0 -// ... -// /file63 -// -// The files are truncated concurrently by 4 goroutines per file. -// These goroutines contend with copying up all parent 64 subdirectories -// as well as the final file content. -// -// At the end of the test, we assert that the files respect the new truncated -// size and contain the content we expect. -func TestConcurrentCopyUp(t *testing.T) { - ctx := contexttest.Context(t) - files := makeOverlayTestFiles(t) - - var wg sync.WaitGroup - for _, file := range files { - for i := 0; i < 4; i++ { - wg.Add(1) - go func(o *overlayTestFile) { - if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil { - t.Fatalf("failed to copy up: %v", err) - } - wg.Done() - }(file) - } - } - wg.Wait() - - for _, file := range files { - got := make([]byte, origFileSize) - n, err := file.File.Readv(ctx, usermem.BytesIOSequence(got)) - if int(n) != truncateFileSize { - t.Fatalf("read %d bytes from file, want %d", n, truncateFileSize) - } - if err != nil && err != io.EOF { - t.Fatalf("read got error %v, want nil", err) - } - if !bytes.Equal(got[:n], file.content[:truncateFileSize]) { - t.Fatalf("file content is %v, want %v", got[:n], file.content[:truncateFileSize]) - } - } -} - -type overlayTestFile struct { - File *fs.File - name string - content []byte -} - -func makeOverlayTestFiles(t *testing.T) []*overlayTestFile { - ctx := contexttest.Context(t) - - // Create a lower tmpfs mount. - fsys, _ := fs.FindFilesystem("tmpfs") - lower, err := fsys.Mount(contexttest.Context(t), "", fs.MountSourceFlags{}, "", nil) - if err != nil { - t.Fatalf("failed to mount tmpfs: %v", err) - } - lowerRoot := fs.NewDirent(ctx, lower, "") - - // Make a deep set of subdirectories that everyone shares. - next := lowerRoot - for i := 0; i < 64; i++ { - name := fmt.Sprintf("subdir%d", i) - err := next.CreateDirectory(ctx, lowerRoot, name, fs.FilePermsFromMode(0777)) - if err != nil { - t.Fatalf("failed to create dir %q: %v", name, err) - } - next, err = next.Walk(ctx, lowerRoot, name) - if err != nil { - t.Fatalf("failed to walk to %q: %v", name, err) - } - } - - // Make a bunch of files in the last directory. - var files []*overlayTestFile - for i := 0; i < 64; i++ { - name := fmt.Sprintf("file%d", i) - f, err := next.Create(ctx, next, name, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - t.Fatalf("failed to create file %q: %v", name, err) - } - defer f.DecRef() - - relname, _ := f.Dirent.FullName(lowerRoot) - - o := &overlayTestFile{ - name: relname, - content: make([]byte, origFileSize), - } - - if _, err := rand.Read(o.content); err != nil { - t.Fatalf("failed to read from /dev/urandom: %v", err) - } - - if _, err := f.Writev(ctx, usermem.BytesIOSequence(o.content)); err != nil { - t.Fatalf("failed to write content to file %q: %v", name, err) - } - - files = append(files, o) - } - - // Create an empty upper tmpfs mount which we will copy up into. - upper, err := fsys.Mount(ctx, "", fs.MountSourceFlags{}, "", nil) - if err != nil { - t.Fatalf("failed to mount tmpfs: %v", err) - } - - // Construct an overlay root. - overlay, err := fs.NewOverlayRoot(ctx, upper, lower, fs.MountSourceFlags{}) - if err != nil { - t.Fatalf("failed to construct overlay root: %v", err) - } - - // Create a MountNamespace to traverse the file system. - mns, err := fs.NewMountNamespace(ctx, overlay) - if err != nil { - t.Fatalf("failed to construct mount manager: %v", err) - } - - // Walk to all of the files in the overlay, open them readable. - for _, f := range files { - maxTraversals := uint(0) - d, err := mns.FindInode(ctx, mns.Root(), mns.Root(), f.name, &maxTraversals) - if err != nil { - t.Fatalf("failed to find %q: %v", f.name, err) - } - defer d.DecRef() - - f.File, err = d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("failed to open file %q readable: %v", f.name, err) - } - } - - return files -} diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD deleted file mode 100644 index 9379a4d7b..000000000 --- a/pkg/sentry/fs/dev/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "dev", - srcs = [ - "dev.go", - "device.go", - "fs.go", - "full.go", - "net_tun.go", - "null.go", - "random.go", - "tty.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/rand", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/socket/netstack", - "//pkg/syserror", - "//pkg/tcpip/link/tun", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/dev/dev_state_autogen.go b/pkg/sentry/fs/dev/dev_state_autogen.go new file mode 100755 index 000000000..272f02672 --- /dev/null +++ b/pkg/sentry/fs/dev/dev_state_autogen.go @@ -0,0 +1,154 @@ +// automatically generated by stateify. + +package dev + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func (x *fullDevice) beforeSave() {} +func (x *fullDevice) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *fullDevice) afterLoad() {} +func (x *fullDevice) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *fullFileOperations) beforeSave() {} +func (x *fullFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *fullFileOperations) afterLoad() {} +func (x *fullFileOperations) load(m state.Map) { +} + +func (x *netTunInodeOperations) beforeSave() {} +func (x *netTunInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *netTunInodeOperations) afterLoad() {} +func (x *netTunInodeOperations) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *netTunFileOperations) beforeSave() {} +func (x *netTunFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("device", &x.device) +} + +func (x *netTunFileOperations) afterLoad() {} +func (x *netTunFileOperations) load(m state.Map) { + m.Load("device", &x.device) +} + +func (x *nullDevice) beforeSave() {} +func (x *nullDevice) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *nullDevice) afterLoad() {} +func (x *nullDevice) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *nullFileOperations) beforeSave() {} +func (x *nullFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *nullFileOperations) afterLoad() {} +func (x *nullFileOperations) load(m state.Map) { +} + +func (x *zeroDevice) beforeSave() {} +func (x *zeroDevice) save(m state.Map) { + x.beforeSave() + m.Save("nullDevice", &x.nullDevice) +} + +func (x *zeroDevice) afterLoad() {} +func (x *zeroDevice) load(m state.Map) { + m.Load("nullDevice", &x.nullDevice) +} + +func (x *zeroFileOperations) beforeSave() {} +func (x *zeroFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *zeroFileOperations) afterLoad() {} +func (x *zeroFileOperations) load(m state.Map) { +} + +func (x *randomDevice) beforeSave() {} +func (x *randomDevice) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *randomDevice) afterLoad() {} +func (x *randomDevice) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *randomFileOperations) beforeSave() {} +func (x *randomFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *randomFileOperations) afterLoad() {} +func (x *randomFileOperations) load(m state.Map) { +} + +func (x *ttyInodeOperations) beforeSave() {} +func (x *ttyInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *ttyInodeOperations) afterLoad() {} +func (x *ttyInodeOperations) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *ttyFileOperations) beforeSave() {} +func (x *ttyFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *ttyFileOperations) afterLoad() {} +func (x *ttyFileOperations) load(m state.Map) { +} + +func init() { + state.Register("pkg/sentry/fs/dev.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) + state.Register("pkg/sentry/fs/dev.fullDevice", (*fullDevice)(nil), state.Fns{Save: (*fullDevice).save, Load: (*fullDevice).load}) + state.Register("pkg/sentry/fs/dev.fullFileOperations", (*fullFileOperations)(nil), state.Fns{Save: (*fullFileOperations).save, Load: (*fullFileOperations).load}) + state.Register("pkg/sentry/fs/dev.netTunInodeOperations", (*netTunInodeOperations)(nil), state.Fns{Save: (*netTunInodeOperations).save, Load: (*netTunInodeOperations).load}) + state.Register("pkg/sentry/fs/dev.netTunFileOperations", (*netTunFileOperations)(nil), state.Fns{Save: (*netTunFileOperations).save, Load: (*netTunFileOperations).load}) + state.Register("pkg/sentry/fs/dev.nullDevice", (*nullDevice)(nil), state.Fns{Save: (*nullDevice).save, Load: (*nullDevice).load}) + state.Register("pkg/sentry/fs/dev.nullFileOperations", (*nullFileOperations)(nil), state.Fns{Save: (*nullFileOperations).save, Load: (*nullFileOperations).load}) + state.Register("pkg/sentry/fs/dev.zeroDevice", (*zeroDevice)(nil), state.Fns{Save: (*zeroDevice).save, Load: (*zeroDevice).load}) + state.Register("pkg/sentry/fs/dev.zeroFileOperations", (*zeroFileOperations)(nil), state.Fns{Save: (*zeroFileOperations).save, Load: (*zeroFileOperations).load}) + state.Register("pkg/sentry/fs/dev.randomDevice", (*randomDevice)(nil), state.Fns{Save: (*randomDevice).save, Load: (*randomDevice).load}) + state.Register("pkg/sentry/fs/dev.randomFileOperations", (*randomFileOperations)(nil), state.Fns{Save: (*randomFileOperations).save, Load: (*randomFileOperations).load}) + state.Register("pkg/sentry/fs/dev.ttyInodeOperations", (*ttyInodeOperations)(nil), state.Fns{Save: (*ttyInodeOperations).save, Load: (*ttyInodeOperations).load}) + state.Register("pkg/sentry/fs/dev.ttyFileOperations", (*ttyFileOperations)(nil), state.Fns{Save: (*ttyFileOperations).save, Load: (*ttyFileOperations).load}) +} diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index dc7ad075a..dc7ad075a 100644..100755 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go diff --git a/pkg/sentry/fs/dirent_cache_test.go b/pkg/sentry/fs/dirent_cache_test.go deleted file mode 100644 index 395c879f5..000000000 --- a/pkg/sentry/fs/dirent_cache_test.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "testing" -) - -func TestDirentCache(t *testing.T) { - const maxSize = 5 - - c := NewDirentCache(maxSize) - - // Size starts at 0. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Create a Dirent d. - d := NewNegativeDirent("") - - // c does not contain d. - if got, want := c.contains(d), false; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add d to the cache. - c.Add(d) - - // Size is now 1. - if got, want := c.Size(), uint64(1); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add maxSize-1 more elements. d should be oldest element. - for i := 0; i < maxSize-1; i++ { - c.Add(NewNegativeDirent("")) - } - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // "Bump" d to the front by re-adding it. - c.Add(d) - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add maxSize-1 more elements. d should again be oldest element. - for i := 0; i < maxSize-1; i++ { - c.Add(NewNegativeDirent("")) - } - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add one more element, which will bump d from the cache. - c.Add(NewNegativeDirent("")) - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c does not contain d. - if got, want := c.contains(d), false; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Invalidating causes size to be 0 and list to be empty. - c.Invalidate() - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - if got, want := c.list.Empty(), true; got != want { - t.Errorf("c.list.Empty() got %v, want %v", got, want) - } - - // Fill cache with maxSize dirents. - for i := 0; i < maxSize; i++ { - c.Add(NewNegativeDirent("")) - } -} - -func TestDirentCacheLimiter(t *testing.T) { - const ( - globalMaxSize = 5 - maxSize = 3 - ) - - limit := NewDirentCacheLimiter(globalMaxSize) - c1 := NewDirentCache(maxSize) - c1.limit = limit - c2 := NewDirentCache(maxSize) - c2.limit = limit - - // Create a Dirent d. - d := NewNegativeDirent("") - - // Add d to the cache. - c1.Add(d) - if got, want := c1.Size(), uint64(1); got != want { - t.Errorf("c1.Size() got %v, want %v", got, want) - } - - // Add maxSize-1 more elements. d should be oldest element. - for i := 0; i < maxSize-1; i++ { - c1.Add(NewNegativeDirent("")) - } - if got, want := c1.Size(), uint64(maxSize); got != want { - t.Errorf("c1.Size() got %v, want %v", got, want) - } - - // Check that d is still there. - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Fill up the other cache, it will start dropping old entries from the cache - // when the global limit is reached. - for i := 0; i < maxSize; i++ { - c2.Add(NewNegativeDirent("")) - } - - // Check is what's remaining from global max. - if got, want := c2.Size(), globalMaxSize-maxSize; int(got) != want { - t.Errorf("c2.Size() got %v, want %v", got, want) - } - - // Check that d was not dropped. - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Add an entry that will eventually be dropped. Check is done later... - drop := NewNegativeDirent("") - c1.Add(drop) - - // Check that d is bumped to front even when global limit is reached. - c1.Add(d) - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Add 2 more element and check that: - // - d is still in the list: to verify that d was bumped - // - d2/d3 are in the list: older entries are dropped when global limit is - // reached. - // - drop is not in the list: indeed older elements are dropped. - d2 := NewNegativeDirent("") - c1.Add(d2) - d3 := NewNegativeDirent("") - c1.Add(d3) - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - if got, want := c1.contains(d2), true; got != want { - t.Errorf("c1.contains(d2) got %v want %v", got, want) - } - if got, want := c1.contains(d3), true; got != want { - t.Errorf("c1.contains(d3) got %v want %v", got, want) - } - if got, want := c1.contains(drop), false; got != want { - t.Errorf("c1.contains(drop) got %v want %v", got, want) - } - - // Drop all entries from one cache. The other will be allowed to grow. - c1.Invalidate() - c2.Add(NewNegativeDirent("")) - if got, want := c2.Size(), uint64(maxSize); got != want { - t.Errorf("c2.Size() got %v, want %v", got, want) - } -} - -// TestNilDirentCache tests that a nil cache supports all cache operations, but -// treats them as noop. -func TestNilDirentCache(t *testing.T) { - // Create a nil cache. - var c *DirentCache - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Add. - c.Add(NewNegativeDirent("")) - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Remove. - c.Remove(NewNegativeDirent("")) - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Invalidate. - c.Invalidate() - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } -} diff --git a/pkg/sentry/fs/dirent_list.go b/pkg/sentry/fs/dirent_list.go new file mode 100755 index 000000000..acdce100c --- /dev/null +++ b/pkg/sentry/fs/dirent_list.go @@ -0,0 +1,186 @@ +package fs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type direntElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (direntElementMapper) linkerFor(elem *Dirent) *Dirent { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type direntList struct { + head *Dirent + tail *Dirent +} + +// Reset resets list l to the empty state. +func (l *direntList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *direntList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *direntList) Front() *Dirent { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *direntList) Back() *Dirent { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *direntList) PushFront(e *Dirent) { + linker := direntElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + direntElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *direntList) PushBack(e *Dirent) { + linker := direntElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + direntElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *direntList) PushBackList(m *direntList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + direntElementMapper{}.linkerFor(l.tail).SetNext(m.head) + direntElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *direntList) InsertAfter(b, e *Dirent) { + bLinker := direntElementMapper{}.linkerFor(b) + eLinker := direntElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + direntElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *direntList) InsertBefore(a, e *Dirent) { + aLinker := direntElementMapper{}.linkerFor(a) + eLinker := direntElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + direntElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *direntList) Remove(e *Dirent) { + linker := direntElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + direntElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + direntElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type direntEntry struct { + next *Dirent + prev *Dirent +} + +// Next returns the entry that follows e in the list. +func (e *direntEntry) Next() *Dirent { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *direntEntry) Prev() *Dirent { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *direntEntry) SetNext(elem *Dirent) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *direntEntry) SetPrev(elem *Dirent) { + e.prev = elem +} diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go deleted file mode 100644 index 98d69c6f2..000000000 --- a/pkg/sentry/fs/dirent_refs_test.go +++ /dev/null @@ -1,418 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" -) - -func newMockDirInode(ctx context.Context, cache *DirentCache) *Inode { - return NewMockInode(ctx, NewMockMountSource(cache), StableAttr{Type: Directory}) -} - -func TestWalkPositive(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(ctx, newMockDirInode(ctx, nil), "root") - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - name := "d" - d, err := root.walk(ctx, root, name, false) - if err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - - if got := root.ReadRefs(); got != 2 { - t.Fatalf("root has a ref count of %d, want %d", got, 2) - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 1) - } - - d.DecRef() - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - if got := d.ReadRefs(); got != 0 { - t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 0) - } - - root.flush() - - if got := len(root.children); got != 0 { - t.Fatalf("root has %d children, want %d", got, 0) - } -} - -func TestWalkNegative(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(ctx, NewEmptyDir(ctx, nil), "root") - mn := root.Inode.InodeOperations.(*mockInodeOperationsLookupNegative) - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - name := "d" - for i := 0; i < 100; i++ { - _, err := root.walk(ctx, root, name, false) - if err != syscall.ENOENT { - t.Fatalf("root.walk(root, %q) got %v, want %v", name, err, syscall.ENOENT) - } - } - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - if got := len(root.children); got != 1 { - t.Fatalf("root has %d children, want %d", got, 1) - } - - w, ok := root.children[name] - if !ok { - t.Fatalf("root wants child at %q", name) - } - - child := w.Get() - if child == nil { - t.Fatalf("root wants to resolve weak reference") - } - - if !child.(*Dirent).IsNegative() { - t.Fatalf("root found positive child at %q, want negative", name) - } - - if got := child.(*Dirent).ReadRefs(); got != 2 { - t.Fatalf("child has a ref count of %d, want %d", got, 2) - } - - child.DecRef() - - if got := child.(*Dirent).ReadRefs(); got != 1 { - t.Fatalf("child has a ref count of %d, want %d", got, 1) - } - - if got := len(root.children); got != 1 { - t.Fatalf("root has %d children, want %d", got, 1) - } - - root.DecRef() - - if got := root.ReadRefs(); got != 0 { - t.Fatalf("root has a ref count of %d, want %d", got, 0) - } - - AsyncBarrier() - - if got := mn.releaseCalled; got != true { - t.Fatalf("root.Close was called %v, want true", got) - } -} - -type mockInodeOperationsLookupNegative struct { - *MockInodeOperations - releaseCalled bool -} - -func NewEmptyDir(ctx context.Context, cache *DirentCache) *Inode { - m := NewMockMountSource(cache) - return NewInode(ctx, &mockInodeOperationsLookupNegative{ - MockInodeOperations: NewMockInodeOperations(ctx), - }, m, StableAttr{Type: Directory}) -} - -func (m *mockInodeOperationsLookupNegative) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) { - return NewNegativeDirent(p), nil -} - -func (m *mockInodeOperationsLookupNegative) Release(context.Context) { - m.releaseCalled = true -} - -func TestHashNegativeToPositive(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(ctx, NewEmptyDir(ctx, nil), "root") - - name := "d" - _, err := root.walk(ctx, root, name, false) - if err != syscall.ENOENT { - t.Fatalf("root.walk(root, %q) got %v, want %v", name, err, syscall.ENOENT) - } - - if got := root.exists(ctx, root, name); got != false { - t.Fatalf("got %q exists, want does not exist", name) - } - - f, err := root.Create(ctx, root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(%q, _), got error %v, want nil", name, err) - } - d := f.Dirent - - if d.IsNegative() { - t.Fatalf("got negative Dirent, want positive") - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("child %q has a ref count of %d, want %d", name, got, 1) - } - - if got := root.ReadRefs(); got != 2 { - t.Fatalf("root has a ref count of %d, want %d", got, 2) - } - - if got := len(root.children); got != 1 { - t.Fatalf("got %d children, want %d", got, 1) - } - - w, ok := root.children[name] - if !ok { - t.Fatalf("failed to find weak reference to %q", name) - } - - child := w.Get() - if child == nil { - t.Fatalf("want to resolve weak reference") - } - - if child.(*Dirent) != d { - t.Fatalf("got foreign child") - } -} - -func TestRevalidate(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - for _, test := range []struct { - // desc is the test's description. - desc string - - // Whether to make negative Dirents. - makeNegative bool - }{ - { - desc: "Revalidate negative Dirent", - makeNegative: true, - }, - { - desc: "Revalidate positive Dirent", - makeNegative: false, - }, - } { - t.Run(test.desc, func(t *testing.T) { - ctx := contexttest.Context(t) - root := NewDirent(ctx, NewMockInodeRevalidate(ctx, test.makeNegative), "root") - - name := "d" - d1, err := root.walk(ctx, root, name, false) - if !test.makeNegative && err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - d2, err := root.walk(ctx, root, name, false) - if !test.makeNegative && err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - if !test.makeNegative && d1 == d2 { - t.Fatalf("revalidating walk got same *Dirent, want different") - } - if got := len(root.children); got != 1 { - t.Errorf("revalidating walk got %d children, want %d", got, 1) - } - }) - } -} - -type MockInodeOperationsRevalidate struct { - *MockInodeOperations - makeNegative bool -} - -func NewMockInodeRevalidate(ctx context.Context, makeNegative bool) *Inode { - mn := NewMockInodeOperations(ctx) - m := NewMockMountSource(nil) - m.MountSourceOperations.(*MockMountSourceOps).revalidate = true - return NewInode(ctx, &MockInodeOperationsRevalidate{MockInodeOperations: mn, makeNegative: makeNegative}, m, StableAttr{Type: Directory}) -} - -func (m *MockInodeOperationsRevalidate) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) { - if !m.makeNegative { - return m.MockInodeOperations.Lookup(ctx, dir, p) - } - return NewNegativeDirent(p), nil -} - -func TestCreateExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - for _, test := range []struct { - // desc is the test's description. - desc string - - // root is the Dirent to create from. - root *Dirent - - // expected references on walked Dirent. - refs int64 - }{ - { - desc: "Create caching", - root: NewDirent(ctx, NewEmptyDir(ctx, NewDirentCache(1)), "root"), - refs: 2, - }, - { - desc: "Create not caching", - root: NewDirent(ctx, NewEmptyDir(ctx, nil), "root"), - refs: 1, - }, - } { - t.Run(test.desc, func(t *testing.T) { - name := "d" - f, err := test.root.Create(ctx, test.root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(root, %q) failed: %v", name, err) - } - d := f.Dirent - - if got := d.ReadRefs(); got != test.refs { - t.Errorf("dirent has a ref count of %d, want %d", got, test.refs) - } - }) - } -} - -func TestRemoveExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - for _, test := range []struct { - // desc is the test's description. - desc string - - // root is the Dirent to make and remove from. - root *Dirent - }{ - { - desc: "Remove caching", - root: NewDirent(ctx, NewEmptyDir(ctx, NewDirentCache(1)), "root"), - }, - { - desc: "Remove not caching", - root: NewDirent(ctx, NewEmptyDir(ctx, nil), "root"), - }, - } { - t.Run(test.desc, func(t *testing.T) { - name := "d" - f, err := test.root.Create(ctx, test.root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(%q, _) failed: %v", name, err) - } - d := f.Dirent - - if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil { - t.Fatalf("root.Remove(root, %q) failed: %v", name, err) - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("dirent has a ref count of %d, want %d", got, 1) - } - - d.DecRef() - - test.root.flush() - - if got := len(test.root.children); got != 0 { - t.Errorf("root has %d children, want %d", got, 0) - } - }) - } -} - -func TestRenameExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - for _, test := range []struct { - // desc is the test's description. - desc string - - // cache of extra Dirent references, may be nil. - cache *DirentCache - }{ - { - desc: "Rename no caching", - cache: nil, - }, - { - desc: "Rename caching", - cache: NewDirentCache(5), - }, - } { - t.Run(test.desc, func(t *testing.T) { - ctx := contexttest.Context(t) - - dirAttr := StableAttr{Type: Directory} - - oldParent := NewDirent(ctx, NewMockInode(ctx, NewMockMountSource(test.cache), dirAttr), "old_parent") - newParent := NewDirent(ctx, NewMockInode(ctx, NewMockMountSource(test.cache), dirAttr), "new_parent") - - renamed, err := oldParent.Walk(ctx, oldParent, "old_child") - if err != nil { - t.Fatalf("Walk(oldParent, %q) got error %v, want nil", "old_child", err) - } - replaced, err := newParent.Walk(ctx, oldParent, "new_child") - if err != nil { - t.Fatalf("Walk(newParent, %q) got error %v, want nil", "new_child", err) - } - - if err := Rename(contexttest.RootContext(t), oldParent /*root */, oldParent, "old_child", newParent, "new_child"); err != nil { - t.Fatalf("Rename got error %v, want nil", err) - } - - oldParent.flush() - newParent.flush() - - // Expect to have only active references. - if got := renamed.ReadRefs(); got != 1 { - t.Errorf("renamed has ref count %d, want only active references %d", got, 1) - } - if got := replaced.ReadRefs(); got != 1 { - t.Errorf("replaced has ref count %d, want only active references %d", got, 1) - } - }) - } -} diff --git a/pkg/sentry/fs/event_list.go b/pkg/sentry/fs/event_list.go new file mode 100755 index 000000000..0274f41a2 --- /dev/null +++ b/pkg/sentry/fs/event_list.go @@ -0,0 +1,186 @@ +package fs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type eventElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (eventElementMapper) linkerFor(elem *Event) *Event { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type eventList struct { + head *Event + tail *Event +} + +// Reset resets list l to the empty state. +func (l *eventList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *eventList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *eventList) Front() *Event { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *eventList) Back() *Event { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *eventList) PushFront(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + eventElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *eventList) PushBack(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + eventElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *eventList) PushBackList(m *eventList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + eventElementMapper{}.linkerFor(l.tail).SetNext(m.head) + eventElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *eventList) InsertAfter(b, e *Event) { + bLinker := eventElementMapper{}.linkerFor(b) + eLinker := eventElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + eventElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *eventList) InsertBefore(a, e *Event) { + aLinker := eventElementMapper{}.linkerFor(a) + eLinker := eventElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + eventElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *eventList) Remove(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + eventElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + eventElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type eventEntry struct { + next *Event + prev *Event +} + +// Next returns the entry that follows e in the list. +func (e *eventEntry) Next() *Event { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *eventEntry) Prev() *Event { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *eventEntry) SetNext(elem *Event) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *eventEntry) SetPrev(elem *Event) { + e.prev = elem +} diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD deleted file mode 100644 index 1d09e983c..000000000 --- a/pkg/sentry/fs/fdpipe/BUILD +++ /dev/null @@ -1,48 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fdpipe", - srcs = [ - "pipe.go", - "pipe_opener.go", - "pipe_state.go", - ], - imports = ["gvisor.dev/gvisor/pkg/sentry/fs"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/safemem", - "//pkg/secio", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "fdpipe_test", - size = "small", - srcs = [ - "pipe_opener_test.go", - "pipe_test.go", - ], - library = ":fdpipe", - deps = [ - "//pkg/context", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/syserror", - "//pkg/usermem", - "@com_github_google_uuid//:go_default_library", - ], -) diff --git a/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go b/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go new file mode 100755 index 000000000..9ed7a3d41 --- /dev/null +++ b/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go @@ -0,0 +1,27 @@ +// automatically generated by stateify. + +package fdpipe + +import ( + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/state" +) + +func (x *pipeOperations) save(m state.Map) { + x.beforeSave() + var flags fs.FileFlags = x.saveFlags() + m.SaveValue("flags", flags) + m.Save("opener", &x.opener) + m.Save("readAheadBuffer", &x.readAheadBuffer) +} + +func (x *pipeOperations) load(m state.Map) { + m.LoadWait("opener", &x.opener) + m.Load("readAheadBuffer", &x.readAheadBuffer) + m.LoadValue("flags", new(fs.FileFlags), func(y interface{}) { x.loadFlags(y.(fs.FileFlags)) }) + m.AfterLoad(x.afterLoad) +} + +func init() { + state.Register("pkg/sentry/fs/fdpipe.pipeOperations", (*pipeOperations)(nil), state.Fns{Save: (*pipeOperations).save, Load: (*pipeOperations).load}) +} diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go deleted file mode 100644 index e556da48a..000000000 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ /dev/null @@ -1,523 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fdpipe - -import ( - "bytes" - "fmt" - "io" - "os" - "path" - "syscall" - "testing" - "time" - - "github.com/google/uuid" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -type hostOpener struct { - name string -} - -func (h *hostOpener) NonBlockingOpen(_ context.Context, p fs.PermMask) (*fd.FD, error) { - var flags int - switch { - case p.Read && p.Write: - flags = syscall.O_RDWR - case p.Write: - flags = syscall.O_WRONLY - case p.Read: - flags = syscall.O_RDONLY - default: - return nil, syscall.EINVAL - } - f, err := syscall.Open(h.name, flags|syscall.O_NONBLOCK, 0666) - if err != nil { - return nil, err - } - return fd.New(f), nil -} - -func pipename() string { - return fmt.Sprintf(path.Join(os.TempDir(), "test-named-pipe-%s"), uuid.New()) -} - -func mkpipe(name string) error { - return syscall.Mknod(name, syscall.S_IFIFO|0666, 0) -} - -func TestTryOpen(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // makePipe is true if the test case should create the pipe. - makePipe bool - - // flags are the fs.FileFlags used to open the pipe. - flags fs.FileFlags - - // expectFile is true if a fs.File is expected. - expectFile bool - - // err is the expected error - err error - }{ - { - desc: "FileFlags lacking Read and Write are invalid", - makePipe: false, - flags: fs.FileFlags{}, /* bogus */ - expectFile: false, - err: syscall.EINVAL, - }, - { - desc: "NonBlocking Read only error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true, NonBlocking: true}, - expectFile: false, - err: syscall.ENOENT, - }, - { - desc: "NonBlocking Read only success returns immediately", - makePipe: true, - flags: fs.FileFlags{Read: true, NonBlocking: true}, - expectFile: true, - err: nil, - }, - { - desc: "NonBlocking Write only error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Write: true, NonBlocking: true}, - expectFile: false, - err: syscall.ENOENT, - }, - { - desc: "NonBlocking Write only no reader error returns immediately", - makePipe: true, - flags: fs.FileFlags{Write: true, NonBlocking: true}, - expectFile: false, - err: syscall.ENXIO, - }, - { - desc: "ReadWrite error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true, Write: true}, - expectFile: false, - err: syscall.ENOENT, - }, - { - desc: "ReadWrite returns immediately", - makePipe: true, - flags: fs.FileFlags{Read: true, Write: true}, - expectFile: true, - err: nil, - }, - { - desc: "Blocking Write only returns open error", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Write: true}, - expectFile: false, - err: syscall.ENOENT, /* from bogus perms */ - }, - { - desc: "Blocking Read only returns open error", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true}, - expectFile: false, - err: syscall.ENOENT, - }, - { - desc: "Blocking Write only returns with syserror.ErrWouldBlock", - makePipe: true, - flags: fs.FileFlags{Write: true}, - expectFile: false, - err: syserror.ErrWouldBlock, - }, - { - desc: "Blocking Read only returns with syserror.ErrWouldBlock", - makePipe: true, - flags: fs.FileFlags{Read: true}, - expectFile: false, - err: syserror.ErrWouldBlock, - }, - } { - name := pipename() - if test.makePipe { - // Create the pipe. We do this per-test case to keep tests independent. - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer syscall.Unlink(name) - } - - // Use a host opener to keep things simple. - opener := &hostOpener{name: name} - - pipeOpenState := &pipeOpenState{} - ctx := contexttest.Context(t) - pipeOps, err := pipeOpenState.TryOpen(ctx, opener, test.flags) - if unwrapError(err) != test.err { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - if pipeOps != nil { - // Cleanup the state of the pipe, and remove the fd from the - // fdnotifier. Sadly this needed to maintain the correctness - // of other tests because the fdnotifier is global. - pipeOps.Release() - } - continue - } - if (pipeOps != nil) != test.expectFile { - t.Errorf("%s: got non-nil file %v, want %v", test.desc, pipeOps != nil, test.expectFile) - } - if pipeOps != nil { - // Same as above. - pipeOps.Release() - } - } -} - -func TestPipeOpenUnblocksEventually(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // partnerIsReader is true if the goroutine opening the same pipe as the test case - // should open the pipe read only. Otherwise write only. This also means that the - // test case will open the pipe in the opposite way. - partnerIsReader bool - - // partnerIsBlocking is true if the goroutine opening the same pipe as the test case - // should do so without the O_NONBLOCK flag, otherwise opens the pipe with O_NONBLOCK - // until ENXIO is not returned. - partnerIsBlocking bool - }{ - { - desc: "Blocking Read with blocking writer partner opens eventually", - partnerIsReader: false, - partnerIsBlocking: true, - }, - { - desc: "Blocking Write with blocking reader partner opens eventually", - partnerIsReader: true, - partnerIsBlocking: true, - }, - { - desc: "Blocking Read with non-blocking writer partner opens eventually", - partnerIsReader: false, - partnerIsBlocking: false, - }, - { - desc: "Blocking Write with non-blocking reader partner opens eventually", - partnerIsReader: true, - partnerIsBlocking: false, - }, - } { - // Create the pipe. We do this per-test case to keep tests independent. - name := pipename() - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer syscall.Unlink(name) - - // Spawn the partner. - type fderr struct { - fd int - err error - } - errch := make(chan fderr, 1) - go func() { - var flags int - if test.partnerIsReader { - flags = syscall.O_RDONLY - } else { - flags = syscall.O_WRONLY - } - if test.partnerIsBlocking { - fd, err := syscall.Open(name, flags, 0666) - errch <- fderr{fd: fd, err: err} - } else { - var fd int - err := error(syscall.ENXIO) - for err == syscall.ENXIO { - fd, err = syscall.Open(name, flags|syscall.O_NONBLOCK, 0666) - time.Sleep(1 * time.Second) - } - errch <- fderr{fd: fd, err: err} - } - }() - - // Setup file flags for either a read only or write only open. - flags := fs.FileFlags{ - Read: !test.partnerIsReader, - Write: test.partnerIsReader, - } - - // Open the pipe in a blocking way, which should succeed eventually. - opener := &hostOpener{name: name} - ctx := contexttest.Context(t) - pipeOps, err := Open(ctx, opener, flags) - if pipeOps != nil { - // Same as TestTryOpen. - pipeOps.Release() - } - - // Check that the partner opened the file successfully. - e := <-errch - if e.err != nil { - t.Errorf("%s: partner got error %v, wanted nil", test.desc, e.err) - continue - } - // If so, then close the partner fd to avoid leaking an fd. - syscall.Close(e.fd) - - // Check that our blocking open was successful. - if err != nil { - t.Errorf("%s: blocking open got error %v, wanted nil", test.desc, err) - continue - } - if pipeOps == nil { - t.Errorf("%s: blocking open got nil file, wanted non-nil", test.desc) - continue - } - } -} - -func TestCopiedReadAheadBuffer(t *testing.T) { - // Create the pipe. - name := pipename() - if err := mkpipe(name); err != nil { - t.Fatalf("failed to make host pipe: %v", err) - } - defer syscall.Unlink(name) - - // We're taking advantage of the fact that pipes opened read only always return - // success, but internally they are not deemed "opened" until we're sure that - // another writer comes along. This means we can open the same pipe write only - // with no problems + write to it, given that opener.Open already tried to open - // the pipe RDONLY and succeeded, which we know happened if TryOpen returns - // syserror.ErrwouldBlock. - // - // This simulates the open(RDONLY) <-> open(WRONLY)+write race we care about, but - // does not cause our test to be racy (which would be terrible). - opener := &hostOpener{name: name} - pipeOpenState := &pipeOpenState{} - ctx := contexttest.Context(t) - pipeOps, err := pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true}) - if pipeOps != nil { - pipeOps.Release() - t.Fatalf("open(%s, %o) got file, want nil", name, syscall.O_RDONLY) - } - if err != syserror.ErrWouldBlock { - t.Fatalf("open(%s, %o) got error %v, want %v", name, syscall.O_RDONLY, err, syserror.ErrWouldBlock) - } - - // Then open the same pipe write only and write some bytes to it. The next - // time we try to open the pipe read only again via the pipeOpenState, we should - // succeed and buffer some of the bytes written. - fd, err := syscall.Open(name, syscall.O_WRONLY, 0666) - if err != nil { - t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_WRONLY, err) - } - defer syscall.Close(fd) - - data := []byte("hello") - if n, err := syscall.Write(fd, data); n != len(data) || err != nil { - t.Fatalf("write(%v) got (%d, %v), want (%d, nil)", data, n, err, len(data)) - } - - // Try the read again, knowing that it should succeed this time. - pipeOps, err = pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true}) - if pipeOps == nil { - t.Fatalf("open(%s, %o) got nil file, want not nil", name, syscall.O_RDONLY) - } - defer pipeOps.Release() - - if err != nil { - t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_RDONLY, err) - } - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, pipeOps) - - // Check that the file we opened points to a pipe with a non-empty read ahead buffer. - bufsize := len(pipeOps.readAheadBuffer) - if bufsize != 1 { - t.Fatalf("read ahead buffer got %d bytes, want %d", bufsize, 1) - } - - // Now for the final test, try to read everything in, expecting to get back all of - // the bytes that were written at once. Note that in the wild there is no atomic - // read size so expecting to get all bytes from a single writer when there are - // multiple readers is a bad expectation. - buf := make([]byte, len(data)) - ioseq := usermem.BytesIOSequence(buf) - n, err := pipeOps.Read(ctx, file, ioseq, 0) - if err != nil { - t.Fatalf("read request got error %v, want nil", err) - } - if n != int64(len(data)) { - t.Fatalf("read request got %d bytes, want %d", n, len(data)) - } - if !bytes.Equal(buf, data) { - t.Errorf("read request got bytes [%v], want [%v]", buf, data) - } -} - -func TestPipeHangup(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // flags control how we open our end of the pipe and must be read - // only or write only. They also dicate how a coordinating partner - // fd is opened, which is their inverse (read only -> write only, etc). - flags fs.FileFlags - - // hangupSelf if true causes the test case to close our end of the pipe - // and causes hangup errors to be asserted on our coordinating partner's - // fd. If hangupSelf is false, then our partner's fd is closed and the - // hangup errors are expected on our end of the pipe. - hangupSelf bool - }{ - { - desc: "Read only gets hangup error", - flags: fs.FileFlags{Read: true}, - }, - { - desc: "Write only gets hangup error", - flags: fs.FileFlags{Write: true}, - }, - { - desc: "Read only generates hangup error", - flags: fs.FileFlags{Read: true}, - hangupSelf: true, - }, - { - desc: "Write only generates hangup error", - flags: fs.FileFlags{Write: true}, - hangupSelf: true, - }, - } { - if test.flags.Read == test.flags.Write { - t.Errorf("%s: test requires a single reader or writer", test.desc) - continue - } - - // Create the pipe. We do this per-test case to keep tests independent. - name := pipename() - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer syscall.Unlink(name) - - // Fire off a partner routine which tries to open the same pipe blocking, - // which will synchronize with us. The channel allows us to get back the - // fd once we expect this partner routine to succeed, so we can manifest - // hangup events more directly. - fdchan := make(chan int, 1) - go func() { - // Be explicit about the flags to protect the test from - // misconfiguration. - var flags int - if test.flags.Read { - flags = syscall.O_WRONLY - } else { - flags = syscall.O_RDONLY - } - fd, err := syscall.Open(name, flags, 0666) - if err != nil { - t.Logf("Open(%q, %o, 0666) partner failed: %v", name, flags, err) - } - fdchan <- fd - }() - - // Open our end in a blocking way to ensure that we coordinate. - opener := &hostOpener{name: name} - ctx := contexttest.Context(t) - pipeOps, err := Open(ctx, opener, test.flags) - if err != nil { - t.Errorf("%s: Open got error %v, want nil", test.desc, err) - continue - } - // Don't defer file.DecRef here because that causes the hangup we're - // trying to test for. - - // Expect the partner routine to have coordinated with us and get back - // its open fd. - f := <-fdchan - if f < 0 { - t.Errorf("%s: partner routine got fd %d, want > 0", test.desc, f) - pipeOps.Release() - continue - } - - if test.hangupSelf { - // Hangup self and assert that our partner got the expected hangup - // error. - pipeOps.Release() - - if test.flags.Read { - // Partner is writer. - assertWriterHungup(t, test.desc, fd.NewReadWriter(f)) - } else { - // Partner is reader. - assertReaderHungup(t, test.desc, fd.NewReadWriter(f)) - } - } else { - // Hangup our partner and expect us to get the hangup error. - syscall.Close(f) - defer pipeOps.Release() - - if test.flags.Read { - assertReaderHungup(t, test.desc, pipeOps.(*pipeOperations).file) - } else { - assertWriterHungup(t, test.desc, pipeOps.(*pipeOperations).file) - } - } - } -} - -func assertReaderHungup(t *testing.T, desc string, reader io.Reader) bool { - // Drain the pipe completely, it might have crap in it, but expect EOF eventually. - var err error - for err == nil { - _, err = reader.Read(make([]byte, 10)) - } - if err != io.EOF { - t.Errorf("%s: read from self after hangup got error %v, want %v", desc, err, io.EOF) - return false - } - return true -} - -func assertWriterHungup(t *testing.T, desc string, writer io.Writer) bool { - if _, err := writer.Write([]byte("hello")); unwrapError(err) != syscall.EPIPE { - t.Errorf("%s: write to self after hangup got error %v, want %v", desc, err, syscall.EPIPE) - return false - } - return true -} diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go deleted file mode 100644 index 5aff0cc95..000000000 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ /dev/null @@ -1,505 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fdpipe - -import ( - "bytes" - "io" - "os" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -func singlePipeFD() (int, error) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - return -1, err - } - syscall.Close(fds[1]) - return fds[0], nil -} - -func singleDirFD() (int, error) { - return syscall.Open(os.TempDir(), syscall.O_RDONLY, 0666) -} - -func mockPipeDirent(t *testing.T) *fs.Dirent { - ctx := contexttest.Context(t) - node := fs.NewMockInodeOperations(ctx) - node.UAttr = fs.UnstableAttr{ - Perms: fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, - }, - } - inode := fs.NewInode(ctx, node, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - BlockSize: usermem.PageSize, - }) - return fs.NewDirent(ctx, inode, "") -} - -func TestNewPipe(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // getfd generates the fd to pass to newPipeOperations. - getfd func() (int, error) - - // flags are the fs.FileFlags passed to newPipeOperations. - flags fs.FileFlags - - // readAheadBuffer is the buffer passed to newPipeOperations. - readAheadBuffer []byte - - // err is the expected error. - err error - }{ - { - desc: "Cannot make new pipe from bad fd", - getfd: func() (int, error) { return -1, nil }, - err: syscall.EINVAL, - }, - { - desc: "Cannot make new pipe from non-pipe fd", - getfd: singleDirFD, - err: syscall.EINVAL, - }, - { - desc: "Can make new pipe from pipe fd", - getfd: singlePipeFD, - flags: fs.FileFlags{Read: true}, - readAheadBuffer: []byte("hello"), - }, - } { - gfd, err := test.getfd() - if err != nil { - t.Errorf("%s: getfd got (%d, %v), want (fd, nil)", test.desc, gfd, err) - continue - } - f := fd.New(gfd) - - p, err := newPipeOperations(contexttest.Context(t), nil, test.flags, f, test.readAheadBuffer) - if p != nil { - // This is necessary to remove the fd from the global fd notifier. - defer p.Release() - } else { - // If there is no p to DecRef on, because newPipeOperations failed, then the - // file still needs to be closed. - defer f.Close() - } - - if err != test.err { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - continue - } - // Check the state of the pipe given that it was successfully opened. - if err == nil { - if p == nil { - t.Errorf("%s: got nil pipe and nil error, want (pipe, nil)", test.desc) - continue - } - if flags := p.flags; test.flags != flags { - t.Errorf("%s: got file flags %s, want %s", test.desc, flags, test.flags) - continue - } - if len(test.readAheadBuffer) != len(p.readAheadBuffer) { - t.Errorf("%s: got read ahead buffer length %d, want %d", test.desc, len(p.readAheadBuffer), len(test.readAheadBuffer)) - continue - } - fileFlags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(p.file.FD()), syscall.F_GETFL, 0) - if errno != 0 { - t.Errorf("%s: failed to get file flags for fd %d, got %v, want 0", test.desc, p.file.FD(), errno) - continue - } - if fileFlags&syscall.O_NONBLOCK == 0 { - t.Errorf("%s: pipe is blocking, expected non-blocking", test.desc) - continue - } - if !fdnotifier.HasFD(int32(f.FD())) { - t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD) - } - } - } -} - -func TestPipeDestruction(t *testing.T) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - f := fd.New(fds[0]) - - // We don't care about the other end, just use the read end. - syscall.Close(fds[1]) - - // Test the read end, but it doesn't really matter which. - p, err := newPipeOperations(contexttest.Context(t), nil, fs.FileFlags{Read: true}, f, nil) - if err != nil { - f.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Drop our only reference, which should trigger the destructor. - p.Release() - - if fdnotifier.HasFD(int32(fds[0])) { - t.Fatalf("after DecRef fdnotifier has fd %d, want no longer registered", fds[0]) - } - if p.file != nil { - t.Errorf("after DecRef got file, want nil") - } -} - -type Seek struct{} - -type ReadDir struct{} - -type Writev struct { - Src usermem.IOSequence -} - -type Readv struct { - Dst usermem.IOSequence -} - -type Fsync struct{} - -func TestPipeRequest(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // request to execute. - context interface{} - - // flags determines whether to use the read or write end - // of the pipe, for this test it can only be Read or Write. - flags fs.FileFlags - - // keepOpenPartner if false closes the other end of the pipe, - // otherwise this is delayed until the end of the test. - keepOpenPartner bool - - // expected error - err error - }{ - { - desc: "ReadDir on pipe returns ENOTDIR", - context: &ReadDir{}, - err: syscall.ENOTDIR, - }, - { - desc: "Fsync on pipe returns EINVAL", - context: &Fsync{}, - err: syscall.EINVAL, - }, - { - desc: "Seek on pipe returns ESPIPE", - context: &Seek{}, - err: syscall.ESPIPE, - }, - { - desc: "Readv on pipe from empty buffer returns nil", - context: &Readv{Dst: usermem.BytesIOSequence(nil)}, - flags: fs.FileFlags{Read: true}, - }, - { - desc: "Readv on pipe from non-empty buffer and closed partner returns EOF", - context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))}, - flags: fs.FileFlags{Read: true}, - err: io.EOF, - }, - { - desc: "Readv on pipe from non-empty buffer and open partner returns EWOULDBLOCK", - context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))}, - flags: fs.FileFlags{Read: true}, - keepOpenPartner: true, - err: syserror.ErrWouldBlock, - }, - { - desc: "Writev on pipe from empty buffer returns nil", - context: &Writev{Src: usermem.BytesIOSequence(nil)}, - flags: fs.FileFlags{Write: true}, - }, - { - desc: "Writev on pipe from non-empty buffer and closed partner returns EPIPE", - context: &Writev{Src: usermem.BytesIOSequence([]byte("hello"))}, - flags: fs.FileFlags{Write: true}, - err: syscall.EPIPE, - }, - { - desc: "Writev on pipe from non-empty buffer and open partner succeeds", - context: &Writev{Src: usermem.BytesIOSequence([]byte("hello"))}, - flags: fs.FileFlags{Write: true}, - keepOpenPartner: true, - }, - } { - if test.flags.Read && test.flags.Write { - panic("both read and write not supported for this test") - } - - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Errorf("%s: failed to create pipes: got %v, want nil", test.desc, err) - continue - } - - // Configure the fd and partner fd based on the file flags. - testFd, partnerFd := fds[0], fds[1] - if test.flags.Write { - testFd, partnerFd = fds[1], fds[0] - } - - // Configure closing the fds. - if test.keepOpenPartner { - defer syscall.Close(partnerFd) - } else { - syscall.Close(partnerFd) - } - - // Create the pipe. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, test.flags, fd.New(testFd), nil) - if err != nil { - t.Fatalf("%s: newPipeOperations got error %v, want nil", test.desc, err) - } - defer p.Release() - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) - - // Issue request via the appropriate function. - switch c := test.context.(type) { - case *Seek: - _, err = p.Seek(ctx, file, 0, 0) - case *ReadDir: - _, err = p.Readdir(ctx, file, nil) - case *Readv: - _, err = p.Read(ctx, file, c.Dst, 0) - case *Writev: - _, err = p.Write(ctx, file, c.Src, 0) - case *Fsync: - err = p.Fsync(ctx, file, 0, fs.FileMaxOffset, fs.SyncAll) - default: - t.Errorf("%s: unknown request type %T", test.desc, test.context) - } - - if unwrapError(err) != test.err { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - } - } -} - -func TestPipeReadAheadBuffer(t *testing.T) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - rfile := fd.New(fds[0]) - - // Eventually close the write end, which is not wrapped in a pipe object. - defer syscall.Close(fds[1]) - - // Write some bytes to this end. - data := []byte("world") - if n, err := syscall.Write(fds[1], data); n != len(data) || err != nil { - rfile.Close() - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(data)) - } - // Close the write end immediately, we don't care about it. - - buffered := []byte("hello ") - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, rfile, buffered) - if err != nil { - rfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - defer p.Release() - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) - - // In total we expect to read data + buffered. - total := append(buffered, data...) - - buf := make([]byte, len(total)) - iov := usermem.BytesIOSequence(buf) - n, err := p.Read(contexttest.Context(t), file, iov, 0) - if err != nil { - t.Fatalf("read request got error %v, want nil", err) - } - if n != int64(len(total)) { - t.Fatalf("read request got %d bytes, want %d", n, len(total)) - } - if !bytes.Equal(buf, total) { - t.Errorf("read request got bytes [%v], want [%v]", buf, total) - } -} - -// This is very important for pipes in general because they can return -// EWOULDBLOCK and for those that block they must continue until they have read -// all of the data (and report it as such). -func TestPipeReadsAccumulate(t *testing.T) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - rfile := fd.New(fds[0]) - - // Eventually close the write end, it doesn't depend on a pipe object. - defer syscall.Close(fds[1]) - - // Get a new read only pipe reference. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, rfile, nil) - if err != nil { - rfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Don't forget to remove the fd from the fd notifier. Otherwise other tests will - // likely be borked, because it's global :( - defer p.Release() - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) - - // Write some some bytes to the pipe. - data := []byte("some message") - if n, err := syscall.Write(fds[1], data); n != len(data) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(data)) - } - - // Construct a segment vec that is a bit more than we have written so we - // trigger an EWOULDBLOCK. - wantBytes := len(data) + 1 - readBuffer := make([]byte, wantBytes) - iov := usermem.BytesIOSequence(readBuffer) - n, err := p.Read(ctx, file, iov, 0) - total := n - iov = iov.DropFirst64(n) - if err != syserror.ErrWouldBlock { - t.Fatalf("Readv got error %v, want %v", err, syserror.ErrWouldBlock) - } - - // Write a few more bytes to allow us to read more/accumulate. - extra := []byte("extra") - if n, err := syscall.Write(fds[1], extra); n != len(extra) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(extra)) - } - - // This time, using the same request, we should not block. - n, err = p.Read(ctx, file, iov, 0) - total += n - if err != nil { - t.Fatalf("Readv got error %v, want nil", err) - } - - // Assert that the result we got back is cumulative. - if total != int64(wantBytes) { - t.Fatalf("Readv sequence got %d bytes, want %d", total, wantBytes) - } - - if want := append(data, extra[0]); !bytes.Equal(readBuffer, want) { - t.Errorf("Readv sequence got %v, want %v", readBuffer, want) - } -} - -// Same as TestReadsAccumulate. -func TestPipeWritesAccumulate(t *testing.T) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - wfile := fd.New(fds[1]) - - // Eventually close the read end, it doesn't depend on a pipe object. - defer syscall.Close(fds[0]) - - // Get a new write only pipe reference. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Write: true}, wfile, nil) - if err != nil { - wfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Don't forget to remove the fd from the fd notifier. Otherwise other tests - // will likely be borked, because it's global :( - defer p.Release() - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) - - pipeSize, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(wfile.FD()), syscall.F_GETPIPE_SZ, 0) - if errno != 0 { - t.Fatalf("fcntl(F_GETPIPE_SZ) failed: %v", errno) - } - t.Logf("Pipe buffer size: %d", pipeSize) - - // Construct a segment vec that is larger than the pipe size to trigger an - // EWOULDBLOCK. - wantBytes := int(pipeSize) * 2 - writeBuffer := make([]byte, wantBytes) - for i := 0; i < wantBytes; i++ { - writeBuffer[i] = 'a' - } - iov := usermem.BytesIOSequence(writeBuffer) - n, err := p.Write(ctx, file, iov, 0) - if err != syserror.ErrWouldBlock { - t.Fatalf("Writev got error %v, want %v", err, syserror.ErrWouldBlock) - } - if n != int64(pipeSize) { - t.Fatalf("Writev partial write, got: %v, want %v", n, pipeSize) - } - total := n - iov = iov.DropFirst64(n) - - // Read the entire pipe buf size to make space for the second half. - readBuffer := make([]byte, n) - if n, err := syscall.Read(fds[0], readBuffer); n != len(readBuffer) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(readBuffer)) - } - if !bytes.Equal(readBuffer, writeBuffer[:len(readBuffer)]) { - t.Fatalf("wrong data read from pipe, got: %v, want: %v", readBuffer, writeBuffer) - } - - // This time we should not block. - n, err = p.Write(ctx, file, iov, 0) - if err != nil { - t.Fatalf("Writev got error %v, want nil", err) - } - if n != int64(pipeSize) { - t.Fatalf("Writev partial write, got: %v, want %v", n, pipeSize) - } - total += n - - // Assert that the result we got back is cumulative. - if total != int64(wantBytes) { - t.Fatalf("Writev sequence got %d bytes, want %d", total, wantBytes) - } -} diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go deleted file mode 100644 index a76d87e3a..000000000 --- a/pkg/sentry/fs/file_overlay_test.go +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" -) - -func TestReaddir(t *testing.T) { - ctx := contexttest.Context(t) - ctx = &rootContext{ - Context: ctx, - root: fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root"), - } - for _, test := range []struct { - // Test description. - desc string - - // Lookup parameters. - dir *fs.Inode - - // Want from lookup. - err error - names []string - }{ - { - desc: "no upper, lower has entries", - dir: fs.NewTestOverlayDir(ctx, - nil, /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - {name: "b"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper has entries, no lower", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - {name: "b"}, - }, nil), /* upper */ - nil, /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper and lower, entries combine", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "b"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper and lower, entries combine, none are masked", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "c"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "c"}, - }, - { - desc: "upper and lower, entries combine, upper masks some of lower", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "b"}, /* will be masked */ - {name: "c"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "c"}, - }, - } { - t.Run(test.desc, func(t *testing.T) { - openDir, err := test.dir.GetFile(ctx, fs.NewDirent(ctx, test.dir, "stub"), fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile got error %v, want nil", err) - } - stubSerializer := &fs.CollectEntriesSerializer{} - err = openDir.Readdir(ctx, stubSerializer) - if err != test.err { - t.Fatalf("Readdir got error %v, want nil", err) - } - if err != nil { - return - } - if !reflect.DeepEqual(stubSerializer.Order, test.names) { - t.Errorf("Readdir got names %v, want %v", stubSerializer.Order, test.names) - } - }) - } -} - -func TestReaddirRevalidation(t *testing.T) { - ctx := contexttest.Context(t) - ctx = &rootContext{ - Context: ctx, - root: fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root"), - } - - // Create an overlay with two directories, each with one file. - upper := newTestRamfsDir(ctx, []dirContent{{name: "a"}}, nil) - lower := newTestRamfsDir(ctx, []dirContent{{name: "b"}}, nil) - overlay := fs.NewTestOverlayDir(ctx, upper, lower, true /* revalidate */) - - // Get a handle to the dirent in the upper filesystem so that we can - // modify it without going through the dirent. - upperDir := upper.InodeOperations.(*dir).InodeOperations.(*ramfs.Dir) - - // Check that overlay returns the files from both upper and lower. - openDir, err := overlay.GetFile(ctx, fs.NewDirent(ctx, overlay, "stub"), fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile got error %v, want nil", err) - } - ser := &fs.CollectEntriesSerializer{} - if err := openDir.Readdir(ctx, ser); err != nil { - t.Fatalf("Readdir got error %v, want nil", err) - } - got, want := ser.Order, []string{".", "..", "a", "b"} - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } - - // Remove "a" from the upper and add "c". - if err := upperDir.Remove(ctx, upper, "a"); err != nil { - t.Fatalf("error removing child: %v", err) - } - upperDir.AddChild(ctx, "c", fs.NewInode(ctx, fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermissions{}, 0), - upper.MountSource, fs.StableAttr{Type: fs.RegularFile})) - - // Seek to beginning of the directory and do the readdir again. - if _, err := openDir.Seek(ctx, fs.SeekSet, 0); err != nil { - t.Fatalf("error seeking to beginning of dir: %v", err) - } - ser = &fs.CollectEntriesSerializer{} - if err := openDir.Readdir(ctx, ser); err != nil { - t.Fatalf("Readdir got error %v, want nil", err) - } - - // Readdir should return the updated children. - got, want = ser.Order, []string{".", "..", "b", "c"} - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } -} - -// TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with -// a frozen dirent tree does not make Readdir calls to the underlying files. -// This is a regression test for b/114808269. -func TestReaddirOverlayFrozen(t *testing.T) { - ctx := contexttest.Context(t) - - // Create an overlay with two directories, each with two files. - upper := newTestRamfsDir(ctx, []dirContent{{name: "upper-file1"}, {name: "upper-file2"}}, nil) - lower := newTestRamfsDir(ctx, []dirContent{{name: "lower-file1"}, {name: "lower-file2"}}, nil) - overlayInode := fs.NewTestOverlayDir(ctx, upper, lower, false) - - // Set that overlay as the root. - root := fs.NewDirent(ctx, overlayInode, "root") - ctx = &rootContext{ - Context: ctx, - root: root, - } - - // Check that calling Readdir on the root now returns all 4 files (2 - // from each layer in the overlay). - rootFile, err := root.Inode.GetFile(ctx, root, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("root.Inode.GetFile failed: %v", err) - } - defer rootFile.DecRef() - ser := &fs.CollectEntriesSerializer{} - if err := rootFile.Readdir(ctx, ser); err != nil { - t.Fatalf("rootFile.Readdir failed: %v", err) - } - if got, want := ser.Order, []string{".", "..", "lower-file1", "lower-file2", "upper-file1", "upper-file2"}; !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } - - // Readdir should have been called on upper and lower. - upperDir := upper.InodeOperations.(*dir) - lowerDir := lower.InodeOperations.(*dir) - if !upperDir.ReaddirCalled { - t.Errorf("upperDir.ReaddirCalled got %v, want true", upperDir.ReaddirCalled) - } - if !lowerDir.ReaddirCalled { - t.Errorf("lowerDir.ReaddirCalled got %v, want true", lowerDir.ReaddirCalled) - } - - // Reset. - upperDir.ReaddirCalled = false - lowerDir.ReaddirCalled = false - - // Take references on "upper-file1" and "lower-file1", pinning them in - // the dirent tree. - for _, name := range []string{"upper-file1", "lower-file1"} { - if _, err := root.Walk(ctx, root, name); err != nil { - t.Fatalf("root.Walk(%q) failed: %v", name, err) - } - // Don't drop a reference on the returned dirent so that it - // will stay in the tree. - } - - // Freeze the dirent tree. - root.Freeze() - - // Seek back to the beginning of the file. - if _, err := rootFile.Seek(ctx, fs.SeekSet, 0); err != nil { - t.Fatalf("error seeking to beginning of directory: %v", err) - } - - // Calling Readdir on the root now will return only the pinned - // children. - ser = &fs.CollectEntriesSerializer{} - if err := rootFile.Readdir(ctx, ser); err != nil { - t.Fatalf("rootFile.Readdir failed: %v", err) - } - if got, want := ser.Order, []string{".", "..", "lower-file1", "upper-file1"}; !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } - - // Readdir should NOT have been called on upper or lower. - if upperDir.ReaddirCalled { - t.Errorf("upperDir.ReaddirCalled got %v, want false", upperDir.ReaddirCalled) - } - if lowerDir.ReaddirCalled { - t.Errorf("lowerDir.ReaddirCalled got %v, want false", lowerDir.ReaddirCalled) - } -} - -type rootContext struct { - context.Context - root *fs.Dirent -} - -// Value implements context.Context. -func (r *rootContext) Value(key interface{}) interface{} { - switch key { - case fs.CtxRoot: - r.root.IncRef() - return r.root - default: - return r.Context.Value(key) - } -} diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD deleted file mode 100644 index a8000e010..000000000 --- a/pkg/sentry/fs/filetest/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "filetest", - testonly = 1, - srcs = ["filetest.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go deleted file mode 100644 index 8049538f2..000000000 --- a/pkg/sentry/fs/filetest/filetest.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package filetest provides a test implementation of an fs.File. -package filetest - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -// TestFileOperations is an implementation of the File interface. It provides all -// required methods. -type TestFileOperations struct { - fsutil.FileNoopRelease `state:"nosave"` - fsutil.FilePipeSeek `state:"nosave"` - fsutil.FileNotDirReaddir `state:"nosave"` - fsutil.FileNoFsync `state:"nosave"` - fsutil.FileNoopFlush `state:"nosave"` - fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoIoctl `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` - fsutil.FileUseInodeUnstableAttr `state:"nosave"` - waiter.AlwaysReady `state:"nosave"` -} - -// NewTestFile creates and initializes a new test file. -func NewTestFile(tb testing.TB) *fs.File { - ctx := contexttest.Context(tb) - dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "test") - return fs.NewFile(ctx, dirent, fs.FileFlags{}, &TestFileOperations{}) -} - -// Read just fails the request. -func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Readv not implemented") -} - -// Write just fails the request. -func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Writev not implemented") -} diff --git a/pkg/sentry/fs/fs_state_autogen.go b/pkg/sentry/fs/fs_state_autogen.go new file mode 100755 index 000000000..74d56c30a --- /dev/null +++ b/pkg/sentry/fs/fs_state_autogen.go @@ -0,0 +1,640 @@ +// automatically generated by stateify. + +package fs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *StableAttr) beforeSave() {} +func (x *StableAttr) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("DeviceID", &x.DeviceID) + m.Save("InodeID", &x.InodeID) + m.Save("BlockSize", &x.BlockSize) + m.Save("DeviceFileMajor", &x.DeviceFileMajor) + m.Save("DeviceFileMinor", &x.DeviceFileMinor) +} + +func (x *StableAttr) afterLoad() {} +func (x *StableAttr) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("DeviceID", &x.DeviceID) + m.Load("InodeID", &x.InodeID) + m.Load("BlockSize", &x.BlockSize) + m.Load("DeviceFileMajor", &x.DeviceFileMajor) + m.Load("DeviceFileMinor", &x.DeviceFileMinor) +} + +func (x *UnstableAttr) beforeSave() {} +func (x *UnstableAttr) save(m state.Map) { + x.beforeSave() + m.Save("Size", &x.Size) + m.Save("Usage", &x.Usage) + m.Save("Perms", &x.Perms) + m.Save("Owner", &x.Owner) + m.Save("AccessTime", &x.AccessTime) + m.Save("ModificationTime", &x.ModificationTime) + m.Save("StatusChangeTime", &x.StatusChangeTime) + m.Save("Links", &x.Links) +} + +func (x *UnstableAttr) afterLoad() {} +func (x *UnstableAttr) load(m state.Map) { + m.Load("Size", &x.Size) + m.Load("Usage", &x.Usage) + m.Load("Perms", &x.Perms) + m.Load("Owner", &x.Owner) + m.Load("AccessTime", &x.AccessTime) + m.Load("ModificationTime", &x.ModificationTime) + m.Load("StatusChangeTime", &x.StatusChangeTime) + m.Load("Links", &x.Links) +} + +func (x *AttrMask) beforeSave() {} +func (x *AttrMask) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("DeviceID", &x.DeviceID) + m.Save("InodeID", &x.InodeID) + m.Save("BlockSize", &x.BlockSize) + m.Save("Size", &x.Size) + m.Save("Usage", &x.Usage) + m.Save("Perms", &x.Perms) + m.Save("UID", &x.UID) + m.Save("GID", &x.GID) + m.Save("AccessTime", &x.AccessTime) + m.Save("ModificationTime", &x.ModificationTime) + m.Save("StatusChangeTime", &x.StatusChangeTime) + m.Save("Links", &x.Links) +} + +func (x *AttrMask) afterLoad() {} +func (x *AttrMask) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("DeviceID", &x.DeviceID) + m.Load("InodeID", &x.InodeID) + m.Load("BlockSize", &x.BlockSize) + m.Load("Size", &x.Size) + m.Load("Usage", &x.Usage) + m.Load("Perms", &x.Perms) + m.Load("UID", &x.UID) + m.Load("GID", &x.GID) + m.Load("AccessTime", &x.AccessTime) + m.Load("ModificationTime", &x.ModificationTime) + m.Load("StatusChangeTime", &x.StatusChangeTime) + m.Load("Links", &x.Links) +} + +func (x *PermMask) beforeSave() {} +func (x *PermMask) save(m state.Map) { + x.beforeSave() + m.Save("Read", &x.Read) + m.Save("Write", &x.Write) + m.Save("Execute", &x.Execute) +} + +func (x *PermMask) afterLoad() {} +func (x *PermMask) load(m state.Map) { + m.Load("Read", &x.Read) + m.Load("Write", &x.Write) + m.Load("Execute", &x.Execute) +} + +func (x *FilePermissions) beforeSave() {} +func (x *FilePermissions) save(m state.Map) { + x.beforeSave() + m.Save("User", &x.User) + m.Save("Group", &x.Group) + m.Save("Other", &x.Other) + m.Save("Sticky", &x.Sticky) + m.Save("SetUID", &x.SetUID) + m.Save("SetGID", &x.SetGID) +} + +func (x *FilePermissions) afterLoad() {} +func (x *FilePermissions) load(m state.Map) { + m.Load("User", &x.User) + m.Load("Group", &x.Group) + m.Load("Other", &x.Other) + m.Load("Sticky", &x.Sticky) + m.Load("SetUID", &x.SetUID) + m.Load("SetGID", &x.SetGID) +} + +func (x *FileOwner) beforeSave() {} +func (x *FileOwner) save(m state.Map) { + x.beforeSave() + m.Save("UID", &x.UID) + m.Save("GID", &x.GID) +} + +func (x *FileOwner) afterLoad() {} +func (x *FileOwner) load(m state.Map) { + m.Load("UID", &x.UID) + m.Load("GID", &x.GID) +} + +func (x *DentAttr) beforeSave() {} +func (x *DentAttr) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("InodeID", &x.InodeID) +} + +func (x *DentAttr) afterLoad() {} +func (x *DentAttr) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("InodeID", &x.InodeID) +} + +func (x *SortedDentryMap) beforeSave() {} +func (x *SortedDentryMap) save(m state.Map) { + x.beforeSave() + m.Save("names", &x.names) + m.Save("entries", &x.entries) +} + +func (x *SortedDentryMap) afterLoad() {} +func (x *SortedDentryMap) load(m state.Map) { + m.Load("names", &x.names) + m.Load("entries", &x.entries) +} + +func (x *Dirent) save(m state.Map) { + x.beforeSave() + var children map[string]*Dirent = x.saveChildren() + m.SaveValue("children", children) + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("userVisible", &x.userVisible) + m.Save("Inode", &x.Inode) + m.Save("name", &x.name) + m.Save("parent", &x.parent) + m.Save("deleted", &x.deleted) + m.Save("frozen", &x.frozen) + m.Save("mounted", &x.mounted) +} + +func (x *Dirent) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("userVisible", &x.userVisible) + m.Load("Inode", &x.Inode) + m.Load("name", &x.name) + m.Load("parent", &x.parent) + m.Load("deleted", &x.deleted) + m.Load("frozen", &x.frozen) + m.Load("mounted", &x.mounted) + m.LoadValue("children", new(map[string]*Dirent), func(y interface{}) { x.loadChildren(y.(map[string]*Dirent)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *DirentCache) beforeSave() {} +func (x *DirentCache) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.currentSize) { + m.Failf("currentSize is %v, expected zero", x.currentSize) + } + if !state.IsZeroValue(x.list) { + m.Failf("list is %v, expected zero", x.list) + } + m.Save("maxSize", &x.maxSize) + m.Save("limit", &x.limit) +} + +func (x *DirentCache) afterLoad() {} +func (x *DirentCache) load(m state.Map) { + m.Load("maxSize", &x.maxSize) + m.Load("limit", &x.limit) +} + +func (x *DirentCacheLimiter) beforeSave() {} +func (x *DirentCacheLimiter) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.count) { + m.Failf("count is %v, expected zero", x.count) + } + m.Save("max", &x.max) +} + +func (x *DirentCacheLimiter) afterLoad() {} +func (x *DirentCacheLimiter) load(m state.Map) { + m.Load("max", &x.max) +} + +func (x *direntList) beforeSave() {} +func (x *direntList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *direntList) afterLoad() {} +func (x *direntList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *direntEntry) beforeSave() {} +func (x *direntEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *direntEntry) afterLoad() {} +func (x *direntEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *eventList) beforeSave() {} +func (x *eventList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *eventList) afterLoad() {} +func (x *eventList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *eventEntry) beforeSave() {} +func (x *eventEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *eventEntry) afterLoad() {} +func (x *eventEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *File) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("UniqueID", &x.UniqueID) + m.Save("Dirent", &x.Dirent) + m.Save("flags", &x.flags) + m.Save("async", &x.async) + m.Save("FileOperations", &x.FileOperations) + m.Save("offset", &x.offset) +} + +func (x *File) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("UniqueID", &x.UniqueID) + m.Load("Dirent", &x.Dirent) + m.Load("flags", &x.flags) + m.Load("async", &x.async) + m.LoadWait("FileOperations", &x.FileOperations) + m.Load("offset", &x.offset) + m.AfterLoad(x.afterLoad) +} + +func (x *overlayFileOperations) beforeSave() {} +func (x *overlayFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("upper", &x.upper) + m.Save("lower", &x.lower) + m.Save("dirCursor", &x.dirCursor) +} + +func (x *overlayFileOperations) afterLoad() {} +func (x *overlayFileOperations) load(m state.Map) { + m.Load("upper", &x.upper) + m.Load("lower", &x.lower) + m.Load("dirCursor", &x.dirCursor) +} + +func (x *overlayMappingIdentity) beforeSave() {} +func (x *overlayMappingIdentity) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("id", &x.id) + m.Save("overlayFile", &x.overlayFile) +} + +func (x *overlayMappingIdentity) afterLoad() {} +func (x *overlayMappingIdentity) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("id", &x.id) + m.Load("overlayFile", &x.overlayFile) +} + +func (x *MountSourceFlags) beforeSave() {} +func (x *MountSourceFlags) save(m state.Map) { + x.beforeSave() + m.Save("ReadOnly", &x.ReadOnly) + m.Save("NoAtime", &x.NoAtime) + m.Save("ForcePageCache", &x.ForcePageCache) + m.Save("NoExec", &x.NoExec) +} + +func (x *MountSourceFlags) afterLoad() {} +func (x *MountSourceFlags) load(m state.Map) { + m.Load("ReadOnly", &x.ReadOnly) + m.Load("NoAtime", &x.NoAtime) + m.Load("ForcePageCache", &x.ForcePageCache) + m.Load("NoExec", &x.NoExec) +} + +func (x *FileFlags) beforeSave() {} +func (x *FileFlags) save(m state.Map) { + x.beforeSave() + m.Save("Direct", &x.Direct) + m.Save("NonBlocking", &x.NonBlocking) + m.Save("DSync", &x.DSync) + m.Save("Sync", &x.Sync) + m.Save("Append", &x.Append) + m.Save("Read", &x.Read) + m.Save("Write", &x.Write) + m.Save("Pread", &x.Pread) + m.Save("Pwrite", &x.Pwrite) + m.Save("Directory", &x.Directory) + m.Save("Async", &x.Async) + m.Save("LargeFile", &x.LargeFile) + m.Save("NonSeekable", &x.NonSeekable) + m.Save("Truncate", &x.Truncate) +} + +func (x *FileFlags) afterLoad() {} +func (x *FileFlags) load(m state.Map) { + m.Load("Direct", &x.Direct) + m.Load("NonBlocking", &x.NonBlocking) + m.Load("DSync", &x.DSync) + m.Load("Sync", &x.Sync) + m.Load("Append", &x.Append) + m.Load("Read", &x.Read) + m.Load("Write", &x.Write) + m.Load("Pread", &x.Pread) + m.Load("Pwrite", &x.Pwrite) + m.Load("Directory", &x.Directory) + m.Load("Async", &x.Async) + m.Load("LargeFile", &x.LargeFile) + m.Load("NonSeekable", &x.NonSeekable) + m.Load("Truncate", &x.Truncate) +} + +func (x *Inode) beforeSave() {} +func (x *Inode) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("InodeOperations", &x.InodeOperations) + m.Save("StableAttr", &x.StableAttr) + m.Save("LockCtx", &x.LockCtx) + m.Save("Watches", &x.Watches) + m.Save("MountSource", &x.MountSource) + m.Save("overlay", &x.overlay) +} + +func (x *Inode) afterLoad() {} +func (x *Inode) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("InodeOperations", &x.InodeOperations) + m.Load("StableAttr", &x.StableAttr) + m.Load("LockCtx", &x.LockCtx) + m.Load("Watches", &x.Watches) + m.Load("MountSource", &x.MountSource) + m.Load("overlay", &x.overlay) +} + +func (x *LockCtx) beforeSave() {} +func (x *LockCtx) save(m state.Map) { + x.beforeSave() + m.Save("Posix", &x.Posix) + m.Save("BSD", &x.BSD) +} + +func (x *LockCtx) afterLoad() {} +func (x *LockCtx) load(m state.Map) { + m.Load("Posix", &x.Posix) + m.Load("BSD", &x.BSD) +} + +func (x *Watches) beforeSave() {} +func (x *Watches) save(m state.Map) { + x.beforeSave() + m.Save("ws", &x.ws) + m.Save("unlinked", &x.unlinked) +} + +func (x *Watches) afterLoad() {} +func (x *Watches) load(m state.Map) { + m.Load("ws", &x.ws) + m.Load("unlinked", &x.unlinked) +} + +func (x *Inotify) beforeSave() {} +func (x *Inotify) save(m state.Map) { + x.beforeSave() + m.Save("id", &x.id) + m.Save("events", &x.events) + m.Save("scratch", &x.scratch) + m.Save("nextWatch", &x.nextWatch) + m.Save("watches", &x.watches) +} + +func (x *Inotify) afterLoad() {} +func (x *Inotify) load(m state.Map) { + m.Load("id", &x.id) + m.Load("events", &x.events) + m.Load("scratch", &x.scratch) + m.Load("nextWatch", &x.nextWatch) + m.Load("watches", &x.watches) +} + +func (x *Event) beforeSave() {} +func (x *Event) save(m state.Map) { + x.beforeSave() + m.Save("eventEntry", &x.eventEntry) + m.Save("wd", &x.wd) + m.Save("mask", &x.mask) + m.Save("cookie", &x.cookie) + m.Save("len", &x.len) + m.Save("name", &x.name) +} + +func (x *Event) afterLoad() {} +func (x *Event) load(m state.Map) { + m.Load("eventEntry", &x.eventEntry) + m.Load("wd", &x.wd) + m.Load("mask", &x.mask) + m.Load("cookie", &x.cookie) + m.Load("len", &x.len) + m.Load("name", &x.name) +} + +func (x *Watch) beforeSave() {} +func (x *Watch) save(m state.Map) { + x.beforeSave() + m.Save("owner", &x.owner) + m.Save("wd", &x.wd) + m.Save("target", &x.target) + m.Save("unpinned", &x.unpinned) + m.Save("mask", &x.mask) + m.Save("pins", &x.pins) +} + +func (x *Watch) afterLoad() {} +func (x *Watch) load(m state.Map) { + m.Load("owner", &x.owner) + m.Load("wd", &x.wd) + m.Load("target", &x.target) + m.Load("unpinned", &x.unpinned) + m.Load("mask", &x.mask) + m.Load("pins", &x.pins) +} + +func (x *MountSource) beforeSave() {} +func (x *MountSource) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("MountSourceOperations", &x.MountSourceOperations) + m.Save("FilesystemType", &x.FilesystemType) + m.Save("Flags", &x.Flags) + m.Save("fscache", &x.fscache) + m.Save("direntRefs", &x.direntRefs) +} + +func (x *MountSource) afterLoad() {} +func (x *MountSource) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("MountSourceOperations", &x.MountSourceOperations) + m.Load("FilesystemType", &x.FilesystemType) + m.Load("Flags", &x.Flags) + m.Load("fscache", &x.fscache) + m.Load("direntRefs", &x.direntRefs) +} + +func (x *SimpleMountSourceOperations) beforeSave() {} +func (x *SimpleMountSourceOperations) save(m state.Map) { + x.beforeSave() + m.Save("keep", &x.keep) + m.Save("revalidate", &x.revalidate) + m.Save("cacheReaddir", &x.cacheReaddir) +} + +func (x *SimpleMountSourceOperations) afterLoad() {} +func (x *SimpleMountSourceOperations) load(m state.Map) { + m.Load("keep", &x.keep) + m.Load("revalidate", &x.revalidate) + m.Load("cacheReaddir", &x.cacheReaddir) +} + +func (x *overlayMountSourceOperations) beforeSave() {} +func (x *overlayMountSourceOperations) save(m state.Map) { + x.beforeSave() + m.Save("upper", &x.upper) + m.Save("lower", &x.lower) +} + +func (x *overlayMountSourceOperations) afterLoad() {} +func (x *overlayMountSourceOperations) load(m state.Map) { + m.Load("upper", &x.upper) + m.Load("lower", &x.lower) +} + +func (x *overlayFilesystem) beforeSave() {} +func (x *overlayFilesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *overlayFilesystem) afterLoad() {} +func (x *overlayFilesystem) load(m state.Map) { +} + +func (x *Mount) beforeSave() {} +func (x *Mount) save(m state.Map) { + x.beforeSave() + m.Save("ID", &x.ID) + m.Save("ParentID", &x.ParentID) + m.Save("root", &x.root) + m.Save("previous", &x.previous) +} + +func (x *Mount) afterLoad() {} +func (x *Mount) load(m state.Map) { + m.Load("ID", &x.ID) + m.Load("ParentID", &x.ParentID) + m.Load("root", &x.root) + m.Load("previous", &x.previous) +} + +func (x *MountNamespace) beforeSave() {} +func (x *MountNamespace) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("userns", &x.userns) + m.Save("root", &x.root) + m.Save("mounts", &x.mounts) + m.Save("mountID", &x.mountID) +} + +func (x *MountNamespace) afterLoad() {} +func (x *MountNamespace) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("userns", &x.userns) + m.Load("root", &x.root) + m.Load("mounts", &x.mounts) + m.Load("mountID", &x.mountID) +} + +func (x *overlayEntry) beforeSave() {} +func (x *overlayEntry) save(m state.Map) { + x.beforeSave() + m.Save("lowerExists", &x.lowerExists) + m.Save("lower", &x.lower) + m.Save("mappings", &x.mappings) + m.Save("upper", &x.upper) + m.Save("dirCache", &x.dirCache) +} + +func (x *overlayEntry) afterLoad() {} +func (x *overlayEntry) load(m state.Map) { + m.Load("lowerExists", &x.lowerExists) + m.Load("lower", &x.lower) + m.Load("mappings", &x.mappings) + m.Load("upper", &x.upper) + m.Load("dirCache", &x.dirCache) +} + +func init() { + state.Register("pkg/sentry/fs.StableAttr", (*StableAttr)(nil), state.Fns{Save: (*StableAttr).save, Load: (*StableAttr).load}) + state.Register("pkg/sentry/fs.UnstableAttr", (*UnstableAttr)(nil), state.Fns{Save: (*UnstableAttr).save, Load: (*UnstableAttr).load}) + state.Register("pkg/sentry/fs.AttrMask", (*AttrMask)(nil), state.Fns{Save: (*AttrMask).save, Load: (*AttrMask).load}) + state.Register("pkg/sentry/fs.PermMask", (*PermMask)(nil), state.Fns{Save: (*PermMask).save, Load: (*PermMask).load}) + state.Register("pkg/sentry/fs.FilePermissions", (*FilePermissions)(nil), state.Fns{Save: (*FilePermissions).save, Load: (*FilePermissions).load}) + state.Register("pkg/sentry/fs.FileOwner", (*FileOwner)(nil), state.Fns{Save: (*FileOwner).save, Load: (*FileOwner).load}) + state.Register("pkg/sentry/fs.DentAttr", (*DentAttr)(nil), state.Fns{Save: (*DentAttr).save, Load: (*DentAttr).load}) + state.Register("pkg/sentry/fs.SortedDentryMap", (*SortedDentryMap)(nil), state.Fns{Save: (*SortedDentryMap).save, Load: (*SortedDentryMap).load}) + state.Register("pkg/sentry/fs.Dirent", (*Dirent)(nil), state.Fns{Save: (*Dirent).save, Load: (*Dirent).load}) + state.Register("pkg/sentry/fs.DirentCache", (*DirentCache)(nil), state.Fns{Save: (*DirentCache).save, Load: (*DirentCache).load}) + state.Register("pkg/sentry/fs.DirentCacheLimiter", (*DirentCacheLimiter)(nil), state.Fns{Save: (*DirentCacheLimiter).save, Load: (*DirentCacheLimiter).load}) + state.Register("pkg/sentry/fs.direntList", (*direntList)(nil), state.Fns{Save: (*direntList).save, Load: (*direntList).load}) + state.Register("pkg/sentry/fs.direntEntry", (*direntEntry)(nil), state.Fns{Save: (*direntEntry).save, Load: (*direntEntry).load}) + state.Register("pkg/sentry/fs.eventList", (*eventList)(nil), state.Fns{Save: (*eventList).save, Load: (*eventList).load}) + state.Register("pkg/sentry/fs.eventEntry", (*eventEntry)(nil), state.Fns{Save: (*eventEntry).save, Load: (*eventEntry).load}) + state.Register("pkg/sentry/fs.File", (*File)(nil), state.Fns{Save: (*File).save, Load: (*File).load}) + state.Register("pkg/sentry/fs.overlayFileOperations", (*overlayFileOperations)(nil), state.Fns{Save: (*overlayFileOperations).save, Load: (*overlayFileOperations).load}) + state.Register("pkg/sentry/fs.overlayMappingIdentity", (*overlayMappingIdentity)(nil), state.Fns{Save: (*overlayMappingIdentity).save, Load: (*overlayMappingIdentity).load}) + state.Register("pkg/sentry/fs.MountSourceFlags", (*MountSourceFlags)(nil), state.Fns{Save: (*MountSourceFlags).save, Load: (*MountSourceFlags).load}) + state.Register("pkg/sentry/fs.FileFlags", (*FileFlags)(nil), state.Fns{Save: (*FileFlags).save, Load: (*FileFlags).load}) + state.Register("pkg/sentry/fs.Inode", (*Inode)(nil), state.Fns{Save: (*Inode).save, Load: (*Inode).load}) + state.Register("pkg/sentry/fs.LockCtx", (*LockCtx)(nil), state.Fns{Save: (*LockCtx).save, Load: (*LockCtx).load}) + state.Register("pkg/sentry/fs.Watches", (*Watches)(nil), state.Fns{Save: (*Watches).save, Load: (*Watches).load}) + state.Register("pkg/sentry/fs.Inotify", (*Inotify)(nil), state.Fns{Save: (*Inotify).save, Load: (*Inotify).load}) + state.Register("pkg/sentry/fs.Event", (*Event)(nil), state.Fns{Save: (*Event).save, Load: (*Event).load}) + state.Register("pkg/sentry/fs.Watch", (*Watch)(nil), state.Fns{Save: (*Watch).save, Load: (*Watch).load}) + state.Register("pkg/sentry/fs.MountSource", (*MountSource)(nil), state.Fns{Save: (*MountSource).save, Load: (*MountSource).load}) + state.Register("pkg/sentry/fs.SimpleMountSourceOperations", (*SimpleMountSourceOperations)(nil), state.Fns{Save: (*SimpleMountSourceOperations).save, Load: (*SimpleMountSourceOperations).load}) + state.Register("pkg/sentry/fs.overlayMountSourceOperations", (*overlayMountSourceOperations)(nil), state.Fns{Save: (*overlayMountSourceOperations).save, Load: (*overlayMountSourceOperations).load}) + state.Register("pkg/sentry/fs.overlayFilesystem", (*overlayFilesystem)(nil), state.Fns{Save: (*overlayFilesystem).save, Load: (*overlayFilesystem).load}) + state.Register("pkg/sentry/fs.Mount", (*Mount)(nil), state.Fns{Save: (*Mount).save, Load: (*Mount).load}) + state.Register("pkg/sentry/fs.MountNamespace", (*MountNamespace)(nil), state.Fns{Save: (*MountNamespace).save, Load: (*MountNamespace).load}) + state.Register("pkg/sentry/fs.overlayEntry", (*overlayEntry)(nil), state.Fns{Save: (*overlayEntry).save, Load: (*overlayEntry).load}) +} diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD deleted file mode 100644 index 789369220..000000000 --- a/pkg/sentry/fs/fsutil/BUILD +++ /dev/null @@ -1,118 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "dirty_set_impl", - out = "dirty_set_impl.go", - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "fsutil", - prefix = "Dirty", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.MappableRange", - "Value": "DirtyInfo", - "Functions": "dirtySetFunctions", - }, -) - -go_template_instance( - name = "frame_ref_set_impl", - out = "frame_ref_set_impl.go", - imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "fsutil", - prefix = "FrameRef", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "platform.FileRange", - "Value": "uint64", - "Functions": "FrameRefSetFunctions", - }, -) - -go_template_instance( - name = "file_range_set_impl", - out = "file_range_set_impl.go", - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "fsutil", - prefix = "FileRange", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.MappableRange", - "Value": "uint64", - "Functions": "FileRangeSetFunctions", - }, -) - -go_library( - name = "fsutil", - srcs = [ - "dirty_set.go", - "dirty_set_impl.go", - "file.go", - "file_range_set.go", - "file_range_set_impl.go", - "frame_ref_set.go", - "frame_ref_set_impl.go", - "fsutil.go", - "host_file_mapper.go", - "host_file_mapper_state.go", - "host_file_mapper_unsafe.go", - "host_mappable.go", - "inode.go", - "inode_cached.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/state", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "fsutil_test", - size = "small", - srcs = [ - "dirty_set_test.go", - "inode_cached_test.go", - ], - library = ":fsutil", - deps = [ - "//pkg/context", - "//pkg/safemem", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/fsutil/README.md b/pkg/sentry/fs/fsutil/README.md deleted file mode 100644 index 8be367334..000000000 --- a/pkg/sentry/fs/fsutil/README.md +++ /dev/null @@ -1,207 +0,0 @@ -This package provides utilities for implementing virtual filesystem objects. - -[TOC] - -## Page cache - -`CachingInodeOperations` implements a page cache for files that cannot use the -host page cache. Normally these are files that store their data in a remote -filesystem. This also applies to files that are accessed on a platform that does -not support directly memory mapping host file descriptors (e.g. the ptrace -platform). - -An `CachingInodeOperations` buffers regions of a single file into memory. It is -owned by an `fs.Inode`, the in-memory representation of a file (all open file -descriptors are backed by an `fs.Inode`). The `fs.Inode` provides operations for -reading memory into an `CachingInodeOperations`, to represent the contents of -the file in-memory, and for writing memory out, to relieve memory pressure on -the kernel and to synchronize in-memory changes to filesystems. - -An `CachingInodeOperations` enables readable and/or writable memory access to -file content. Files can be mapped shared or private, see mmap(2). When a file is -mapped shared, changes to the file via write(2) and truncate(2) are reflected in -the shared memory region. Conversely, when the shared memory region is modified, -changes to the file are visible via read(2). Multiple shared mappings of the -same file are coherent with each other. This is consistent with Linux. - -When a file is mapped private, updates to the mapped memory are not visible to -other memory mappings. Updates to the mapped memory are also not reflected in -the file content as seen by read(2). If the file is changed after a private -mapping is created, for instance by write(2), the change to the file may or may -not be reflected in the private mapping. This is consistent with Linux. - -An `CachingInodeOperations` keeps track of ranges of memory that were modified -(or "dirtied"). When the file is explicitly synced via fsync(2), only the dirty -ranges are written out to the filesystem. Any error returned indicates a failure -to write all dirty memory of an `CachingInodeOperations` to the filesystem. In -this case the filesystem may be in an inconsistent state. The same operation can -be performed on the shared memory itself using msync(2). If neither fsync(2) nor -msync(2) is performed, then the dirty memory is written out in accordance with -the `CachingInodeOperations` eviction strategy (see below) and there is no -guarantee that memory will be written out successfully in full. - -### Memory allocation and eviction - -An `CachingInodeOperations` implements the following allocation and eviction -strategy: - -- Memory is allocated and brought up to date with the contents of a file when - a region of mapped memory is accessed (or "faulted on"). - -- Dirty memory is written out to filesystems when an fsync(2) or msync(2) - operation is performed on a memory mapped file, for all memory mapped files - when saved, and/or when there are no longer any memory mappings of a range - of a file, see munmap(2). As the latter implies, in the absence of a panic - or SIGKILL, dirty memory is written out for all memory mapped files when an - application exits. - -- Memory is freed when there are no longer any memory mappings of a range of a - file (e.g. when an application exits). This behavior is consistent with - Linux for shared memory that has been locked via mlock(2). - -Notably, memory is not allocated for read(2) or write(2) operations. This means -that reads and writes to the file are only accelerated by an -`CachingInodeOperations` if the file being read or written has been memory -mapped *and* if the shared memory has been accessed at the region being read or -written. This diverges from Linux which buffers memory into a page cache on -read(2) proactively (i.e. readahead) and delays writing it out to filesystems on -write(2) (i.e. writeback). The absence of these optimizations is not visible to -applications beyond less than optimal performance when repeatedly reading and/or -writing to same region of a file. See [Future Work](#future-work) for plans to -implement these optimizations. - -Additionally, memory held by `CachingInodeOperationss` is currently unbounded in -size. An `CachingInodeOperations` does not write out dirty memory and free it -under system memory pressure. This can cause pathological memory usage. - -When memory is written back, an `CachingInodeOperations` may write regions of -shared memory that were never modified. This is due to the strategy of -minimizing page faults (see below) and handling only a subset of memory write -faults. In the absence of an application or sentry crash, it is guaranteed that -if a region of shared memory was written to, it is written back to a filesystem. - -### Life of a shared memory mapping - -A file is memory mapped via mmap(2). For example, if `A` is an address, an -application may execute: - -``` -mmap(A, 0x1000, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); -``` - -This creates a shared mapping of fd that reflects 4k of the contents of fd -starting at offset 0, accessible at address `A`. This in turn creates a virtual -memory area region ("vma") which indicates that [`A`, `A`+0x1000) is now a valid -address range for this application to access. - -At this point, memory has not been allocated in the file's -`CachingInodeOperations`. It is also the case that the address range [`A`, -`A`+0x1000) has not been mapped on the host on behalf of the application. If the -application then tries to modify 8 bytes of the shared memory: - -``` -char buffer[] = "aaaaaaaa"; -memcpy(A, buffer, 8); -``` - -The host then sends a `SIGSEGV` to the sentry because the address range [`A`, -`A`+8) is not mapped on the host. The `SIGSEGV` indicates that the memory was -accessed writable. The sentry looks up the vma associated with [`A`, `A`+8), -finds the file that was mapped and its `CachingInodeOperations`. It then calls -`CachingInodeOperations.Translate` which allocates memory to back [`A`, `A`+8). -It may choose to allocate more memory (i.e. do "readahead") to minimize -subsequent faults. - -Memory that is allocated comes from a host tmpfs file (see -`pgalloc.MemoryFile`). The host tmpfs file memory is brought up to date with the -contents of the mapped file on its filesystem. The region of the host tmpfs file -that reflects the mapped file is then mapped into the host address space of the -application so that subsequent memory accesses do not repeatedly generate a -`SIGSEGV`. - -The range that was allocated, including any extra memory allocation to minimize -faults, is marked dirty due to the write fault. This overcounts dirty memory if -the extra memory allocated is never modified. - -To make the scenario more interesting, imagine that this application spawns -another process and maps the same file in the exact same way: - -``` -mmap(A, 0x1000, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); -``` - -Imagine that this process then tries to modify the file again but with only 4 -bytes: - -``` -char buffer[] = "bbbb"; -memcpy(A, buffer, 4); -``` - -Since the first process has already mapped and accessed the same region of the -file writable, `CachingInodeOperations.Translate` is called but returns the -memory that has already been allocated rather than allocating new memory. The -address range [`A`, `A`+0x1000) reflects the same cached view of the file as the -first process sees. For example, reading 8 bytes from the file from either -process via read(2) starting at offset 0 returns a consistent "bbbbaaaa". - -When this process no longer needs the shared memory, it may do: - -``` -munmap(A, 0x1000); -``` - -At this point, the modified memory cached by the `CachingInodeOperations` is not -written back to the file because it is still in use by the first process that -mapped it. When the first process also does: - -``` -munmap(A, 0x1000); -``` - -Then the last memory mapping of the file at the range [0, 0x1000) is gone. The -file's `CachingInodeOperations` then starts writing back memory marked dirty to -the file on its filesystem. Once writing completes, regardless of whether it was -successful, the `CachingInodeOperations` frees the memory cached at the range -[0, 0x1000). - -Subsequent read(2) or write(2) operations on the file go directly to the -filesystem since there no longer exists memory for it in its -`CachingInodeOperations`. - -## Future Work - -### Page cache - -The sentry does not yet implement the readahead and writeback optimizations for -read(2) and write(2) respectively. To do so, on read(2) and/or write(2) the -sentry must ensure that memory is allocated in a page cache to read or write -into. However, the sentry cannot boundlessly allocate memory. If it did, the -host would eventually OOM-kill the sentry+application process. This means that -the sentry must implement a page cache memory allocation strategy that is -bounded by a global user or container imposed limit. When this limit is -approached, the sentry must decide from which page cache memory should be freed -so that it can allocate more memory. If it makes a poor decision, the sentry may -end up freeing and re-allocating memory to back regions of files that are -frequently used, nullifying the optimization (and in some cases causing worse -performance due to the overhead of memory allocation and general management). -This is a form of "cache thrashing". - -In Linux, much research has been done to select and implement a lightweight but -optimal page cache eviction algorithm. Linux makes use of hardware page bits to -keep track of whether memory has been accessed. The sentry does not have direct -access to hardware. Implementing a similarly lightweight and optimal page cache -eviction algorithm will need to either introduce a kernel interface to obtain -these page bits or find a suitable alternative proxy for access events. - -In Linux, readahead happens by default but is not always ideal. For instance, -for files that are not read sequentially, it would be more ideal to simply read -from only those regions of the file rather than to optimistically cache some -number of bytes ahead of the read (up to 2MB in Linux) if the bytes cached won't -be accessed. Linux implements the fadvise64(2) system call for applications to -specify that a range of a file will not be accessed sequentially. The advice bit -FADV_RANDOM turns off the readahead optimization for the given range in the -given file. However fadvise64 is rarely used by applications so Linux implements -a readahead backoff strategy if reads are not sequential. To ensure that -application performance is not degraded, the sentry must implement a similar -backoff strategy. diff --git a/pkg/sentry/fs/fsutil/dirty_set_impl.go b/pkg/sentry/fs/fsutil/dirty_set_impl.go new file mode 100755 index 000000000..2510b81b3 --- /dev/null +++ b/pkg/sentry/fs/fsutil/dirty_set_impl.go @@ -0,0 +1,1274 @@ +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + DirtyminDegree = 3 + + DirtymaxDegree = 2 * DirtyminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type DirtySet struct { + root Dirtynode `state:".(*DirtySegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *DirtySet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *DirtySet) IsEmptyRange(r __generics_imported0.MappableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *DirtySet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *DirtySet) SpanRange(r __generics_imported0.MappableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *DirtySet) FirstSegment() DirtyIterator { + if s.root.nrSegments == 0 { + return DirtyIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *DirtySet) LastSegment() DirtyIterator { + if s.root.nrSegments == 0 { + return DirtyIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *DirtySet) FirstGap() DirtyGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return DirtyGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *DirtySet) LastGap() DirtyGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return DirtyGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *DirtySet) Find(key uint64) (DirtyIterator, DirtyGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return DirtyIterator{n, i}, DirtyGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return DirtyIterator{}, DirtyGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *DirtySet) FindSegment(key uint64) DirtyIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *DirtySet) LowerBoundSegment(min uint64) DirtyIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *DirtySet) UpperBoundSegment(max uint64) DirtyIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *DirtySet) FindGap(key uint64) DirtyGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *DirtySet) LowerBoundGap(min uint64) DirtyGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *DirtySet) UpperBoundGap(max uint64) DirtyGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *DirtySet) Add(r __generics_imported0.MappableRange, val DirtyInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *DirtySet) AddWithoutMerging(r __generics_imported0.MappableRange, val DirtyInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *DirtySet) Insert(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (dirtySetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (dirtySetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (dirtySetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *DirtySet) InsertWithoutMerging(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *DirtySet) InsertWithoutMergingUnchecked(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return DirtyIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *DirtySet) Remove(seg DirtyIterator) DirtyGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + dirtySetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(DirtyGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *DirtySet) RemoveAll() { + s.root = Dirtynode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *DirtySet) RemoveRange(r __generics_imported0.MappableRange) DirtyGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *DirtySet) Merge(first, second DirtyIterator) DirtyIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *DirtySet) MergeUnchecked(first, second DirtyIterator) DirtyIterator { + if first.End() == second.Start() { + if mval, ok := (dirtySetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return DirtyIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *DirtySet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *DirtySet) MergeRange(r __generics_imported0.MappableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *DirtySet) MergeAdjacent(r __generics_imported0.MappableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *DirtySet) Split(seg DirtyIterator, split uint64) (DirtyIterator, DirtyIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *DirtySet) SplitUnchecked(seg DirtyIterator, split uint64) (DirtyIterator, DirtyIterator) { + val1, val2 := (dirtySetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.MappableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *DirtySet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *DirtySet) Isolate(seg DirtyIterator, r __generics_imported0.MappableRange) DirtyIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *DirtySet) ApplyContiguous(r __generics_imported0.MappableRange, fn func(seg DirtyIterator)) DirtyGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return DirtyGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return DirtyGapIterator{} + } + } +} + +// +stateify savable +type Dirtynode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Dirtynode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [DirtymaxDegree - 1]__generics_imported0.MappableRange + values [DirtymaxDegree - 1]DirtyInfo + children [DirtymaxDegree]*Dirtynode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Dirtynode) firstSegment() DirtyIterator { + for n.hasChildren { + n = n.children[0] + } + return DirtyIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Dirtynode) lastSegment() DirtyIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return DirtyIterator{n, n.nrSegments - 1} +} + +func (n *Dirtynode) prevSibling() *Dirtynode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Dirtynode) nextSibling() *Dirtynode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Dirtynode) rebalanceBeforeInsert(gap DirtyGapIterator) DirtyGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < DirtymaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:DirtyminDegree-1], n.keys[:DirtyminDegree-1]) + copy(left.values[:DirtyminDegree-1], n.values[:DirtyminDegree-1]) + copy(right.keys[:DirtyminDegree-1], n.keys[DirtyminDegree:]) + copy(right.values[:DirtyminDegree-1], n.values[DirtyminDegree:]) + n.keys[0], n.values[0] = n.keys[DirtyminDegree-1], n.values[DirtyminDegree-1] + DirtyzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:DirtyminDegree], n.children[:DirtyminDegree]) + copy(right.children[:DirtyminDegree], n.children[DirtyminDegree:]) + DirtyzeroNodeSlice(n.children[2:]) + for i := 0; i < DirtyminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < DirtyminDegree { + return DirtyGapIterator{left, gap.index} + } + return DirtyGapIterator{right, gap.index - DirtyminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[DirtyminDegree-1], n.values[DirtyminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:DirtyminDegree-1], n.keys[DirtyminDegree:]) + copy(sibling.values[:DirtyminDegree-1], n.values[DirtyminDegree:]) + DirtyzeroValueSlice(n.values[DirtyminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:DirtyminDegree], n.children[DirtyminDegree:]) + DirtyzeroNodeSlice(n.children[DirtyminDegree:]) + for i := 0; i < DirtyminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = DirtyminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < DirtyminDegree { + return gap + } + return DirtyGapIterator{sibling, gap.index - DirtyminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Dirtynode) rebalanceAfterRemove(gap DirtyGapIterator) DirtyGapIterator { + for { + if n.nrSegments >= DirtyminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= DirtyminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + dirtySetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return DirtyGapIterator{n, 0} + } + if gap.node == n { + return DirtyGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= DirtyminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + dirtySetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return DirtyGapIterator{n, n.nrSegments} + } + return DirtyGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return DirtyGapIterator{p, gap.index} + } + if gap.node == right { + return DirtyGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Dirtynode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = DirtyGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + dirtySetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type DirtyIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Dirtynode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg DirtyIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg DirtyIterator) Range() __generics_imported0.MappableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg DirtyIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg DirtyIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg DirtyIterator) SetRangeUnchecked(r __generics_imported0.MappableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetRange(r __generics_imported0.MappableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg DirtyIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg DirtyIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg DirtyIterator) Value() DirtyInfo { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg DirtyIterator) ValuePtr() *DirtyInfo { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg DirtyIterator) SetValue(val DirtyInfo) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg DirtyIterator) PrevSegment() DirtyIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return DirtyIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return DirtyIterator{} + } + return DirtysegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg DirtyIterator) NextSegment() DirtyIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return DirtyIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return DirtyIterator{} + } + return DirtysegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg DirtyIterator) PrevGap() DirtyGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return DirtyGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg DirtyIterator) NextGap() DirtyGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return DirtyGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg DirtyIterator) PrevNonEmpty() (DirtyIterator, DirtyGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return DirtyIterator{}, gap + } + return gap.PrevSegment(), DirtyGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg DirtyIterator) NextNonEmpty() (DirtyIterator, DirtyGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return DirtyIterator{}, gap + } + return gap.NextSegment(), DirtyGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type DirtyGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Dirtynode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap DirtyGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap DirtyGapIterator) Range() __generics_imported0.MappableRange { + return __generics_imported0.MappableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap DirtyGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return dirtySetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap DirtyGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return dirtySetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap DirtyGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap DirtyGapIterator) PrevSegment() DirtyIterator { + return DirtysegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap DirtyGapIterator) NextSegment() DirtyIterator { + return DirtysegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap DirtyGapIterator) PrevGap() DirtyGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return DirtyGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap DirtyGapIterator) NextGap() DirtyGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return DirtyGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func DirtysegmentBeforePosition(n *Dirtynode, i int) DirtyIterator { + for i == 0 { + if n.parent == nil { + return DirtyIterator{} + } + n, i = n.parent, n.parentIndex + } + return DirtyIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func DirtysegmentAfterPosition(n *Dirtynode, i int) DirtyIterator { + for i == n.nrSegments { + if n.parent == nil { + return DirtyIterator{} + } + n, i = n.parent, n.parentIndex + } + return DirtyIterator{n, i} +} + +func DirtyzeroValueSlice(slice []DirtyInfo) { + + for i := range slice { + dirtySetFunctions{}.ClearValue(&slice[i]) + } +} + +func DirtyzeroNodeSlice(slice []*Dirtynode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *DirtySet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *Dirtynode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Dirtynode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type DirtySegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []DirtyInfo +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *DirtySet) ExportSortedSlices() *DirtySegmentDataSlices { + var sds DirtySegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *DirtySet) ImportSortedSlices(sds *DirtySegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.MappableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *DirtySet) saveRoot() *DirtySegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *DirtySet) loadRoot(sds *DirtySegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/dirty_set_test.go b/pkg/sentry/fs/fsutil/dirty_set_test.go deleted file mode 100644 index e3579c23c..000000000 --- a/pkg/sentry/fs/fsutil/dirty_set_test.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fsutil - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/usermem" -) - -func TestDirtySet(t *testing.T) { - var set DirtySet - set.MarkDirty(memmap.MappableRange{0, 2 * usermem.PageSize}) - set.KeepDirty(memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize}) - set.MarkClean(memmap.MappableRange{0, 2 * usermem.PageSize}) - want := &DirtySegmentDataSlices{ - Start: []uint64{usermem.PageSize}, - End: []uint64{2 * usermem.PageSize}, - Values: []DirtyInfo{{Keep: true}}, - } - if got := set.ExportSortedSlices(); !reflect.DeepEqual(got, want) { - t.Errorf("set:\n\tgot %v,\n\twant %v", got, want) - } -} diff --git a/pkg/segment/set.go b/pkg/sentry/fs/fsutil/file_range_set_impl.go index 03e4f258f..01e7a2401 100644..100755 --- a/pkg/segment/set.go +++ b/pkg/sentry/fs/fsutil/file_range_set_impl.go @@ -1,73 +1,14 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package segment provides tools for working with collections of segments. A -// segment is a key-value mapping, where the key is a non-empty contiguous -// range of values of type Key, and the value is a single value of type Value. -// -// Clients using this package must use the go_template_instance rule in -// tools/go_generics/defs.bzl to create an instantiation of this -// template package, providing types to use in place of Key, Range, Value, and -// Functions. See pkg/segment/test/BUILD for a usage example. -package segment +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) import ( "bytes" "fmt" ) -// Key is a required type parameter that must be an integral type. -type Key uint64 - -// Range is a required type parameter equivalent to Range<Key>. -type Range interface{} - -// Value is a required type parameter. -type Value interface{} - -// Functions is a required type parameter that must be a struct implementing -// the methods defined by Functions. -type Functions interface { - // MinKey returns the minimum allowed key. - MinKey() Key - - // MaxKey returns the maximum allowed key + 1. - MaxKey() Key - - // ClearValue deinitializes the given value. (For example, if Value is a - // pointer or interface type, ClearValue should set it to nil.) - ClearValue(*Value) - - // Merge attempts to merge the values corresponding to two consecutive - // segments. If successful, Merge returns (merged value, true). Otherwise, - // it returns (unspecified, false). - // - // Preconditions: r1.End == r2.Start. - // - // Postconditions: If merging succeeds, val1 and val2 are invalidated. - Merge(r1 Range, val1 Value, r2 Range, val2 Value) (Value, bool) - - // Split splits a segment's value at a key within its range, such that the - // first returned value corresponds to the range [r.Start, split) and the - // second returned value corresponds to the range [split, r.End). - // - // Preconditions: r.Start < split < r.End. - // - // Postconditions: The original value val is invalidated. - Split(r Range, val Value, split Key) (Value, Value) -} - const ( // minDegree is the minimum degree of an internal node in a Set B-tree. // @@ -80,9 +21,9 @@ const ( // // Our implementation requires minDegree >= 3. Higher values of minDegree // usually improve performance, but increase memory usage for small sets. - minDegree = 3 + FileRangeminDegree = 3 - maxDegree = 2 * minDegree + FileRangemaxDegree = 2 * FileRangeminDegree ) // A Set is a mapping of segments with non-overlapping Range keys. The zero @@ -90,19 +31,19 @@ const ( // copyable. Set is thread-compatible. // // +stateify savable -type Set struct { - root node `state:".(*SegmentDataSlices)"` +type FileRangeSet struct { + root FileRangenode `state:".(*FileRangeSegmentDataSlices)"` } // IsEmpty returns true if the set contains no segments. -func (s *Set) IsEmpty() bool { +func (s *FileRangeSet) IsEmpty() bool { return s.root.nrSegments == 0 } // IsEmptyRange returns true iff no segments in the set overlap the given // range. This is semantically equivalent to s.SpanRange(r) == 0, but may be // more efficient. -func (s *Set) IsEmptyRange(r Range) bool { +func (s *FileRangeSet) IsEmptyRange(r __generics_imported0.MappableRange) bool { switch { case r.Length() < 0: panic(fmt.Sprintf("invalid range %v", r)) @@ -117,8 +58,8 @@ func (s *Set) IsEmptyRange(r Range) bool { } // Span returns the total size of all segments in the set. -func (s *Set) Span() Key { - var sz Key +func (s *FileRangeSet) Span() uint64 { + var sz uint64 for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { sz += seg.Range().Length() } @@ -127,14 +68,14 @@ func (s *Set) Span() Key { // SpanRange returns the total size of the intersection of segments in the set // with the given range. -func (s *Set) SpanRange(r Range) Key { +func (s *FileRangeSet) SpanRange(r __generics_imported0.MappableRange) uint64 { switch { case r.Length() < 0: panic(fmt.Sprintf("invalid range %v", r)) case r.Length() == 0: return 0 } - var sz Key + var sz uint64 for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { sz += seg.Range().Intersect(r).Length() } @@ -143,56 +84,55 @@ func (s *Set) SpanRange(r Range) Key { // FirstSegment returns the first segment in the set. If the set is empty, // FirstSegment returns a terminal iterator. -func (s *Set) FirstSegment() Iterator { +func (s *FileRangeSet) FirstSegment() FileRangeIterator { if s.root.nrSegments == 0 { - return Iterator{} + return FileRangeIterator{} } return s.root.firstSegment() } // LastSegment returns the last segment in the set. If the set is empty, // LastSegment returns a terminal iterator. -func (s *Set) LastSegment() Iterator { +func (s *FileRangeSet) LastSegment() FileRangeIterator { if s.root.nrSegments == 0 { - return Iterator{} + return FileRangeIterator{} } return s.root.lastSegment() } // FirstGap returns the first gap in the set. -func (s *Set) FirstGap() GapIterator { +func (s *FileRangeSet) FirstGap() FileRangeGapIterator { n := &s.root for n.hasChildren { n = n.children[0] } - return GapIterator{n, 0} + return FileRangeGapIterator{n, 0} } // LastGap returns the last gap in the set. -func (s *Set) LastGap() GapIterator { +func (s *FileRangeSet) LastGap() FileRangeGapIterator { n := &s.root for n.hasChildren { n = n.children[n.nrSegments] } - return GapIterator{n, n.nrSegments} + return FileRangeGapIterator{n, n.nrSegments} } // Find returns the segment or gap whose range contains the given key. If a // segment is found, the returned Iterator is non-terminal and the // returned GapIterator is terminal. Otherwise, the returned Iterator is // terminal and the returned GapIterator is non-terminal. -func (s *Set) Find(key Key) (Iterator, GapIterator) { +func (s *FileRangeSet) Find(key uint64) (FileRangeIterator, FileRangeGapIterator) { n := &s.root for { - // Binary search invariant: the correct value of i lies within [lower, - // upper]. + lower := 0 upper := n.nrSegments for lower < upper { i := lower + (upper-lower)/2 if r := n.keys[i]; key < r.End { if key >= r.Start { - return Iterator{n, i}, GapIterator{} + return FileRangeIterator{n, i}, FileRangeGapIterator{} } upper = i } else { @@ -201,7 +141,7 @@ func (s *Set) Find(key Key) (Iterator, GapIterator) { } i := lower if !n.hasChildren { - return Iterator{}, GapIterator{n, i} + return FileRangeIterator{}, FileRangeGapIterator{n, i} } n = n.children[i] } @@ -209,7 +149,7 @@ func (s *Set) Find(key Key) (Iterator, GapIterator) { // FindSegment returns the segment whose range contains the given key. If no // such segment exists, FindSegment returns a terminal iterator. -func (s *Set) FindSegment(key Key) Iterator { +func (s *FileRangeSet) FindSegment(key uint64) FileRangeIterator { seg, _ := s.Find(key) return seg } @@ -217,7 +157,7 @@ func (s *Set) FindSegment(key Key) Iterator { // LowerBoundSegment returns the segment with the lowest range that contains a // key greater than or equal to min. If no such segment exists, // LowerBoundSegment returns a terminal iterator. -func (s *Set) LowerBoundSegment(min Key) Iterator { +func (s *FileRangeSet) LowerBoundSegment(min uint64) FileRangeIterator { seg, gap := s.Find(min) if seg.Ok() { return seg @@ -228,7 +168,7 @@ func (s *Set) LowerBoundSegment(min Key) Iterator { // UpperBoundSegment returns the segment with the highest range that contains a // key less than or equal to max. If no such segment exists, UpperBoundSegment // returns a terminal iterator. -func (s *Set) UpperBoundSegment(max Key) Iterator { +func (s *FileRangeSet) UpperBoundSegment(max uint64) FileRangeIterator { seg, gap := s.Find(max) if seg.Ok() { return seg @@ -239,14 +179,14 @@ func (s *Set) UpperBoundSegment(max Key) Iterator { // FindGap returns the gap containing the given key. If no such gap exists // (i.e. the set contains a segment containing that key), FindGap returns a // terminal iterator. -func (s *Set) FindGap(key Key) GapIterator { +func (s *FileRangeSet) FindGap(key uint64) FileRangeGapIterator { _, gap := s.Find(key) return gap } // LowerBoundGap returns the gap with the lowest range that is greater than or // equal to min. -func (s *Set) LowerBoundGap(min Key) GapIterator { +func (s *FileRangeSet) LowerBoundGap(min uint64) FileRangeGapIterator { seg, gap := s.Find(min) if gap.Ok() { return gap @@ -256,7 +196,7 @@ func (s *Set) LowerBoundGap(min Key) GapIterator { // UpperBoundGap returns the gap with the highest range that is less than or // equal to max. -func (s *Set) UpperBoundGap(max Key) GapIterator { +func (s *FileRangeSet) UpperBoundGap(max uint64) FileRangeGapIterator { seg, gap := s.Find(max) if gap.Ok() { return gap @@ -268,7 +208,7 @@ func (s *Set) UpperBoundGap(max Key) GapIterator { // segment can be merged with adjacent segments, Add will do so. If the new // segment would overlap an existing segment, Add returns false. If Add // succeeds, all existing iterators are invalidated. -func (s *Set) Add(r Range, val Value) bool { +func (s *FileRangeSet) Add(r __generics_imported0.MappableRange, val uint64) bool { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -287,7 +227,7 @@ func (s *Set) Add(r Range, val Value) bool { // If it would overlap an existing segment, AddWithoutMerging does nothing and // returns false. If AddWithoutMerging succeeds, all existing iterators are // invalidated. -func (s *Set) AddWithoutMerging(r Range, val Value) bool { +func (s *FileRangeSet) AddWithoutMerging(r __generics_imported0.MappableRange, val uint64) bool { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -314,7 +254,7 @@ func (s *Set) AddWithoutMerging(r Range, val Value) bool { // Merge, but may be more efficient. Note that there is no unchecked variant of // Insert since Insert must retrieve and inspect gap's predecessor and // successor segments regardless. -func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { +func (s *FileRangeSet) Insert(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -326,12 +266,12 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) } if prev.Ok() && prev.End() == r.Start { - if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + if mval, ok := (FileRangeSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { prev.SetEndUnchecked(r.End) prev.SetValue(mval) if next.Ok() && next.Start() == r.End { val = mval - if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + if mval, ok := (FileRangeSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { prev.SetEndUnchecked(next.End()) prev.SetValue(mval) return s.Remove(next).PrevSegment() @@ -341,7 +281,7 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } } if next.Ok() && next.Start() == r.End { - if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok { + if mval, ok := (FileRangeSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { next.SetStartUnchecked(r.Start) next.SetValue(mval) return next @@ -356,7 +296,7 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { // // If the gap cannot accommodate the segment, or if r is invalid, // InsertWithoutMerging panics. -func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator { +func (s *FileRangeSet) InsertWithoutMerging(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -371,52 +311,45 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // (including gap, but not including the returned iterator) are invalidated. // // Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). -func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { +func (s *FileRangeSet) InsertWithoutMergingUnchecked(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { gap = gap.node.rebalanceBeforeInsert(gap) copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) gap.node.keys[gap.index] = r gap.node.values[gap.index] = val gap.node.nrSegments++ - return Iterator{gap.node, gap.index} + return FileRangeIterator{gap.node, gap.index} } // Remove removes the given segment and returns an iterator to the vacated gap. // All existing iterators (including seg, but not including the returned // iterator) are invalidated. -func (s *Set) Remove(seg Iterator) GapIterator { - // We only want to remove directly from a leaf node. +func (s *FileRangeSet) Remove(seg FileRangeIterator) FileRangeGapIterator { + if seg.node.hasChildren { - // Since seg.node has children, the removed segment must have a - // predecessor (at the end of the rightmost leaf of its left child - // subtree). Move the contents of that predecessor into the removed - // segment's position, and remove that predecessor instead. (We choose - // to steal the predecessor rather than the successor because removing - // from the end of a leaf node doesn't involve any copying unless - // merging is required.) + victim := seg.PrevSegment() - // This must be unchecked since until victim is removed, seg and victim - // overlap. + seg.SetRangeUnchecked(victim.Range()) seg.SetValue(victim.Value()) return s.Remove(victim).NextGap() } copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) - Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + FileRangeSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) seg.node.nrSegments-- - return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index}) + return seg.node.rebalanceAfterRemove(FileRangeGapIterator{seg.node, seg.index}) } // RemoveAll removes all segments from the set. All existing iterators are // invalidated. -func (s *Set) RemoveAll() { - s.root = node{} +func (s *FileRangeSet) RemoveAll() { + s.root = FileRangenode{} } // RemoveRange removes all segments in the given range. An iterator to the // newly formed gap is returned, and all existing iterators are invalidated. -func (s *Set) RemoveRange(r Range) GapIterator { +func (s *FileRangeSet) RemoveRange(r __generics_imported0.MappableRange) FileRangeGapIterator { seg, gap := s.Find(r.Start) if seg.Ok() { seg = s.Isolate(seg, r) @@ -434,7 +367,7 @@ func (s *Set) RemoveRange(r Range) GapIterator { // invalidated. Otherwise, Merge returns a terminal iterator. // // If first is not the predecessor of second, Merge panics. -func (s *Set) Merge(first, second Iterator) Iterator { +func (s *FileRangeSet) Merge(first, second FileRangeIterator) FileRangeIterator { if first.NextSegment() != second { panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) } @@ -448,22 +381,21 @@ func (s *Set) Merge(first, second Iterator) Iterator { // // Precondition: first is the predecessor of second: first.NextSegment() == // second, first == second.PrevSegment(). -func (s *Set) MergeUnchecked(first, second Iterator) Iterator { +func (s *FileRangeSet) MergeUnchecked(first, second FileRangeIterator) FileRangeIterator { if first.End() == second.Start() { - if mval, ok := (Functions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { - // N.B. This must be unchecked because until s.Remove(second), first - // overlaps second. + if mval, ok := (FileRangeSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + first.SetEndUnchecked(second.End()) first.SetValue(mval) return s.Remove(second).PrevSegment() } } - return Iterator{} + return FileRangeIterator{} } // MergeAll attempts to merge all adjacent segments in the set. All existing // iterators are invalidated. -func (s *Set) MergeAll() { +func (s *FileRangeSet) MergeAll() { seg := s.FirstSegment() if !seg.Ok() { return @@ -480,7 +412,7 @@ func (s *Set) MergeAll() { // MergeRange attempts to merge all adjacent segments that contain a key in the // specific range. All existing iterators are invalidated. -func (s *Set) MergeRange(r Range) { +func (s *FileRangeSet) MergeRange(r __generics_imported0.MappableRange) { seg := s.LowerBoundSegment(r.Start) if !seg.Ok() { return @@ -497,7 +429,7 @@ func (s *Set) MergeRange(r Range) { // MergeAdjacent attempts to merge the segment containing r.Start with its // predecessor, and the segment containing r.End-1 with its successor. -func (s *Set) MergeAdjacent(r Range) { +func (s *FileRangeSet) MergeAdjacent(r __generics_imported0.MappableRange) { first := s.FindSegment(r.Start) if first.Ok() { if prev := first.PrevSegment(); prev.Ok() { @@ -520,7 +452,7 @@ func (s *Set) MergeAdjacent(r Range) { // end of the segment's range, so splitting would produce a segment with zero // length, or because split falls outside the segment's range altogether), // Split panics. -func (s *Set) Split(seg Iterator, split Key) (Iterator, Iterator) { +func (s *FileRangeSet) Split(seg FileRangeIterator, split uint64) (FileRangeIterator, FileRangeIterator) { if !seg.Range().CanSplitAt(split) { panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) } @@ -532,20 +464,20 @@ func (s *Set) Split(seg Iterator, split Key) (Iterator, Iterator) { // seg, but not including the returned iterators) are invalidated. // // Preconditions: seg.Start() < key < seg.End(). -func (s *Set) SplitUnchecked(seg Iterator, split Key) (Iterator, Iterator) { - val1, val2 := (Functions{}).Split(seg.Range(), seg.Value(), split) +func (s *FileRangeSet) SplitUnchecked(seg FileRangeIterator, split uint64) (FileRangeIterator, FileRangeIterator) { + val1, val2 := (FileRangeSetFunctions{}).Split(seg.Range(), seg.Value(), split) end2 := seg.End() seg.SetEndUnchecked(split) seg.SetValue(val1) - seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), Range{split, end2}, val2) - // seg may now be invalid due to the Insert. + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.MappableRange{split, end2}, val2) + return seg2.PrevSegment(), seg2 } // SplitAt splits the segment straddling split, if one exists. SplitAt returns // true if a segment was split and false otherwise. If SplitAt splits a // segment, all existing iterators are invalidated. -func (s *Set) SplitAt(split Key) bool { +func (s *FileRangeSet) SplitAt(split uint64) bool { if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { s.SplitUnchecked(seg, split) return true @@ -557,7 +489,7 @@ func (s *Set) SplitAt(split Key) bool { // splitting at r.Start and r.End if necessary, and returns an updated iterator // to the bounded segment. All existing iterators (including seg, but not // including the returned iterators) are invalidated. -func (s *Set) Isolate(seg Iterator, r Range) Iterator { +func (s *FileRangeSet) Isolate(seg FileRangeIterator, r __generics_imported0.MappableRange) FileRangeIterator { if seg.Range().CanSplitAt(r.Start) { _, seg = s.SplitUnchecked(seg, r.Start) } @@ -574,7 +506,7 @@ func (s *Set) Isolate(seg Iterator, r Range) Iterator { // are invalidated. // // N.B. The Iterator must not be invalidated by the function. -func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator { +func (s *FileRangeSet) ApplyContiguous(r __generics_imported0.MappableRange, fn func(seg FileRangeIterator)) FileRangeGapIterator { seg, gap := s.Find(r.Start) if !seg.Ok() { return gap @@ -583,7 +515,7 @@ func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator { seg = s.Isolate(seg, r) fn(seg) if seg.End() >= r.End { - return GapIterator{} + return FileRangeGapIterator{} } gap = seg.NextGap() if !gap.IsEmpty() { @@ -591,15 +523,14 @@ func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator { } seg = gap.NextSegment() if !seg.Ok() { - // This implies that the last segment extended all the - // way to the maximum value, since the gap was empty. - return GapIterator{} + + return FileRangeGapIterator{} } } } // +stateify savable -type node struct { +type FileRangenode struct { // An internal binary tree node looks like: // // K @@ -621,7 +552,7 @@ type node struct { // parent is a pointer to this node's parent. If this node is root, parent // is nil. - parent *node + parent *FileRangenode // parentIndex is the index of this node in parent.children. parentIndex int @@ -633,39 +564,39 @@ type node struct { // Nodes store keys and values in separate arrays to maximize locality in // the common case (scanning keys for lookup). - keys [maxDegree - 1]Range - values [maxDegree - 1]Value - children [maxDegree]*node + keys [FileRangemaxDegree - 1]__generics_imported0.MappableRange + values [FileRangemaxDegree - 1]uint64 + children [FileRangemaxDegree]*FileRangenode } // firstSegment returns the first segment in the subtree rooted by n. // // Preconditions: n.nrSegments != 0. -func (n *node) firstSegment() Iterator { +func (n *FileRangenode) firstSegment() FileRangeIterator { for n.hasChildren { n = n.children[0] } - return Iterator{n, 0} + return FileRangeIterator{n, 0} } // lastSegment returns the last segment in the subtree rooted by n. // // Preconditions: n.nrSegments != 0. -func (n *node) lastSegment() Iterator { +func (n *FileRangenode) lastSegment() FileRangeIterator { for n.hasChildren { n = n.children[n.nrSegments] } - return Iterator{n, n.nrSegments - 1} + return FileRangeIterator{n, n.nrSegments - 1} } -func (n *node) prevSibling() *node { +func (n *FileRangenode) prevSibling() *FileRangenode { if n.parent == nil || n.parentIndex == 0 { return nil } return n.parent.children[n.parentIndex-1] } -func (n *node) nextSibling() *node { +func (n *FileRangenode) nextSibling() *FileRangenode { if n.parent == nil || n.parentIndex == n.parent.nrSegments { return nil } @@ -675,40 +606,38 @@ func (n *node) nextSibling() *node { // rebalanceBeforeInsert splits n and its ancestors if they are full, as // required for insertion, and returns an updated iterator to the position // represented by gap. -func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { +func (n *FileRangenode) rebalanceBeforeInsert(gap FileRangeGapIterator) FileRangeGapIterator { if n.parent != nil { gap = n.parent.rebalanceBeforeInsert(gap) } - if n.nrSegments < maxDegree-1 { + if n.nrSegments < FileRangemaxDegree-1 { return gap } if n.parent == nil { - // n is root. Move all segments before and after n's median segment - // into new child nodes adjacent to the median segment, which is now - // the only segment in root. - left := &node{ - nrSegments: minDegree - 1, + + left := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, parent: n, parentIndex: 0, hasChildren: n.hasChildren, } - right := &node{ - nrSegments: minDegree - 1, + right := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, parent: n, parentIndex: 1, hasChildren: n.hasChildren, } - copy(left.keys[:minDegree-1], n.keys[:minDegree-1]) - copy(left.values[:minDegree-1], n.values[:minDegree-1]) - copy(right.keys[:minDegree-1], n.keys[minDegree:]) - copy(right.values[:minDegree-1], n.values[minDegree:]) - n.keys[0], n.values[0] = n.keys[minDegree-1], n.values[minDegree-1] - zeroValueSlice(n.values[1:]) + copy(left.keys[:FileRangeminDegree-1], n.keys[:FileRangeminDegree-1]) + copy(left.values[:FileRangeminDegree-1], n.values[:FileRangeminDegree-1]) + copy(right.keys[:FileRangeminDegree-1], n.keys[FileRangeminDegree:]) + copy(right.values[:FileRangeminDegree-1], n.values[FileRangeminDegree:]) + n.keys[0], n.values[0] = n.keys[FileRangeminDegree-1], n.values[FileRangeminDegree-1] + FileRangezeroValueSlice(n.values[1:]) if n.hasChildren { - copy(left.children[:minDegree], n.children[:minDegree]) - copy(right.children[:minDegree], n.children[minDegree:]) - zeroNodeSlice(n.children[2:]) - for i := 0; i < minDegree; i++ { + copy(left.children[:FileRangeminDegree], n.children[:FileRangeminDegree]) + copy(right.children[:FileRangeminDegree], n.children[FileRangeminDegree:]) + FileRangezeroNodeSlice(n.children[2:]) + for i := 0; i < FileRangeminDegree; i++ { left.children[i].parent = left left.children[i].parentIndex = i right.children[i].parent = right @@ -722,50 +651,47 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { if gap.node != n { return gap } - if gap.index < minDegree { - return GapIterator{left, gap.index} + if gap.index < FileRangeminDegree { + return FileRangeGapIterator{left, gap.index} } - return GapIterator{right, gap.index - minDegree} + return FileRangeGapIterator{right, gap.index - FileRangeminDegree} } - // n is non-root. Move n's median segment into its parent node (which can't - // be full because we've already invoked n.parent.rebalanceBeforeInsert) - // and move all segments after n's median into a new sibling node (the - // median segment's right child subtree). + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) - n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[minDegree-1], n.values[minDegree-1] + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[FileRangeminDegree-1], n.values[FileRangeminDegree-1] copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { n.parent.children[i].parentIndex = i } - sibling := &node{ - nrSegments: minDegree - 1, + sibling := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, parent: n.parent, parentIndex: n.parentIndex + 1, hasChildren: n.hasChildren, } n.parent.children[n.parentIndex+1] = sibling n.parent.nrSegments++ - copy(sibling.keys[:minDegree-1], n.keys[minDegree:]) - copy(sibling.values[:minDegree-1], n.values[minDegree:]) - zeroValueSlice(n.values[minDegree-1:]) + copy(sibling.keys[:FileRangeminDegree-1], n.keys[FileRangeminDegree:]) + copy(sibling.values[:FileRangeminDegree-1], n.values[FileRangeminDegree:]) + FileRangezeroValueSlice(n.values[FileRangeminDegree-1:]) if n.hasChildren { - copy(sibling.children[:minDegree], n.children[minDegree:]) - zeroNodeSlice(n.children[minDegree:]) - for i := 0; i < minDegree; i++ { + copy(sibling.children[:FileRangeminDegree], n.children[FileRangeminDegree:]) + FileRangezeroNodeSlice(n.children[FileRangeminDegree:]) + for i := 0; i < FileRangeminDegree; i++ { sibling.children[i].parent = sibling sibling.children[i].parentIndex = i } } - n.nrSegments = minDegree - 1 - // gap.node can't be n.parent because gaps are always in leaf nodes. + n.nrSegments = FileRangeminDegree - 1 + if gap.node != n { return gap } - if gap.index < minDegree { + if gap.index < FileRangeminDegree { return gap } - return GapIterator{sibling, gap.index - minDegree} + return FileRangeGapIterator{sibling, gap.index - FileRangeminDegree} } // rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient @@ -774,41 +700,24 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { // // Precondition: n is the only node in the tree that may currently violate a // B-tree invariant. -func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { +func (n *FileRangenode) rebalanceAfterRemove(gap FileRangeGapIterator) FileRangeGapIterator { for { - if n.nrSegments >= minDegree-1 { + if n.nrSegments >= FileRangeminDegree-1 { return gap } if n.parent == nil { - // Root is allowed to be deficient. + return gap } - // There's one other thing we can do before resorting to unsplitting. - // If either sibling node has at least minDegree segments, rotate that - // sibling's closest segment through the segment in the parent that - // separates us. That is, given: - // - // ... D ... - // / \ - // ... B C] [E ... - // - // where the node containing E is deficient, end up with: - // - // ... C ... - // / \ - // ... B] [D E ... - // - // As in Set.Remove, prefer rotating from the end of the sibling to the - // left: by precondition, n.node has fewer segments (to memcpy) than - // the sibling does. - if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= minDegree { + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= FileRangeminDegree { copy(n.keys[1:], n.keys[:n.nrSegments]) copy(n.values[1:], n.values[:n.nrSegments]) n.keys[0] = n.parent.keys[n.parentIndex-1] n.values[0] = n.parent.values[n.parentIndex-1] n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] - Functions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + FileRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) if n.hasChildren { copy(n.children[1:], n.children[:n.nrSegments+1]) n.children[0] = sibling.children[sibling.nrSegments] @@ -822,21 +731,21 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { n.nrSegments++ sibling.nrSegments-- if gap.node == sibling && gap.index == sibling.nrSegments { - return GapIterator{n, 0} + return FileRangeGapIterator{n, 0} } if gap.node == n { - return GapIterator{n, gap.index + 1} + return FileRangeGapIterator{n, gap.index + 1} } return gap } - if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= minDegree { + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= FileRangeminDegree { n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] n.values[n.nrSegments] = n.parent.values[n.parentIndex] n.parent.keys[n.parentIndex] = sibling.keys[0] n.parent.values[n.parentIndex] = sibling.values[0] copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) - Functions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + FileRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) if n.hasChildren { n.children[n.nrSegments+1] = sibling.children[0] copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) @@ -851,21 +760,16 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { sibling.nrSegments-- if gap.node == sibling { if gap.index == 0 { - return GapIterator{n, n.nrSegments} + return FileRangeGapIterator{n, n.nrSegments} } - return GapIterator{sibling, gap.index - 1} + return FileRangeGapIterator{sibling, gap.index - 1} } return gap } - // Otherwise, we must unsplit. + p := n.parent if p.nrSegments == 1 { - // Merge all segments in both n and its sibling back into n.parent. - // This is the reverse of the root splitting case in - // node.rebalanceBeforeInsert. (Because we require minDegree >= 3, - // only root can have 1 segment in this path, so this reduces the - // height of the tree by 1, without violating the constraint that - // all leaf nodes remain at the same depth.) + left, right := p.children[0], p.children[1] p.nrSegments = left.nrSegments + right.nrSegments + 1 p.hasChildren = left.hasChildren @@ -887,10 +791,10 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { p.children[1] = nil } if gap.node == left { - return GapIterator{p, gap.index} + return FileRangeGapIterator{p, gap.index} } if gap.node == right { - return GapIterator{p, gap.index + left.nrSegments + 1} + return FileRangeGapIterator{p, gap.index + left.nrSegments + 1} } return gap } @@ -898,7 +802,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { // two, into whichever of the two nodes comes first. This is the // reverse of the non-root splitting case in // node.rebalanceBeforeInsert. - var left, right *node + var left, right *FileRangenode if n.parentIndex > 0 { left = n.prevSibling() right = n @@ -906,10 +810,9 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { left = n right = n.nextSibling() } - // Fix up gap first since we need the old left.nrSegments, which - // merging will change. + if gap.node == right { - gap = GapIterator{left, gap.index + left.nrSegments + 1} + gap = FileRangeGapIterator{left, gap.index + left.nrSegments + 1} } left.keys[left.nrSegments] = p.keys[left.parentIndex] left.values[left.nrSegments] = p.values[left.parentIndex] @@ -925,14 +828,14 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { left.nrSegments += right.nrSegments + 1 copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) - Functions{}.ClearValue(&p.values[p.nrSegments-1]) + FileRangeSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) for i := 0; i < p.nrSegments; i++ { p.children[i].parentIndex = i } p.children[p.nrSegments] = nil p.nrSegments-- - // This process robs p of one segment, so recurse into rebalancing p. + n = p } } @@ -949,10 +852,10 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { // // Unless otherwise specified, any mutation of a set invalidates all existing // iterators into the set. -type Iterator struct { +type FileRangeIterator struct { // node is the node containing the iterated segment. If the iterator is // terminal, node is nil. - node *node + node *FileRangenode // index is the index of the segment in node.keys/values. index int @@ -960,24 +863,24 @@ type Iterator struct { // Ok returns true if the iterator is not terminal. All other methods are only // valid for non-terminal iterators. -func (seg Iterator) Ok() bool { +func (seg FileRangeIterator) Ok() bool { return seg.node != nil } // Range returns the iterated segment's range key. -func (seg Iterator) Range() Range { +func (seg FileRangeIterator) Range() __generics_imported0.MappableRange { return seg.node.keys[seg.index] } // Start is equivalent to Range().Start, but should be preferred if only the // start of the range is needed. -func (seg Iterator) Start() Key { +func (seg FileRangeIterator) Start() uint64 { return seg.node.keys[seg.index].Start } // End is equivalent to Range().End, but should be preferred if only the end of // the range is needed. -func (seg Iterator) End() Key { +func (seg FileRangeIterator) End() uint64 { return seg.node.keys[seg.index].End } @@ -991,7 +894,7 @@ func (seg Iterator) End() Key { // - The new range must not overlap an existing one: If seg.NextSegment().Ok(), // then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then // r.start >= seg.PrevSegment().End(). -func (seg Iterator) SetRangeUnchecked(r Range) { +func (seg FileRangeIterator) SetRangeUnchecked(r __generics_imported0.MappableRange) { seg.node.keys[seg.index] = r } @@ -999,7 +902,7 @@ func (seg Iterator) SetRangeUnchecked(r Range) { // cause the iterated segment to overlap another segment, or if the new range // is invalid, SetRange panics. This operation does not invalidate any // iterators. -func (seg Iterator) SetRange(r Range) { +func (seg FileRangeIterator) SetRange(r __generics_imported0.MappableRange) { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -1017,7 +920,7 @@ func (seg Iterator) SetRange(r Range) { // // Preconditions: The new start must be valid: start < seg.End(); if // seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). -func (seg Iterator) SetStartUnchecked(start Key) { +func (seg FileRangeIterator) SetStartUnchecked(start uint64) { seg.node.keys[seg.index].Start = start } @@ -1025,7 +928,7 @@ func (seg Iterator) SetStartUnchecked(start Key) { // cause the iterated segment to overlap another segment, or would result in an // invalid range, SetStart panics. This operation does not invalidate any // iterators. -func (seg Iterator) SetStart(start Key) { +func (seg FileRangeIterator) SetStart(start uint64) { if start >= seg.End() { panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) } @@ -1040,7 +943,7 @@ func (seg Iterator) SetStart(start Key) { // // Preconditions: The new end must be valid: end > seg.Start(); if // seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). -func (seg Iterator) SetEndUnchecked(end Key) { +func (seg FileRangeIterator) SetEndUnchecked(end uint64) { seg.node.keys[seg.index].End = end } @@ -1048,7 +951,7 @@ func (seg Iterator) SetEndUnchecked(end Key) { // the iterated segment to overlap another segment, or would result in an // invalid range, SetEnd panics. This operation does not invalidate any // iterators. -func (seg Iterator) SetEnd(end Key) { +func (seg FileRangeIterator) SetEnd(end uint64) { if end <= seg.Start() { panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) } @@ -1059,69 +962,68 @@ func (seg Iterator) SetEnd(end Key) { } // Value returns a copy of the iterated segment's value. -func (seg Iterator) Value() Value { +func (seg FileRangeIterator) Value() uint64 { return seg.node.values[seg.index] } // ValuePtr returns a pointer to the iterated segment's value. The pointer is // invalidated if the iterator is invalidated. This operation does not // invalidate any iterators. -func (seg Iterator) ValuePtr() *Value { +func (seg FileRangeIterator) ValuePtr() *uint64 { return &seg.node.values[seg.index] } // SetValue mutates the iterated segment's value. This operation does not // invalidate any iterators. -func (seg Iterator) SetValue(val Value) { +func (seg FileRangeIterator) SetValue(val uint64) { seg.node.values[seg.index] = val } // PrevSegment returns the iterated segment's predecessor. If there is no // preceding segment, PrevSegment returns a terminal iterator. -func (seg Iterator) PrevSegment() Iterator { +func (seg FileRangeIterator) PrevSegment() FileRangeIterator { if seg.node.hasChildren { return seg.node.children[seg.index].lastSegment() } if seg.index > 0 { - return Iterator{seg.node, seg.index - 1} + return FileRangeIterator{seg.node, seg.index - 1} } if seg.node.parent == nil { - return Iterator{} + return FileRangeIterator{} } - return segmentBeforePosition(seg.node.parent, seg.node.parentIndex) + return FileRangesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) } // NextSegment returns the iterated segment's successor. If there is no // succeeding segment, NextSegment returns a terminal iterator. -func (seg Iterator) NextSegment() Iterator { +func (seg FileRangeIterator) NextSegment() FileRangeIterator { if seg.node.hasChildren { return seg.node.children[seg.index+1].firstSegment() } if seg.index < seg.node.nrSegments-1 { - return Iterator{seg.node, seg.index + 1} + return FileRangeIterator{seg.node, seg.index + 1} } if seg.node.parent == nil { - return Iterator{} + return FileRangeIterator{} } - return segmentAfterPosition(seg.node.parent, seg.node.parentIndex) + return FileRangesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) } // PrevGap returns the gap immediately before the iterated segment. -func (seg Iterator) PrevGap() GapIterator { +func (seg FileRangeIterator) PrevGap() FileRangeGapIterator { if seg.node.hasChildren { - // Note that this isn't recursive because the last segment in a subtree - // must be in a leaf node. + return seg.node.children[seg.index].lastSegment().NextGap() } - return GapIterator{seg.node, seg.index} + return FileRangeGapIterator{seg.node, seg.index} } // NextGap returns the gap immediately after the iterated segment. -func (seg Iterator) NextGap() GapIterator { +func (seg FileRangeIterator) NextGap() FileRangeGapIterator { if seg.node.hasChildren { return seg.node.children[seg.index+1].firstSegment().PrevGap() } - return GapIterator{seg.node, seg.index + 1} + return FileRangeGapIterator{seg.node, seg.index + 1} } // PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, @@ -1129,12 +1031,12 @@ func (seg Iterator) NextGap() GapIterator { // Functions.MinKey(), PrevNonEmpty will return two terminal iterators. // Otherwise, exactly one of the iterators returned by PrevNonEmpty will be // non-terminal. -func (seg Iterator) PrevNonEmpty() (Iterator, GapIterator) { +func (seg FileRangeIterator) PrevNonEmpty() (FileRangeIterator, FileRangeGapIterator) { gap := seg.PrevGap() if gap.Range().Length() != 0 { - return Iterator{}, gap + return FileRangeIterator{}, gap } - return gap.PrevSegment(), GapIterator{} + return gap.PrevSegment(), FileRangeGapIterator{} } // NextNonEmpty returns the iterated segment's successor if it is adjacent, or @@ -1142,12 +1044,12 @@ func (seg Iterator) PrevNonEmpty() (Iterator, GapIterator) { // Functions.MaxKey(), NextNonEmpty will return two terminal iterators. // Otherwise, exactly one of the iterators returned by NextNonEmpty will be // non-terminal. -func (seg Iterator) NextNonEmpty() (Iterator, GapIterator) { +func (seg FileRangeIterator) NextNonEmpty() (FileRangeIterator, FileRangeGapIterator) { gap := seg.NextGap() if gap.Range().Length() != 0 { - return Iterator{}, gap + return FileRangeIterator{}, gap } - return gap.NextSegment(), GapIterator{} + return gap.NextSegment(), FileRangeGapIterator{} } // A GapIterator is conceptually one of: @@ -1168,77 +1070,77 @@ func (seg Iterator) NextNonEmpty() (Iterator, GapIterator) { // // Unless otherwise specified, any mutation of a set invalidates all existing // iterators into the set. -type GapIterator struct { +type FileRangeGapIterator struct { // The representation of a GapIterator is identical to that of an Iterator, // except that index corresponds to positions between segments in the same // way as for node.children (see comment for node.nrSegments). - node *node + node *FileRangenode index int } // Ok returns true if the iterator is not terminal. All other methods are only // valid for non-terminal iterators. -func (gap GapIterator) Ok() bool { +func (gap FileRangeGapIterator) Ok() bool { return gap.node != nil } // Range returns the range spanned by the iterated gap. -func (gap GapIterator) Range() Range { - return Range{gap.Start(), gap.End()} +func (gap FileRangeGapIterator) Range() __generics_imported0.MappableRange { + return __generics_imported0.MappableRange{gap.Start(), gap.End()} } // Start is equivalent to Range().Start, but should be preferred if only the // start of the range is needed. -func (gap GapIterator) Start() Key { +func (gap FileRangeGapIterator) Start() uint64 { if ps := gap.PrevSegment(); ps.Ok() { return ps.End() } - return Functions{}.MinKey() + return FileRangeSetFunctions{}.MinKey() } // End is equivalent to Range().End, but should be preferred if only the end of // the range is needed. -func (gap GapIterator) End() Key { +func (gap FileRangeGapIterator) End() uint64 { if ns := gap.NextSegment(); ns.Ok() { return ns.Start() } - return Functions{}.MaxKey() + return FileRangeSetFunctions{}.MaxKey() } // IsEmpty returns true if the iterated gap is empty (that is, the "gap" is // between two adjacent segments.) -func (gap GapIterator) IsEmpty() bool { +func (gap FileRangeGapIterator) IsEmpty() bool { return gap.Range().Length() == 0 } // PrevSegment returns the segment immediately before the iterated gap. If no // such segment exists, PrevSegment returns a terminal iterator. -func (gap GapIterator) PrevSegment() Iterator { - return segmentBeforePosition(gap.node, gap.index) +func (gap FileRangeGapIterator) PrevSegment() FileRangeIterator { + return FileRangesegmentBeforePosition(gap.node, gap.index) } // NextSegment returns the segment immediately after the iterated gap. If no // such segment exists, NextSegment returns a terminal iterator. -func (gap GapIterator) NextSegment() Iterator { - return segmentAfterPosition(gap.node, gap.index) +func (gap FileRangeGapIterator) NextSegment() FileRangeIterator { + return FileRangesegmentAfterPosition(gap.node, gap.index) } // PrevGap returns the iterated gap's predecessor. If no such gap exists, // PrevGap returns a terminal iterator. -func (gap GapIterator) PrevGap() GapIterator { +func (gap FileRangeGapIterator) PrevGap() FileRangeGapIterator { seg := gap.PrevSegment() if !seg.Ok() { - return GapIterator{} + return FileRangeGapIterator{} } return seg.PrevGap() } // NextGap returns the iterated gap's successor. If no such gap exists, NextGap // returns a terminal iterator. -func (gap GapIterator) NextGap() GapIterator { +func (gap FileRangeGapIterator) NextGap() FileRangeGapIterator { seg := gap.NextSegment() if !seg.Ok() { - return GapIterator{} + return FileRangeGapIterator{} } return seg.NextGap() } @@ -1246,56 +1148,55 @@ func (gap GapIterator) NextGap() GapIterator { // segmentBeforePosition returns the predecessor segment of the position given // by n.children[i], which may or may not contain a child. If no such segment // exists, segmentBeforePosition returns a terminal iterator. -func segmentBeforePosition(n *node, i int) Iterator { +func FileRangesegmentBeforePosition(n *FileRangenode, i int) FileRangeIterator { for i == 0 { if n.parent == nil { - return Iterator{} + return FileRangeIterator{} } n, i = n.parent, n.parentIndex } - return Iterator{n, i - 1} + return FileRangeIterator{n, i - 1} } // segmentAfterPosition returns the successor segment of the position given by // n.children[i], which may or may not contain a child. If no such segment // exists, segmentAfterPosition returns a terminal iterator. -func segmentAfterPosition(n *node, i int) Iterator { +func FileRangesegmentAfterPosition(n *FileRangenode, i int) FileRangeIterator { for i == n.nrSegments { if n.parent == nil { - return Iterator{} + return FileRangeIterator{} } n, i = n.parent, n.parentIndex } - return Iterator{n, i} + return FileRangeIterator{n, i} } -func zeroValueSlice(slice []Value) { - // TODO(jamieliu): check if Go is actually smart enough to optimize a - // ClearValue that assigns nil to a memset here +func FileRangezeroValueSlice(slice []uint64) { + for i := range slice { - Functions{}.ClearValue(&slice[i]) + FileRangeSetFunctions{}.ClearValue(&slice[i]) } } -func zeroNodeSlice(slice []*node) { +func FileRangezeroNodeSlice(slice []*FileRangenode) { for i := range slice { slice[i] = nil } } // String stringifies a Set for debugging. -func (s *Set) String() string { +func (s *FileRangeSet) String() string { return s.root.String() } // String stringifies a node (and all of its children) for debugging. -func (n *node) String() string { +func (n *FileRangenode) String() string { var buf bytes.Buffer n.writeDebugString(&buf, "") return buf.String() } -func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { +func (n *FileRangenode) writeDebugString(buf *bytes.Buffer, prefix string) { if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { buf.WriteString(prefix) buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) @@ -1322,16 +1223,16 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { // for save/restore and the layout here is optimized for that. // // +stateify savable -type SegmentDataSlices struct { - Start []Key - End []Key - Values []Value +type FileRangeSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []uint64 } // ExportSortedSlice returns a copy of all segments in the given set, in ascending // key order. -func (s *Set) ExportSortedSlices() *SegmentDataSlices { - var sds SegmentDataSlices +func (s *FileRangeSet) ExportSortedSlices() *FileRangeSegmentDataSlices { + var sds FileRangeSegmentDataSlices for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { sds.Start = append(sds.Start, seg.Start()) sds.End = append(sds.End, seg.End()) @@ -1348,13 +1249,13 @@ func (s *Set) ExportSortedSlices() *SegmentDataSlices { // Preconditions: s must be empty. sds must represent a valid set (the segments // in sds must have valid lengths that do not overlap). The segments in sds // must be sorted in ascending key order. -func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { +func (s *FileRangeSet) ImportSortedSlices(sds *FileRangeSegmentDataSlices) error { if !s.IsEmpty() { return fmt.Errorf("cannot import into non-empty set %v", s) } gap := s.FirstGap() for i := range sds.Start { - r := Range{sds.Start[i], sds.End[i]} + r := __generics_imported0.MappableRange{sds.Start[i], sds.End[i]} if !gap.Range().IsSupersetOf(r) { return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) } @@ -1362,3 +1263,12 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { } return nil } +func (s *FileRangeSet) saveRoot() *FileRangeSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *FileRangeSet) loadRoot(sds *FileRangeSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/frame_ref_set_impl.go b/pkg/sentry/fs/fsutil/frame_ref_set_impl.go new file mode 100755 index 000000000..88695dbd1 --- /dev/null +++ b/pkg/sentry/fs/fsutil/frame_ref_set_impl.go @@ -0,0 +1,1274 @@ +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/platform" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + FrameRefminDegree = 3 + + FrameRefmaxDegree = 2 * FrameRefminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type FrameRefSet struct { + root FrameRefnode `state:".(*FrameRefSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *FrameRefSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *FrameRefSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *FrameRefSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *FrameRefSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *FrameRefSet) FirstSegment() FrameRefIterator { + if s.root.nrSegments == 0 { + return FrameRefIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *FrameRefSet) LastSegment() FrameRefIterator { + if s.root.nrSegments == 0 { + return FrameRefIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *FrameRefSet) FirstGap() FrameRefGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return FrameRefGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *FrameRefSet) LastGap() FrameRefGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return FrameRefGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *FrameRefSet) Find(key uint64) (FrameRefIterator, FrameRefGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return FrameRefIterator{n, i}, FrameRefGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return FrameRefIterator{}, FrameRefGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *FrameRefSet) FindSegment(key uint64) FrameRefIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *FrameRefSet) LowerBoundSegment(min uint64) FrameRefIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *FrameRefSet) UpperBoundSegment(max uint64) FrameRefIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *FrameRefSet) FindGap(key uint64) FrameRefGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *FrameRefSet) LowerBoundGap(min uint64) FrameRefGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *FrameRefSet) UpperBoundGap(max uint64) FrameRefGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *FrameRefSet) Add(r __generics_imported0.FileRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *FrameRefSet) AddWithoutMerging(r __generics_imported0.FileRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *FrameRefSet) Insert(gap FrameRefGapIterator, r __generics_imported0.FileRange, val uint64) FrameRefIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (FrameRefSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (FrameRefSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (FrameRefSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *FrameRefSet) InsertWithoutMerging(gap FrameRefGapIterator, r __generics_imported0.FileRange, val uint64) FrameRefIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *FrameRefSet) InsertWithoutMergingUnchecked(gap FrameRefGapIterator, r __generics_imported0.FileRange, val uint64) FrameRefIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return FrameRefIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *FrameRefSet) Remove(seg FrameRefIterator) FrameRefGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + FrameRefSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(FrameRefGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *FrameRefSet) RemoveAll() { + s.root = FrameRefnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *FrameRefSet) RemoveRange(r __generics_imported0.FileRange) FrameRefGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *FrameRefSet) Merge(first, second FrameRefIterator) FrameRefIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *FrameRefSet) MergeUnchecked(first, second FrameRefIterator) FrameRefIterator { + if first.End() == second.Start() { + if mval, ok := (FrameRefSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return FrameRefIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *FrameRefSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *FrameRefSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *FrameRefSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *FrameRefSet) Split(seg FrameRefIterator, split uint64) (FrameRefIterator, FrameRefIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *FrameRefSet) SplitUnchecked(seg FrameRefIterator, split uint64) (FrameRefIterator, FrameRefIterator) { + val1, val2 := (FrameRefSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *FrameRefSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *FrameRefSet) Isolate(seg FrameRefIterator, r __generics_imported0.FileRange) FrameRefIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *FrameRefSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg FrameRefIterator)) FrameRefGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return FrameRefGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return FrameRefGapIterator{} + } + } +} + +// +stateify savable +type FrameRefnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *FrameRefnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [FrameRefmaxDegree - 1]__generics_imported0.FileRange + values [FrameRefmaxDegree - 1]uint64 + children [FrameRefmaxDegree]*FrameRefnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *FrameRefnode) firstSegment() FrameRefIterator { + for n.hasChildren { + n = n.children[0] + } + return FrameRefIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *FrameRefnode) lastSegment() FrameRefIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return FrameRefIterator{n, n.nrSegments - 1} +} + +func (n *FrameRefnode) prevSibling() *FrameRefnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *FrameRefnode) nextSibling() *FrameRefnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *FrameRefnode) rebalanceBeforeInsert(gap FrameRefGapIterator) FrameRefGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < FrameRefmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &FrameRefnode{ + nrSegments: FrameRefminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &FrameRefnode{ + nrSegments: FrameRefminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:FrameRefminDegree-1], n.keys[:FrameRefminDegree-1]) + copy(left.values[:FrameRefminDegree-1], n.values[:FrameRefminDegree-1]) + copy(right.keys[:FrameRefminDegree-1], n.keys[FrameRefminDegree:]) + copy(right.values[:FrameRefminDegree-1], n.values[FrameRefminDegree:]) + n.keys[0], n.values[0] = n.keys[FrameRefminDegree-1], n.values[FrameRefminDegree-1] + FrameRefzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:FrameRefminDegree], n.children[:FrameRefminDegree]) + copy(right.children[:FrameRefminDegree], n.children[FrameRefminDegree:]) + FrameRefzeroNodeSlice(n.children[2:]) + for i := 0; i < FrameRefminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < FrameRefminDegree { + return FrameRefGapIterator{left, gap.index} + } + return FrameRefGapIterator{right, gap.index - FrameRefminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[FrameRefminDegree-1], n.values[FrameRefminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &FrameRefnode{ + nrSegments: FrameRefminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:FrameRefminDegree-1], n.keys[FrameRefminDegree:]) + copy(sibling.values[:FrameRefminDegree-1], n.values[FrameRefminDegree:]) + FrameRefzeroValueSlice(n.values[FrameRefminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:FrameRefminDegree], n.children[FrameRefminDegree:]) + FrameRefzeroNodeSlice(n.children[FrameRefminDegree:]) + for i := 0; i < FrameRefminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = FrameRefminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < FrameRefminDegree { + return gap + } + return FrameRefGapIterator{sibling, gap.index - FrameRefminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *FrameRefnode) rebalanceAfterRemove(gap FrameRefGapIterator) FrameRefGapIterator { + for { + if n.nrSegments >= FrameRefminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= FrameRefminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + FrameRefSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return FrameRefGapIterator{n, 0} + } + if gap.node == n { + return FrameRefGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= FrameRefminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + FrameRefSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return FrameRefGapIterator{n, n.nrSegments} + } + return FrameRefGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return FrameRefGapIterator{p, gap.index} + } + if gap.node == right { + return FrameRefGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *FrameRefnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = FrameRefGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + FrameRefSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type FrameRefIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *FrameRefnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg FrameRefIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg FrameRefIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg FrameRefIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg FrameRefIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg FrameRefIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg FrameRefIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg FrameRefIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg FrameRefIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg FrameRefIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg FrameRefIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg FrameRefIterator) Value() uint64 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg FrameRefIterator) ValuePtr() *uint64 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg FrameRefIterator) SetValue(val uint64) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg FrameRefIterator) PrevSegment() FrameRefIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return FrameRefIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return FrameRefIterator{} + } + return FrameRefsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg FrameRefIterator) NextSegment() FrameRefIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return FrameRefIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return FrameRefIterator{} + } + return FrameRefsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg FrameRefIterator) PrevGap() FrameRefGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return FrameRefGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg FrameRefIterator) NextGap() FrameRefGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return FrameRefGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg FrameRefIterator) PrevNonEmpty() (FrameRefIterator, FrameRefGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return FrameRefIterator{}, gap + } + return gap.PrevSegment(), FrameRefGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg FrameRefIterator) NextNonEmpty() (FrameRefIterator, FrameRefGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return FrameRefIterator{}, gap + } + return gap.NextSegment(), FrameRefGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type FrameRefGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *FrameRefnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap FrameRefGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap FrameRefGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap FrameRefGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return FrameRefSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap FrameRefGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return FrameRefSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap FrameRefGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap FrameRefGapIterator) PrevSegment() FrameRefIterator { + return FrameRefsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap FrameRefGapIterator) NextSegment() FrameRefIterator { + return FrameRefsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap FrameRefGapIterator) PrevGap() FrameRefGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return FrameRefGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap FrameRefGapIterator) NextGap() FrameRefGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return FrameRefGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func FrameRefsegmentBeforePosition(n *FrameRefnode, i int) FrameRefIterator { + for i == 0 { + if n.parent == nil { + return FrameRefIterator{} + } + n, i = n.parent, n.parentIndex + } + return FrameRefIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func FrameRefsegmentAfterPosition(n *FrameRefnode, i int) FrameRefIterator { + for i == n.nrSegments { + if n.parent == nil { + return FrameRefIterator{} + } + n, i = n.parent, n.parentIndex + } + return FrameRefIterator{n, i} +} + +func FrameRefzeroValueSlice(slice []uint64) { + + for i := range slice { + FrameRefSetFunctions{}.ClearValue(&slice[i]) + } +} + +func FrameRefzeroNodeSlice(slice []*FrameRefnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *FrameRefSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *FrameRefnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *FrameRefnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type FrameRefSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []uint64 +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *FrameRefSet) ExportSortedSlices() *FrameRefSegmentDataSlices { + var sds FrameRefSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *FrameRefSet) ImportSortedSlices(sds *FrameRefSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *FrameRefSet) saveRoot() *FrameRefSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *FrameRefSet) loadRoot(sds *FrameRefSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/fsutil_impl_state_autogen.go b/pkg/sentry/fs/fsutil/fsutil_impl_state_autogen.go new file mode 100755 index 000000000..a0baca0c5 --- /dev/null +++ b/pkg/sentry/fs/fsutil/fsutil_impl_state_autogen.go @@ -0,0 +1,169 @@ +// automatically generated by stateify. + +package fsutil + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *DirtySet) beforeSave() {} +func (x *DirtySet) save(m state.Map) { + x.beforeSave() + var root *DirtySegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *DirtySet) afterLoad() {} +func (x *DirtySet) load(m state.Map) { + m.LoadValue("root", new(*DirtySegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*DirtySegmentDataSlices)) }) +} + +func (x *Dirtynode) beforeSave() {} +func (x *Dirtynode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *Dirtynode) afterLoad() {} +func (x *Dirtynode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *DirtySegmentDataSlices) beforeSave() {} +func (x *DirtySegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *DirtySegmentDataSlices) afterLoad() {} +func (x *DirtySegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *FileRangeSet) beforeSave() {} +func (x *FileRangeSet) save(m state.Map) { + x.beforeSave() + var root *FileRangeSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *FileRangeSet) afterLoad() {} +func (x *FileRangeSet) load(m state.Map) { + m.LoadValue("root", new(*FileRangeSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*FileRangeSegmentDataSlices)) }) +} + +func (x *FileRangenode) beforeSave() {} +func (x *FileRangenode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *FileRangenode) afterLoad() {} +func (x *FileRangenode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *FileRangeSegmentDataSlices) beforeSave() {} +func (x *FileRangeSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *FileRangeSegmentDataSlices) afterLoad() {} +func (x *FileRangeSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *FrameRefSet) beforeSave() {} +func (x *FrameRefSet) save(m state.Map) { + x.beforeSave() + var root *FrameRefSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *FrameRefSet) afterLoad() {} +func (x *FrameRefSet) load(m state.Map) { + m.LoadValue("root", new(*FrameRefSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*FrameRefSegmentDataSlices)) }) +} + +func (x *FrameRefnode) beforeSave() {} +func (x *FrameRefnode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *FrameRefnode) afterLoad() {} +func (x *FrameRefnode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *FrameRefSegmentDataSlices) beforeSave() {} +func (x *FrameRefSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *FrameRefSegmentDataSlices) afterLoad() {} +func (x *FrameRefSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("pkg/sentry/fs/fsutil.DirtySet", (*DirtySet)(nil), state.Fns{Save: (*DirtySet).save, Load: (*DirtySet).load}) + state.Register("pkg/sentry/fs/fsutil.Dirtynode", (*Dirtynode)(nil), state.Fns{Save: (*Dirtynode).save, Load: (*Dirtynode).load}) + state.Register("pkg/sentry/fs/fsutil.DirtySegmentDataSlices", (*DirtySegmentDataSlices)(nil), state.Fns{Save: (*DirtySegmentDataSlices).save, Load: (*DirtySegmentDataSlices).load}) + state.Register("pkg/sentry/fs/fsutil.FileRangeSet", (*FileRangeSet)(nil), state.Fns{Save: (*FileRangeSet).save, Load: (*FileRangeSet).load}) + state.Register("pkg/sentry/fs/fsutil.FileRangenode", (*FileRangenode)(nil), state.Fns{Save: (*FileRangenode).save, Load: (*FileRangenode).load}) + state.Register("pkg/sentry/fs/fsutil.FileRangeSegmentDataSlices", (*FileRangeSegmentDataSlices)(nil), state.Fns{Save: (*FileRangeSegmentDataSlices).save, Load: (*FileRangeSegmentDataSlices).load}) + state.Register("pkg/sentry/fs/fsutil.FrameRefSet", (*FrameRefSet)(nil), state.Fns{Save: (*FrameRefSet).save, Load: (*FrameRefSet).load}) + state.Register("pkg/sentry/fs/fsutil.FrameRefnode", (*FrameRefnode)(nil), state.Fns{Save: (*FrameRefnode).save, Load: (*FrameRefnode).load}) + state.Register("pkg/sentry/fs/fsutil.FrameRefSegmentDataSlices", (*FrameRefSegmentDataSlices)(nil), state.Fns{Save: (*FrameRefSegmentDataSlices).save, Load: (*FrameRefSegmentDataSlices).load}) +} diff --git a/pkg/sentry/fs/fsutil/fsutil_state_autogen.go b/pkg/sentry/fs/fsutil/fsutil_state_autogen.go new file mode 100755 index 000000000..80b93ad25 --- /dev/null +++ b/pkg/sentry/fs/fsutil/fsutil_state_autogen.go @@ -0,0 +1,204 @@ +// automatically generated by stateify. + +package fsutil + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *DirtyInfo) beforeSave() {} +func (x *DirtyInfo) save(m state.Map) { + x.beforeSave() + m.Save("Keep", &x.Keep) +} + +func (x *DirtyInfo) afterLoad() {} +func (x *DirtyInfo) load(m state.Map) { + m.Load("Keep", &x.Keep) +} + +func (x *StaticDirFileOperations) beforeSave() {} +func (x *StaticDirFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("dentryMap", &x.dentryMap) + m.Save("dirCursor", &x.dirCursor) +} + +func (x *StaticDirFileOperations) afterLoad() {} +func (x *StaticDirFileOperations) load(m state.Map) { + m.Load("dentryMap", &x.dentryMap) + m.Load("dirCursor", &x.dirCursor) +} + +func (x *NoReadWriteFile) beforeSave() {} +func (x *NoReadWriteFile) save(m state.Map) { + x.beforeSave() +} + +func (x *NoReadWriteFile) afterLoad() {} +func (x *NoReadWriteFile) load(m state.Map) { +} + +func (x *FileStaticContentReader) beforeSave() {} +func (x *FileStaticContentReader) save(m state.Map) { + x.beforeSave() + m.Save("content", &x.content) +} + +func (x *FileStaticContentReader) afterLoad() {} +func (x *FileStaticContentReader) load(m state.Map) { + m.Load("content", &x.content) +} + +func (x *HostFileMapper) beforeSave() {} +func (x *HostFileMapper) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) +} + +func (x *HostFileMapper) load(m state.Map) { + m.Load("refs", &x.refs) + m.AfterLoad(x.afterLoad) +} + +func (x *HostMappable) beforeSave() {} +func (x *HostMappable) save(m state.Map) { + x.beforeSave() + m.Save("hostFileMapper", &x.hostFileMapper) + m.Save("backingFile", &x.backingFile) + m.Save("mappings", &x.mappings) +} + +func (x *HostMappable) afterLoad() {} +func (x *HostMappable) load(m state.Map) { + m.Load("hostFileMapper", &x.hostFileMapper) + m.Load("backingFile", &x.backingFile) + m.Load("mappings", &x.mappings) +} + +func (x *SimpleFileInode) beforeSave() {} +func (x *SimpleFileInode) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *SimpleFileInode) afterLoad() {} +func (x *SimpleFileInode) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *NoReadWriteFileInode) beforeSave() {} +func (x *NoReadWriteFileInode) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *NoReadWriteFileInode) afterLoad() {} +func (x *NoReadWriteFileInode) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *InodeSimpleAttributes) beforeSave() {} +func (x *InodeSimpleAttributes) save(m state.Map) { + x.beforeSave() + m.Save("fsType", &x.fsType) + m.Save("unstable", &x.unstable) +} + +func (x *InodeSimpleAttributes) afterLoad() {} +func (x *InodeSimpleAttributes) load(m state.Map) { + m.Load("fsType", &x.fsType) + m.Load("unstable", &x.unstable) +} + +func (x *InodeSimpleExtendedAttributes) beforeSave() {} +func (x *InodeSimpleExtendedAttributes) save(m state.Map) { + x.beforeSave() + m.Save("xattrs", &x.xattrs) +} + +func (x *InodeSimpleExtendedAttributes) afterLoad() {} +func (x *InodeSimpleExtendedAttributes) load(m state.Map) { + m.Load("xattrs", &x.xattrs) +} + +func (x *staticFile) beforeSave() {} +func (x *staticFile) save(m state.Map) { + x.beforeSave() + m.Save("FileStaticContentReader", &x.FileStaticContentReader) +} + +func (x *staticFile) afterLoad() {} +func (x *staticFile) load(m state.Map) { + m.Load("FileStaticContentReader", &x.FileStaticContentReader) +} + +func (x *InodeStaticFileGetter) beforeSave() {} +func (x *InodeStaticFileGetter) save(m state.Map) { + x.beforeSave() + m.Save("Contents", &x.Contents) +} + +func (x *InodeStaticFileGetter) afterLoad() {} +func (x *InodeStaticFileGetter) load(m state.Map) { + m.Load("Contents", &x.Contents) +} + +func (x *CachingInodeOperations) beforeSave() {} +func (x *CachingInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("backingFile", &x.backingFile) + m.Save("mfp", &x.mfp) + m.Save("opts", &x.opts) + m.Save("attr", &x.attr) + m.Save("dirtyAttr", &x.dirtyAttr) + m.Save("mappings", &x.mappings) + m.Save("cache", &x.cache) + m.Save("dirty", &x.dirty) + m.Save("hostFileMapper", &x.hostFileMapper) + m.Save("refs", &x.refs) +} + +func (x *CachingInodeOperations) afterLoad() {} +func (x *CachingInodeOperations) load(m state.Map) { + m.Load("backingFile", &x.backingFile) + m.Load("mfp", &x.mfp) + m.Load("opts", &x.opts) + m.Load("attr", &x.attr) + m.Load("dirtyAttr", &x.dirtyAttr) + m.Load("mappings", &x.mappings) + m.Load("cache", &x.cache) + m.Load("dirty", &x.dirty) + m.Load("hostFileMapper", &x.hostFileMapper) + m.Load("refs", &x.refs) +} + +func (x *CachingInodeOperationsOptions) beforeSave() {} +func (x *CachingInodeOperationsOptions) save(m state.Map) { + x.beforeSave() + m.Save("ForcePageCache", &x.ForcePageCache) + m.Save("LimitHostFDTranslation", &x.LimitHostFDTranslation) +} + +func (x *CachingInodeOperationsOptions) afterLoad() {} +func (x *CachingInodeOperationsOptions) load(m state.Map) { + m.Load("ForcePageCache", &x.ForcePageCache) + m.Load("LimitHostFDTranslation", &x.LimitHostFDTranslation) +} + +func init() { + state.Register("pkg/sentry/fs/fsutil.DirtyInfo", (*DirtyInfo)(nil), state.Fns{Save: (*DirtyInfo).save, Load: (*DirtyInfo).load}) + state.Register("pkg/sentry/fs/fsutil.StaticDirFileOperations", (*StaticDirFileOperations)(nil), state.Fns{Save: (*StaticDirFileOperations).save, Load: (*StaticDirFileOperations).load}) + state.Register("pkg/sentry/fs/fsutil.NoReadWriteFile", (*NoReadWriteFile)(nil), state.Fns{Save: (*NoReadWriteFile).save, Load: (*NoReadWriteFile).load}) + state.Register("pkg/sentry/fs/fsutil.FileStaticContentReader", (*FileStaticContentReader)(nil), state.Fns{Save: (*FileStaticContentReader).save, Load: (*FileStaticContentReader).load}) + state.Register("pkg/sentry/fs/fsutil.HostFileMapper", (*HostFileMapper)(nil), state.Fns{Save: (*HostFileMapper).save, Load: (*HostFileMapper).load}) + state.Register("pkg/sentry/fs/fsutil.HostMappable", (*HostMappable)(nil), state.Fns{Save: (*HostMappable).save, Load: (*HostMappable).load}) + state.Register("pkg/sentry/fs/fsutil.SimpleFileInode", (*SimpleFileInode)(nil), state.Fns{Save: (*SimpleFileInode).save, Load: (*SimpleFileInode).load}) + state.Register("pkg/sentry/fs/fsutil.NoReadWriteFileInode", (*NoReadWriteFileInode)(nil), state.Fns{Save: (*NoReadWriteFileInode).save, Load: (*NoReadWriteFileInode).load}) + state.Register("pkg/sentry/fs/fsutil.InodeSimpleAttributes", (*InodeSimpleAttributes)(nil), state.Fns{Save: (*InodeSimpleAttributes).save, Load: (*InodeSimpleAttributes).load}) + state.Register("pkg/sentry/fs/fsutil.InodeSimpleExtendedAttributes", (*InodeSimpleExtendedAttributes)(nil), state.Fns{Save: (*InodeSimpleExtendedAttributes).save, Load: (*InodeSimpleExtendedAttributes).load}) + state.Register("pkg/sentry/fs/fsutil.staticFile", (*staticFile)(nil), state.Fns{Save: (*staticFile).save, Load: (*staticFile).load}) + state.Register("pkg/sentry/fs/fsutil.InodeStaticFileGetter", (*InodeStaticFileGetter)(nil), state.Fns{Save: (*InodeStaticFileGetter).save, Load: (*InodeStaticFileGetter).load}) + state.Register("pkg/sentry/fs/fsutil.CachingInodeOperations", (*CachingInodeOperations)(nil), state.Fns{Save: (*CachingInodeOperations).save, Load: (*CachingInodeOperations).load}) + state.Register("pkg/sentry/fs/fsutil.CachingInodeOperationsOptions", (*CachingInodeOperationsOptions)(nil), state.Fns{Save: (*CachingInodeOperationsOptions).save, Load: (*CachingInodeOperationsOptions).load}) +} diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go deleted file mode 100644 index 1547584c5..000000000 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ /dev/null @@ -1,389 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fsutil - -import ( - "bytes" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -type noopBackingFile struct{} - -func (noopBackingFile) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { - return dsts.NumBytes(), nil -} - -func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { - return srcs.NumBytes(), nil -} - -func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { - return nil -} - -func (noopBackingFile) Sync(context.Context) error { - return nil -} - -func (noopBackingFile) FD() int { - return -1 -} - -func (noopBackingFile) Allocate(ctx context.Context, offset int64, length int64) error { - return nil -} - -func TestSetPermissions(t *testing.T) { - ctx := contexttest.Context(t) - - uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{ - Perms: fs.FilePermsFromMode(0444), - }) - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - perms := fs.FilePermsFromMode(0777) - if !iops.SetPermissions(ctx, nil, perms) { - t.Fatalf("SetPermissions failed, want success") - } - - // Did permissions change? - if iops.attr.Perms != perms { - t.Fatalf("got perms +%v, want +%v", iops.attr.Perms, perms) - } - - // Did status change time change? - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("got status change time not dirty, want dirty") - } - if iops.attr.StatusChangeTime.Equal(uattr.StatusChangeTime) { - t.Fatalf("got status change time unchanged") - } -} - -func TestSetTimestamps(t *testing.T) { - ctx := contexttest.Context(t) - for _, test := range []struct { - desc string - ts fs.TimeSpec - wantChanged fs.AttrMask - }{ - { - desc: "noop", - ts: fs.TimeSpec{ - ATimeOmit: true, - MTimeOmit: true, - }, - wantChanged: fs.AttrMask{}, - }, - { - desc: "access time only", - ts: fs.TimeSpec{ - ATime: ktime.NowFromContext(ctx), - MTimeOmit: true, - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - }, - }, - { - desc: "modification time only", - ts: fs.TimeSpec{ - ATimeOmit: true, - MTime: ktime.NowFromContext(ctx), - }, - wantChanged: fs.AttrMask{ - ModificationTime: true, - }, - }, - { - desc: "access and modification time", - ts: fs.TimeSpec{ - ATime: ktime.NowFromContext(ctx), - MTime: ktime.NowFromContext(ctx), - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - ModificationTime: true, - }, - }, - { - desc: "system time access and modification time", - ts: fs.TimeSpec{ - ATimeSetSystemTime: true, - MTimeSetSystemTime: true, - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - ModificationTime: true, - }, - }, - } { - t.Run(test.desc, func(t *testing.T) { - ctx := contexttest.Context(t) - - epoch := ktime.ZeroTime - uattr := fs.UnstableAttr{ - AccessTime: epoch, - ModificationTime: epoch, - StatusChangeTime: epoch, - } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil { - t.Fatalf("SetTimestamps got error %v, want nil", err) - } - if test.wantChanged.AccessTime { - if !iops.attr.AccessTime.After(uattr.AccessTime) { - t.Fatalf("diritied access time did not advance, want %v > %v", iops.attr.AccessTime, uattr.AccessTime) - } - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("dirty access time requires dirty status change time") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not advance") - } - } - if test.wantChanged.ModificationTime { - if !iops.attr.ModificationTime.After(uattr.ModificationTime) { - t.Fatalf("diritied modification time did not advance") - } - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("dirty modification time requires dirty status change time") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not advance") - } - } - }) - } -} - -func TestTruncate(t *testing.T) { - ctx := contexttest.Context(t) - - uattr := fs.UnstableAttr{ - Size: 0, - } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - if err := iops.Truncate(ctx, nil, uattr.Size); err != nil { - t.Fatalf("Truncate got error %v, want nil", err) - } - var size int64 = 4096 - if err := iops.Truncate(ctx, nil, size); err != nil { - t.Fatalf("Truncate got error %v, want nil", err) - } - if iops.attr.Size != size { - t.Fatalf("Truncate got %d, want %d", iops.attr.Size, size) - } - if !iops.dirtyAttr.ModificationTime || !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("Truncate did not dirty modification and status change time") - } - if !iops.attr.ModificationTime.After(uattr.ModificationTime) { - t.Fatalf("dirtied modification time did not change") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not change") - } -} - -type sliceBackingFile struct { - data []byte -} - -func newSliceBackingFile(data []byte) *sliceBackingFile { - return &sliceBackingFile{data} -} - -func (f *sliceBackingFile) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { - r := safemem.BlockSeqReader{safemem.BlockSeqOf(safemem.BlockFromSafeSlice(f.data)).DropFirst64(offset)} - return r.ReadToBlocks(dsts) -} - -func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { - w := safemem.BlockSeqWriter{safemem.BlockSeqOf(safemem.BlockFromSafeSlice(f.data)).DropFirst64(offset)} - return w.WriteFromBlocks(srcs) -} - -func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { - return nil -} - -func (*sliceBackingFile) Sync(context.Context) error { - return nil -} - -func (*sliceBackingFile) FD() int { - return -1 -} - -func (f *sliceBackingFile) Allocate(ctx context.Context, offset int64, length int64) error { - return syserror.EOPNOTSUPP -} - -type noopMappingSpace struct{} - -// Invalidate implements memmap.MappingSpace.Invalidate. -func (noopMappingSpace) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) { -} - -func anonInode(ctx context.Context) *fs.Inode { - return fs.NewInode(ctx, &SimpleFileInode{ - InodeSimpleAttributes: NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, - }, 0), - }, fs.NewPseudoMountSource(ctx), fs.StableAttr{ - Type: fs.Anonymous, - BlockSize: usermem.PageSize, - }) -} - -func pagesOf(bs ...byte) []byte { - buf := make([]byte, 0, len(bs)*usermem.PageSize) - for _, b := range bs { - buf = append(buf, bytes.Repeat([]byte{b}, usermem.PageSize)...) - } - return buf -} - -func TestRead(t *testing.T) { - ctx := contexttest.Context(t) - - // Construct a 3-page file. - buf := pagesOf('a', 'b', 'c') - file := fs.NewFile(ctx, fs.NewDirent(ctx, anonInode(ctx), "anon"), fs.FileFlags{}, nil) - uattr := fs.UnstableAttr{ - Size: int64(len(buf)), - } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - // Expect the cache to be initially empty. - if cached := iops.cache.Span(); cached != 0 { - t.Errorf("Span got %d, want 0", cached) - } - - // Create a memory mapping of the second page (as CachingInodeOperations - // expects to only cache mapped pages), then call Translate to force it to - // be cached. - var ms noopMappingSpace - ar := usermem.AddrRange{usermem.PageSize, 2 * usermem.PageSize} - if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil { - t.Fatalf("AddMapping got %v, want nil", err) - } - mr := memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize} - if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil { - t.Fatalf("Translate got %v, want nil", err) - } - if cached := iops.cache.Span(); cached != usermem.PageSize { - t.Errorf("SpanRange got %d, want %d", cached, usermem.PageSize) - } - - // Try to read 4 pages. The first and third pages should be read directly - // from the "file", the second page should be read from the cache, and only - // 3 pages (the size of the file) should be readable. - rbuf := make([]byte, 4*usermem.PageSize) - dst := usermem.BytesIOSequence(rbuf) - n, err := iops.Read(ctx, file, dst, 0) - if n != 3*usermem.PageSize || (err != nil && err != io.EOF) { - t.Fatalf("Read got (%d, %v), want (%d, nil or EOF)", n, err, 3*usermem.PageSize) - } - rbuf = rbuf[:3*usermem.PageSize] - - // Did we get the bytes we expect? - if !bytes.Equal(rbuf, buf) { - t.Errorf("Read back bytes %v, want %v", rbuf, buf) - } - - // Delete the memory mapping before iops.Release(). The cached page will - // either be evicted by ctx's pgalloc.MemoryFile, or dropped by - // iops.Release(). - iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true) -} - -func TestWrite(t *testing.T) { - ctx := contexttest.Context(t) - - // Construct a 4-page file. - buf := pagesOf('a', 'b', 'c', 'd') - orig := append([]byte(nil), buf...) - inode := anonInode(ctx) - uattr := fs.UnstableAttr{ - Size: int64(len(buf)), - } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - // Expect the cache to be initially empty. - if cached := iops.cache.Span(); cached != 0 { - t.Errorf("Span got %d, want 0", cached) - } - - // Create a memory mapping of the second and third pages (as - // CachingInodeOperations expects to only cache mapped pages), then call - // Translate to force them to be cached. - var ms noopMappingSpace - ar := usermem.AddrRange{usermem.PageSize, 3 * usermem.PageSize} - if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil { - t.Fatalf("AddMapping got %v, want nil", err) - } - defer iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true) - mr := memmap.MappableRange{usermem.PageSize, 3 * usermem.PageSize} - if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil { - t.Fatalf("Translate got %v, want nil", err) - } - if cached := iops.cache.Span(); cached != 2*usermem.PageSize { - t.Errorf("SpanRange got %d, want %d", cached, 2*usermem.PageSize) - } - - // Write to the first 2 pages. - wbuf := pagesOf('e', 'f') - src := usermem.BytesIOSequence(wbuf) - n, err := iops.Write(ctx, src, 0) - if n != 2*usermem.PageSize || err != nil { - t.Fatalf("Write got (%d, %v), want (%d, nil)", n, err, 2*usermem.PageSize) - } - - // The first page should have been written directly, since it was not cached. - want := append([]byte(nil), orig...) - copy(want, pagesOf('e')) - if !bytes.Equal(buf, want) { - t.Errorf("File contents are %v, want %v", buf, want) - } - - // Sync back to the "backing file". - if err := iops.WriteOut(ctx, inode); err != nil { - t.Errorf("Sync got %v, want nil", err) - } - - // Now the second page should have been written as well. - copy(want[usermem.PageSize:], pagesOf('f')) - if !bytes.Equal(buf, want) { - t.Errorf("File contents are %v, want %v", buf, want) - } -} diff --git a/pkg/sentry/fs/g3doc/inotify.md b/pkg/sentry/fs/g3doc/inotify.md deleted file mode 100644 index 85063d4e6..000000000 --- a/pkg/sentry/fs/g3doc/inotify.md +++ /dev/null @@ -1,122 +0,0 @@ -# Inotify - -Inotify implements the like-named filesystem event notification system for the -sentry, see `inotify(7)`. - -## Architecture - -For the most part, the sentry implementation of inotify mirrors the Linux -architecture. Inotify instances (i.e. the fd returned by inotify_init(2)) are -backed by a pseudo-filesystem. Events are generated from various places in the -sentry, including the [syscall layer][syscall_dir], the [vfs layer][dirent] and -the [process fd table][fd_table]. Watches are stored in inodes and generated -events are queued to the inotify instance owning the watches for delivery to the -user. - -## Objects - -Here is a brief description of the existing and new objects involved in the -sentry inotify mechanism, and how they interact: - -### [`fs.Inotify`][inotify] - -- An inotify instances, created by inotify_init(2)/inotify_init1(2). -- The inotify fd has a `fs.Dirent`, supports filesystem syscalls to read - events. -- Has multiple `fs.Watch`es, with at most one watch per target inode, per - inotify instance. -- Has an instance `id` which is globally unique. This is *not* the fd number - for this instance, since the fd can be duped. This `id` is not externally - visible. - -### [`fs.Watch`][watch] - -- An inotify watch, created/deleted by - inotify_add_watch(2)/inotify_rm_watch(2). -- Owned by an `fs.Inotify` instance, each watch keeps a pointer to the - `owner`. -- Associated with a single `fs.Inode`, which is the watch `target`. While the - watch is active, it indirectly pins `target` to memory. See the "Reference - Model" section for a detailed explanation. -- Filesystem operations on `target` generate `fs.Event`s. - -### [`fs.Event`][event] - -- A simple struct encapsulating all the fields for an inotify event. -- Generated by `fs.Watch`es and forwarded to the watches' `owner`s. -- Serialized to the user during read(2) syscalls on the associated - `fs.Inotify`'s fd. - -### [`fs.Dirent`][dirent] - -- Many inotify events are generated inside dirent methods. Events are - generated in the dirent methods rather than `fs.Inode` methods because some - events carry the name of the subject node, and node names are generally - unavailable in an `fs.Inode`. -- Dirents do not directly contain state for any watches. Instead, they forward - notifications to the underlying `fs.Inode`. - -### [`fs.Inode`][inode] - -- Interacts with inotify through `fs.Watch`es. -- Inodes contain a map of all active `fs.Watch`es on them. -- An `fs.Inotify` instance can have at most one `fs.Watch` per inode. - `fs.Watch`es on an inode are indexed by their `owner`'s `id`. -- All inotify logic is encapsulated in the [`Watches`][inode_watches] struct - in an inode. Logically, `Watches` is the set of inotify watches on the - inode. - -## Reference Model - -The sentry inotify implementation has a complex reference model. An inotify -watch observes a single inode. For efficient lookup, the state for a watch is -stored directly on the target inode. This state needs to be persistent for the -lifetime of watch. Unlike usual filesystem metadata, the watch state has no -"on-disk" representation, so they cannot be reconstructed by the filesystem if -the inode is flushed from memory. This effectively means we need to keep any -inodes with actives watches pinned to memory. - -We can't just hold an extra ref on the inode to pin it to memory because some -filesystems (such as gofer-based filesystems) don't have persistent inodes. In -such a filesystem, if we just pin the inode, nothing prevents the enclosing -dirent from being GCed. Once the dirent is GCed, the pinned inode is -unreachable -- these filesystems generate a new inode by re-reading the node -state on the next walk. Incidentally, hardlinks also don't work on these -filesystems for this reason. - -To prevent the above scenario, when a new watch is added on an inode, we *pin* -the dirent we used to reach the inode. Note that due to hardlinks, this dirent -may not be the only dirent pointing to the inode. Attempting to set an inotify -watch via multiple hardlinks to the same file results in the same watch being -returned for both links. However, for each new dirent we use to reach the same -inode, we add a new pin. We need a new pin for each new dirent used to reach the -inode because we have no guarantees about the deletion order of the different -links to the inode. - -## Lock Ordering - -There are 4 locks related to the inotify implementation: - -- `Inotify.mu`: the inotify instance lock. -- `Inotify.evMu`: the inotify event queue lock. -- `Watch.mu`: the watch lock, used to protect pins. -- `fs.Watches.mu`: the inode watch set mu, used to protect the collection of - watches on the inode. - -The correct lock ordering for inotify code is: - -`Inotify.mu` -> `fs.Watches.mu` -> `Watch.mu` -> `Inotify.evMu`. - -We need a distinct lock for the event queue because by the time a goroutine -attempts to queue a new event, it is already holding `fs.Watches.mu`. If we used -`Inotify.mu` to also protect the event queue, this would violate the above lock -ordering. - -[dirent]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/dirent.go -[event]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_event.go -[fd_table]: https://github.com/google/gvisor/blob/master/pkg/sentry/kernel/fd_table.go -[inode]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode.go -[inode_watches]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode_inotify.go -[inotify]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify.go -[syscall_dir]: https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/ -[watch]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_watch.go diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD deleted file mode 100644 index fea135eea..000000000 --- a/pkg/sentry/fs/gofer/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gofer", - srcs = [ - "attr.go", - "cache_policy.go", - "context_file.go", - "device.go", - "fifo.go", - "file.go", - "file_state.go", - "fs.go", - "handles.go", - "inode.go", - "inode_state.go", - "path.go", - "session.go", - "session_state.go", - "socket.go", - "util.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fd", - "//pkg/log", - "//pkg/metric", - "//pkg/p9", - "//pkg/refs", - "//pkg/safemem", - "//pkg/secio", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fdpipe", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/host", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/unix/transport", - "//pkg/sync", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/unet", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "gofer_test", - size = "small", - srcs = ["gofer_test.go"], - library = ":gofer", - deps = [ - "//pkg/context", - "//pkg/p9", - "//pkg/p9/p9test", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - ], -) diff --git a/pkg/sentry/fs/gofer/fifo.go b/pkg/sentry/fs/gofer/fifo.go index 456557058..456557058 100644..100755 --- a/pkg/sentry/fs/gofer/fifo.go +++ b/pkg/sentry/fs/gofer/fifo.go diff --git a/pkg/sentry/fs/gofer/gofer_state_autogen.go b/pkg/sentry/fs/gofer/gofer_state_autogen.go new file mode 100755 index 000000000..7db9211b4 --- /dev/null +++ b/pkg/sentry/fs/gofer/gofer_state_autogen.go @@ -0,0 +1,145 @@ +// automatically generated by stateify. + +package gofer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *fifo) beforeSave() {} +func (x *fifo) save(m state.Map) { + x.beforeSave() + m.Save("InodeOperations", &x.InodeOperations) + m.Save("fileIops", &x.fileIops) +} + +func (x *fifo) afterLoad() {} +func (x *fifo) load(m state.Map) { + m.Load("InodeOperations", &x.InodeOperations) + m.Load("fileIops", &x.fileIops) +} + +func (x *fileOperations) beforeSave() {} +func (x *fileOperations) save(m state.Map) { + x.beforeSave() + m.Save("inodeOperations", &x.inodeOperations) + m.Save("dirCursor", &x.dirCursor) + m.Save("flags", &x.flags) +} + +func (x *fileOperations) load(m state.Map) { + m.LoadWait("inodeOperations", &x.inodeOperations) + m.Load("dirCursor", &x.dirCursor) + m.LoadWait("flags", &x.flags) + m.AfterLoad(x.afterLoad) +} + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func (x *inodeOperations) beforeSave() {} +func (x *inodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("fileState", &x.fileState) + m.Save("cachingInodeOps", &x.cachingInodeOps) +} + +func (x *inodeOperations) afterLoad() {} +func (x *inodeOperations) load(m state.Map) { + m.LoadWait("fileState", &x.fileState) + m.Load("cachingInodeOps", &x.cachingInodeOps) +} + +func (x *inodeFileState) save(m state.Map) { + x.beforeSave() + var loading struct{} = x.saveLoading() + m.SaveValue("loading", loading) + m.Save("s", &x.s) + m.Save("sattr", &x.sattr) + m.Save("savedUAttr", &x.savedUAttr) + m.Save("hostMappable", &x.hostMappable) +} + +func (x *inodeFileState) load(m state.Map) { + m.LoadWait("s", &x.s) + m.LoadWait("sattr", &x.sattr) + m.Load("savedUAttr", &x.savedUAttr) + m.Load("hostMappable", &x.hostMappable) + m.LoadValue("loading", new(struct{}), func(y interface{}) { x.loadLoading(y.(struct{})) }) + m.AfterLoad(x.afterLoad) +} + +func (x *overrideInfo) beforeSave() {} +func (x *overrideInfo) save(m state.Map) { + x.beforeSave() + m.Save("dirent", &x.dirent) + m.Save("endpoint", &x.endpoint) + m.Save("inode", &x.inode) +} + +func (x *overrideInfo) afterLoad() {} +func (x *overrideInfo) load(m state.Map) { + m.Load("dirent", &x.dirent) + m.Load("endpoint", &x.endpoint) + m.Load("inode", &x.inode) +} + +func (x *overrideMaps) beforeSave() {} +func (x *overrideMaps) save(m state.Map) { + x.beforeSave() + m.Save("pathMap", &x.pathMap) +} + +func (x *overrideMaps) afterLoad() {} +func (x *overrideMaps) load(m state.Map) { + m.Load("pathMap", &x.pathMap) +} + +func (x *session) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("msize", &x.msize) + m.Save("version", &x.version) + m.Save("cachePolicy", &x.cachePolicy) + m.Save("aname", &x.aname) + m.Save("superBlockFlags", &x.superBlockFlags) + m.Save("limitHostFDTranslation", &x.limitHostFDTranslation) + m.Save("overlayfsStaleRead", &x.overlayfsStaleRead) + m.Save("connID", &x.connID) + m.Save("inodeMappings", &x.inodeMappings) + m.Save("mounter", &x.mounter) + m.Save("overrides", &x.overrides) +} + +func (x *session) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.LoadWait("msize", &x.msize) + m.LoadWait("version", &x.version) + m.LoadWait("cachePolicy", &x.cachePolicy) + m.LoadWait("aname", &x.aname) + m.LoadWait("superBlockFlags", &x.superBlockFlags) + m.Load("limitHostFDTranslation", &x.limitHostFDTranslation) + m.Load("overlayfsStaleRead", &x.overlayfsStaleRead) + m.LoadWait("connID", &x.connID) + m.LoadWait("inodeMappings", &x.inodeMappings) + m.LoadWait("mounter", &x.mounter) + m.LoadWait("overrides", &x.overrides) + m.AfterLoad(x.afterLoad) +} + +func init() { + state.Register("pkg/sentry/fs/gofer.fifo", (*fifo)(nil), state.Fns{Save: (*fifo).save, Load: (*fifo).load}) + state.Register("pkg/sentry/fs/gofer.fileOperations", (*fileOperations)(nil), state.Fns{Save: (*fileOperations).save, Load: (*fileOperations).load}) + state.Register("pkg/sentry/fs/gofer.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) + state.Register("pkg/sentry/fs/gofer.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load}) + state.Register("pkg/sentry/fs/gofer.inodeFileState", (*inodeFileState)(nil), state.Fns{Save: (*inodeFileState).save, Load: (*inodeFileState).load}) + state.Register("pkg/sentry/fs/gofer.overrideInfo", (*overrideInfo)(nil), state.Fns{Save: (*overrideInfo).save, Load: (*overrideInfo).load}) + state.Register("pkg/sentry/fs/gofer.overrideMaps", (*overrideMaps)(nil), state.Fns{Save: (*overrideMaps).save, Load: (*overrideMaps).load}) + state.Register("pkg/sentry/fs/gofer.session", (*session)(nil), state.Fns{Save: (*session).save, Load: (*session).load}) +} diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go deleted file mode 100644 index 2df2fe889..000000000 --- a/pkg/sentry/fs/gofer/gofer_test.go +++ /dev/null @@ -1,310 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "fmt" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/p9/p9test" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// rootTest runs a test with a p9 mock and an fs.InodeOperations created from -// the attached root directory. The root file will be closed and client -// disconnected, but additional files must be closed manually. -func rootTest(t *testing.T, name string, cp cachePolicy, fn func(context.Context, *p9test.Harness, *p9test.Mock, *fs.Inode)) { - t.Run(name, func(t *testing.T) { - h, c := p9test.NewHarness(t) - defer h.Finish() - - // Create a new root. Note that we pass an empty, but non-nil - // map here. This allows tests to extend the root children - // dynamically. - root := h.NewDirectory(map[string]p9test.Generator{})(nil) - - // Return this as the root. - h.Attacher.EXPECT().Attach().Return(root, nil).Times(1) - - // ... and open via the client. - rootFile, err := c.Attach("/") - if err != nil { - t.Fatalf("unable to attach: %v", err) - } - defer rootFile.Close() - - // Wrap an a session. - s := &session{ - mounter: fs.RootOwner, - cachePolicy: cp, - client: c, - } - - // ... and an INode, with only the mode being explicitly valid for now. - ctx := contexttest.Context(t) - sattr, rootInodeOperations := newInodeOperations(ctx, s, contextFile{ - file: rootFile, - }, root.QID, p9.AttrMaskAll(), root.Attr) - m := fs.NewMountSource(ctx, s, &filesystem{}, fs.MountSourceFlags{}) - rootInode := fs.NewInode(ctx, rootInodeOperations, m, sattr) - - // Ensure that the cache is fully invalidated, so that any - // close actions actually take place before the full harness is - // torn down. - defer func() { - m.FlushDirentRefs() - - // Wait for all resources to be released, otherwise the - // operations may fail after we close the rootFile. - fs.AsyncBarrier() - }() - - // Execute the test. - fn(ctx, h, root, rootInode) - }) -} - -func TestLookup(t *testing.T) { - type lookupTest struct { - // Name of the test. - name string - - // Expected return value. - want error - } - - tests := []lookupTest{ - { - name: "mock Walk passes (function succeeds)", - want: nil, - }, - { - name: "mock Walk fails (function fails)", - want: syscall.ENOENT, - }, - } - - const file = "file" // The walked target file. - - for _, test := range tests { - rootTest(t, test.name, cacheNone, func(ctx context.Context, h *p9test.Harness, rootFile *p9test.Mock, rootInode *fs.Inode) { - // Setup the appropriate result. - rootFile.WalkCallback = func() error { - return test.want - } - if test.want == nil { - // Set the contents of the root. We expect a - // normal file generator for ppp above. This is - // overriden by setting WalkErr in the mock. - rootFile.AddChild(file, h.NewFile()) - } - - // Call function. - dirent, err := rootInode.Lookup(ctx, file) - - // Unwrap the InodeOperations. - var newInodeOperations fs.InodeOperations - if dirent != nil { - if dirent.IsNegative() { - err = syscall.ENOENT - } else { - newInodeOperations = dirent.Inode.InodeOperations - } - } - - // Check return values. - if err != test.want { - t.Errorf("Lookup got err %v, want %v", err, test.want) - } - if err == nil && newInodeOperations == nil { - t.Errorf("Lookup got non-nil err and non-nil node, wanted at least one non-nil") - } - }) - } -} - -func TestRevalidation(t *testing.T) { - type revalidationTest struct { - cachePolicy cachePolicy - - // Whether dirent should be reloaded before any modifications. - preModificationWantReload bool - - // Whether dirent should be reloaded after updating an unstable - // attribute on the remote fs. - postModificationWantReload bool - - // Whether dirent unstable attributes should be updated after - // updating an attribute on the remote fs. - postModificationWantUpdatedAttrs bool - - // Whether dirent should be reloaded after the remote has - // removed the file. - postRemovalWantReload bool - } - - tests := []revalidationTest{ - { - // Policy cacheNone causes Revalidate to always return - // true. - cachePolicy: cacheNone, - preModificationWantReload: true, - postModificationWantReload: true, - postModificationWantUpdatedAttrs: true, - postRemovalWantReload: true, - }, - { - // Policy cacheAll causes Revalidate to always return - // false. - cachePolicy: cacheAll, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: false, - postRemovalWantReload: false, - }, - { - // Policy cacheAllWritethrough causes Revalidate to - // always return false. - cachePolicy: cacheAllWritethrough, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: false, - postRemovalWantReload: false, - }, - { - // Policy cacheRemoteRevalidating causes Revalidate to - // return update cached unstable attrs, and returns - // true only when the remote inode itself has been - // removed or replaced. - cachePolicy: cacheRemoteRevalidating, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: true, - postRemovalWantReload: true, - }, - } - - const file = "file" // The file walked below. - - for _, test := range tests { - name := fmt.Sprintf("cachepolicy=%s", test.cachePolicy) - rootTest(t, name, test.cachePolicy, func(ctx context.Context, h *p9test.Harness, rootFile *p9test.Mock, rootInode *fs.Inode) { - // Wrap in a dirent object. - rootDir := fs.NewDirent(ctx, rootInode, "root") - - // Create a mock file a child of the root. We save when - // this is generated, so that when the time changed, we - // can update the original entry. - var origMocks []*p9test.Mock - rootFile.AddChild(file, func(parent *p9test.Mock) *p9test.Mock { - // Regular a regular file that has a consistent - // path number. This might be used by - // validation so we don't change it. - m := h.NewMock(parent, 0, p9.Attr{ - Mode: p9.ModeRegular, - }) - origMocks = append(origMocks, m) - return m - }) - - // Do the walk. - dirent, err := rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - - // We must release the dirent, of the test will fail - // with a reference leak. This is tracked by p9test. - defer dirent.DecRef() - - // Walk again. Depending on the cache policy, we may - // get a new dirent. - newDirent, err := rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if test.preModificationWantReload && dirent == newDirent { - t.Errorf("Lookup with cachePolicy=%s got old dirent %+v, wanted a new dirent", test.cachePolicy, dirent) - } - if !test.preModificationWantReload && dirent != newDirent { - t.Errorf("Lookup with cachePolicy=%s got new dirent %+v, wanted old dirent %+v", test.cachePolicy, newDirent, dirent) - } - newDirent.DecRef() // See above. - - // Modify the underlying mocked file's modification - // time for the next walk that occurs. - nowSeconds := time.Now().Unix() - rootFile.AddChild(file, func(parent *p9test.Mock) *p9test.Mock { - // Ensure that the path is the same as above, - // but we change only the modification time of - // the file. - return h.NewMock(parent, 0, p9.Attr{ - Mode: p9.ModeRegular, - MTimeSeconds: uint64(nowSeconds), - }) - }) - - // We also modify the original time, so that GetAttr - // behaves as expected for the caching case. - for _, m := range origMocks { - m.Attr.MTimeSeconds = uint64(nowSeconds) - } - - // Walk again. Depending on the cache policy, we may - // get a new dirent. - newDirent, err = rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if test.postModificationWantReload && dirent == newDirent { - t.Errorf("Lookup with cachePolicy=%s got old dirent, wanted a new dirent", test.cachePolicy) - } - if !test.postModificationWantReload && dirent != newDirent { - t.Errorf("Lookup with cachePolicy=%s got new dirent, wanted old dirent", test.cachePolicy) - } - uattrs, err := newDirent.Inode.UnstableAttr(ctx) - if err != nil { - t.Fatalf("Error getting unstable attrs: %v", err) - } - gotModTimeSeconds := uattrs.ModificationTime.Seconds() - if test.postModificationWantUpdatedAttrs && gotModTimeSeconds != nowSeconds { - t.Fatalf("Lookup with cachePolicy=%s got new modification time %v, wanted %v", test.cachePolicy, gotModTimeSeconds, nowSeconds) - } - newDirent.DecRef() // See above. - - // Remove the file from the remote fs, subsequent walks - // should now fail to find anything. - rootFile.RemoveChild(file) - - // Walk again. Depending on the cache policy, we may - // get ENOENT. - newDirent, err = rootDir.Walk(ctx, rootDir, file) - if test.postRemovalWantReload && err == nil { - t.Errorf("Lookup with cachePolicy=%s got nil error, wanted ENOENT", test.cachePolicy) - } - if !test.postRemovalWantReload && (err != nil || dirent != newDirent) { - t.Errorf("Lookup with cachePolicy=%s got new dirent and error %v, wanted old dirent and nil error", test.cachePolicy, err) - } - if err == nil { - newDirent.DecRef() // See above. - } - }) - } -} diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD deleted file mode 100644 index 21003ea45..000000000 --- a/pkg/sentry/fs/host/BUILD +++ /dev/null @@ -1,85 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "host", - srcs = [ - "control.go", - "descriptor.go", - "descriptor_state.go", - "device.go", - "file.go", - "fs.go", - "inode.go", - "inode_state.go", - "ioctl_unsafe.go", - "socket.go", - "socket_iovec.go", - "socket_state.go", - "socket_unsafe.go", - "tty.go", - "util.go", - "util_amd64_unsafe.go", - "util_arm64_unsafe.go", - "util_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/refs", - "//pkg/safemem", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/unimpl", - "//pkg/sentry/uniqueid", - "//pkg/sync", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/unet", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "host_test", - size = "small", - srcs = [ - "descriptor_test.go", - "fs_test.go", - "inode_test.go", - "socket_test.go", - "wait_test.go", - ], - library = ":host", - deps = [ - "//pkg/context", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/host/descriptor_test.go b/pkg/sentry/fs/host/descriptor_test.go deleted file mode 100644 index 4205981f5..000000000 --- a/pkg/sentry/fs/host/descriptor_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "io/ioutil" - "path/filepath" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestDescriptorRelease(t *testing.T) { - for _, tc := range []struct { - name string - saveable bool - wouldBlock bool - }{ - {name: "all false"}, - {name: "saveable", saveable: true}, - {name: "wouldBlock", wouldBlock: true}, - } { - t.Run(tc.name, func(t *testing.T) { - dir, err := ioutil.TempDir("", "descriptor_test") - if err != nil { - t.Fatal("ioutil.TempDir() failed:", err) - } - - fd, err := syscall.Open(filepath.Join(dir, "file"), syscall.O_RDWR|syscall.O_CREAT, 0666) - if err != nil { - t.Fatal("failed to open temp file:", err) - } - - // FD ownership is transferred to the descritor. - queue := &waiter.Queue{} - d, err := newDescriptor(fd, false /* donated*/, tc.saveable, tc.wouldBlock, queue) - if err != nil { - syscall.Close(fd) - t.Fatalf("newDescriptor(%d, %t, false, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err) - } - if tc.saveable { - if d.origFD < 0 { - t.Errorf("saveable descriptor must preserve origFD, desc: %+v", d) - } - } - if tc.wouldBlock { - if !fdnotifier.HasFD(int32(d.value)) { - t.Errorf("FD not registered with notifier, desc: %+v", d) - } - } - - oldVal := d.value - d.Release() - if d.value != -1 { - t.Errorf("d.value want: -1, got: %d", d.value) - } - if tc.wouldBlock { - if fdnotifier.HasFD(int32(oldVal)) { - t.Errorf("FD not unregistered with notifier, desc: %+v", d) - } - } - }) - } -} diff --git a/pkg/sentry/fs/host/fs_test.go b/pkg/sentry/fs/host/fs_test.go deleted file mode 100644 index 3111d2df9..000000000 --- a/pkg/sentry/fs/host/fs_test.go +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "fmt" - "io/ioutil" - "os" - "path" - "reflect" - "sort" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// newTestMountNamespace creates a MountNamespace with a ramfs root. -// It returns the host folder created, which should be removed when done. -func newTestMountNamespace(t *testing.T) (*fs.MountNamespace, string, error) { - p, err := ioutil.TempDir("", "root") - if err != nil { - return nil, "", err - } - - fd, err := open(nil, p) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - ctx := contexttest.Context(t) - root, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - mm, err := fs.NewMountNamespace(ctx, root) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - return mm, p, nil -} - -// createTestDirs populates the root with some test files and directories. -// /a/a1.txt -// /a/a2.txt -// /b/b1.txt -// /b/c/c1.txt -// /symlinks/normal.txt -// /symlinks/to_normal.txt -> /symlinks/normal.txt -// /symlinks/recursive -> /symlinks -func createTestDirs(ctx context.Context, t *testing.T, m *fs.MountNamespace) error { - r := m.Root() - defer r.DecRef() - - if err := r.CreateDirectory(ctx, r, "a", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - a, err := r.Walk(ctx, r, "a") - if err != nil { - return err - } - defer a.DecRef() - - a1, err := a.Create(ctx, r, "a1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - a1.DecRef() - - a2, err := a.Create(ctx, r, "a2.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - a2.DecRef() - - if err := r.CreateDirectory(ctx, r, "b", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - b, err := r.Walk(ctx, r, "b") - if err != nil { - return err - } - defer b.DecRef() - - b1, err := b.Create(ctx, r, "b1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - b1.DecRef() - - if err := b.CreateDirectory(ctx, r, "c", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - c, err := b.Walk(ctx, r, "c") - if err != nil { - return err - } - defer c.DecRef() - - c1, err := c.Create(ctx, r, "c1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - c1.DecRef() - - if err := r.CreateDirectory(ctx, r, "symlinks", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - symlinks, err := r.Walk(ctx, r, "symlinks") - if err != nil { - return err - } - defer symlinks.DecRef() - - normal, err := symlinks.Create(ctx, r, "normal.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - normal.DecRef() - - if err := symlinks.CreateLink(ctx, r, "/symlinks/normal.txt", "to_normal.txt"); err != nil { - return err - } - - return symlinks.CreateLink(ctx, r, "/symlinks", "recursive") -} - -// allPaths returns a slice of all paths of entries visible in the rootfs. -func allPaths(ctx context.Context, t *testing.T, m *fs.MountNamespace, base string) ([]string, error) { - var paths []string - root := m.Root() - defer root.DecRef() - - maxTraversals := uint(1) - d, err := m.FindLink(ctx, root, nil, base, &maxTraversals) - if err != nil { - t.Logf("FindLink failed for %q", base) - return paths, err - } - defer d.DecRef() - - if fs.IsDir(d.Inode.StableAttr) { - dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) - if err != nil { - return nil, fmt.Errorf("failed to open directory %q: %v", base, err) - } - iter, ok := dir.FileOperations.(fs.DirIterator) - if !ok { - return nil, fmt.Errorf("cannot directly iterate on host directory %q", base) - } - dirCtx := &fs.DirCtx{ - Serializer: noopDentrySerializer{}, - } - if _, err := fs.DirentReaddir(ctx, d, iter, root, dirCtx, 0); err != nil { - return nil, err - } - for name := range dirCtx.DentAttrs() { - if name == "." || name == ".." { - continue - } - - fullName := path.Join(base, name) - paths = append(paths, fullName) - - // Recurse. - subpaths, err := allPaths(ctx, t, m, fullName) - if err != nil { - return paths, err - } - paths = append(paths, subpaths...) - } - } - - return paths, nil -} - -type noopDentrySerializer struct{} - -func (noopDentrySerializer) CopyOut(string, fs.DentAttr) error { - return nil -} -func (noopDentrySerializer) Written() int { - return 4096 -} - -// pathsEqual returns true if the two string slices contain the same entries. -func pathsEqual(got, want []string) bool { - sort.Strings(got) - sort.Strings(want) - - if len(got) != len(want) { - return false - } - - for i := range got { - if got[i] != want[i] { - return false - } - } - - return true -} - -func TestWhitelist(t *testing.T) { - for _, test := range []struct { - // description of the test. - desc string - // paths are the paths to whitelist - paths []string - // want are all of the directory entries that should be - // visible (nothing beyond this set should be visible). - want []string - }{ - { - desc: "root", - paths: []string{"/"}, - want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt", "/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt", "/symlinks/recursive"}, - }, - { - desc: "top-level directories", - paths: []string{"/a", "/b"}, - want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "nested directories (1/2)", - paths: []string{"/b", "/b/c"}, - want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "nested directories (2/2)", - paths: []string{"/b/c", "/b"}, - want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "single file", - paths: []string{"/b/c/c1.txt"}, - want: []string{"/b", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "single file and directory", - paths: []string{"/a/a1.txt", "/b/c"}, - want: []string{"/a", "/a/a1.txt", "/b", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "symlink", - paths: []string{"/symlinks/to_normal.txt"}, - want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt"}, - }, - { - desc: "recursive symlink", - paths: []string{"/symlinks/recursive/normal.txt"}, - want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/recursive"}, - }, - } { - t.Run(test.desc, func(t *testing.T) { - m, p, err := newTestMountNamespace(t) - if err != nil { - t.Errorf("Failed to create MountNamespace: %v", err) - } - defer os.RemoveAll(p) - - ctx := withRoot(contexttest.RootContext(t), m.Root()) - if err := createTestDirs(ctx, t, m); err != nil { - t.Errorf("Failed to create test dirs: %v", err) - } - - if err := installWhitelist(ctx, m, test.paths); err != nil { - t.Errorf("installWhitelist(%v) err got %v want nil", test.paths, err) - } - - got, err := allPaths(ctx, t, m, "/") - if err != nil { - t.Fatalf("Failed to lookup paths (whitelisted: %v): %v", test.paths, err) - } - - if !pathsEqual(got, test.want) { - t.Errorf("For paths %v got %v want %v", test.paths, got, test.want) - } - }) - } -} - -func TestRootPath(t *testing.T) { - // Create a temp dir, which will be the root of our mounted fs. - rootPath, err := ioutil.TempDir(os.TempDir(), "root") - if err != nil { - t.Fatalf("TempDir failed: %v", err) - } - defer os.RemoveAll(rootPath) - - // Create two files inside the new root, one which will be whitelisted - // and one not. - whitelisted, err := ioutil.TempFile(rootPath, "white") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - if _, err := ioutil.TempFile(rootPath, "black"); err != nil { - t.Fatalf("TempFile failed: %v", err) - } - - // Create a mount with a root path and single whitelisted file. - hostFS := &Filesystem{} - ctx := contexttest.Context(t) - data := fmt.Sprintf("%s=%s,%s=%s", rootPathKey, rootPath, whitelistKey, whitelisted.Name()) - inode, err := hostFS.Mount(ctx, "", fs.MountSourceFlags{}, data, nil) - if err != nil { - t.Fatalf("Mount failed: %v", err) - } - mm, err := fs.NewMountNamespace(ctx, inode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - if err := hostFS.InstallWhitelist(ctx, mm); err != nil { - t.Fatalf("InstallWhitelist failed: %v", err) - } - - // Get the contents of the root directory. - rootDir := mm.Root() - rctx := withRoot(ctx, rootDir) - f, err := rootDir.Inode.GetFile(rctx, rootDir, fs.FileFlags{}) - if err != nil { - t.Fatalf("GetFile failed: %v", err) - } - c := &fs.CollectEntriesSerializer{} - if err := f.Readdir(rctx, c); err != nil { - t.Fatalf("Readdir failed: %v", err) - } - - // We should have only our whitelisted file, plus the dots. - want := []string{path.Base(whitelisted.Name()), ".", ".."} - got := c.Order - sort.Strings(want) - sort.Strings(got) - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got %v, wanted %v", got, want) - } -} - -type rootContext struct { - context.Context - root *fs.Dirent -} - -// withRoot returns a copy of ctx with the given root. -func withRoot(ctx context.Context, root *fs.Dirent) context.Context { - return &rootContext{ - Context: ctx, - root: root, - } -} - -// Value implements Context.Value. -func (rc rootContext) Value(key interface{}) interface{} { - switch key { - case fs.CtxRoot: - rc.root.IncRef() - return rc.root - default: - return rc.Context.Value(key) - } -} diff --git a/pkg/sentry/fs/host/host_amd64_unsafe_state_autogen.go b/pkg/sentry/fs/host/host_amd64_unsafe_state_autogen.go new file mode 100755 index 000000000..488cbdfcf --- /dev/null +++ b/pkg/sentry/fs/host/host_amd64_unsafe_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package host diff --git a/pkg/sentry/fs/host/host_arm64_unsafe_state_autogen.go b/pkg/sentry/fs/host/host_arm64_unsafe_state_autogen.go new file mode 100755 index 000000000..7371b44db --- /dev/null +++ b/pkg/sentry/fs/host/host_arm64_unsafe_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package host diff --git a/pkg/sentry/fs/host/host_state_autogen.go b/pkg/sentry/fs/host/host_state_autogen.go new file mode 100755 index 000000000..e689cd52c --- /dev/null +++ b/pkg/sentry/fs/host/host_state_autogen.go @@ -0,0 +1,142 @@ +// automatically generated by stateify. + +package host + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *descriptor) save(m state.Map) { + x.beforeSave() + m.Save("donated", &x.donated) + m.Save("origFD", &x.origFD) + m.Save("wouldBlock", &x.wouldBlock) +} + +func (x *descriptor) load(m state.Map) { + m.Load("donated", &x.donated) + m.Load("origFD", &x.origFD) + m.Load("wouldBlock", &x.wouldBlock) + m.AfterLoad(x.afterLoad) +} + +func (x *fileOperations) beforeSave() {} +func (x *fileOperations) save(m state.Map) { + x.beforeSave() + m.Save("iops", &x.iops) + m.Save("dirCursor", &x.dirCursor) +} + +func (x *fileOperations) afterLoad() {} +func (x *fileOperations) load(m state.Map) { + m.LoadWait("iops", &x.iops) + m.Load("dirCursor", &x.dirCursor) +} + +func (x *Filesystem) beforeSave() {} +func (x *Filesystem) save(m state.Map) { + x.beforeSave() + m.Save("paths", &x.paths) +} + +func (x *Filesystem) afterLoad() {} +func (x *Filesystem) load(m state.Map) { + m.Load("paths", &x.paths) +} + +func (x *superOperations) beforeSave() {} +func (x *superOperations) save(m state.Map) { + x.beforeSave() + m.Save("SimpleMountSourceOperations", &x.SimpleMountSourceOperations) + m.Save("root", &x.root) + m.Save("inodeMappings", &x.inodeMappings) + m.Save("mounter", &x.mounter) + m.Save("dontTranslateOwnership", &x.dontTranslateOwnership) +} + +func (x *superOperations) afterLoad() {} +func (x *superOperations) load(m state.Map) { + m.Load("SimpleMountSourceOperations", &x.SimpleMountSourceOperations) + m.Load("root", &x.root) + m.Load("inodeMappings", &x.inodeMappings) + m.Load("mounter", &x.mounter) + m.Load("dontTranslateOwnership", &x.dontTranslateOwnership) +} + +func (x *inodeOperations) beforeSave() {} +func (x *inodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("fileState", &x.fileState) + m.Save("cachingInodeOps", &x.cachingInodeOps) +} + +func (x *inodeOperations) afterLoad() {} +func (x *inodeOperations) load(m state.Map) { + m.LoadWait("fileState", &x.fileState) + m.Load("cachingInodeOps", &x.cachingInodeOps) +} + +func (x *inodeFileState) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.queue) { + m.Failf("queue is %v, expected zero", x.queue) + } + m.Save("mops", &x.mops) + m.Save("descriptor", &x.descriptor) + m.Save("sattr", &x.sattr) + m.Save("savedUAttr", &x.savedUAttr) +} + +func (x *inodeFileState) load(m state.Map) { + m.LoadWait("mops", &x.mops) + m.LoadWait("descriptor", &x.descriptor) + m.LoadWait("sattr", &x.sattr) + m.Load("savedUAttr", &x.savedUAttr) + m.AfterLoad(x.afterLoad) +} + +func (x *ConnectedEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("ref", &x.ref) + m.Save("queue", &x.queue) + m.Save("path", &x.path) + m.Save("srfd", &x.srfd) + m.Save("stype", &x.stype) +} + +func (x *ConnectedEndpoint) load(m state.Map) { + m.Load("ref", &x.ref) + m.Load("queue", &x.queue) + m.Load("path", &x.path) + m.LoadWait("srfd", &x.srfd) + m.Load("stype", &x.stype) + m.AfterLoad(x.afterLoad) +} + +func (x *TTYFileOperations) beforeSave() {} +func (x *TTYFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("fileOperations", &x.fileOperations) + m.Save("session", &x.session) + m.Save("fgProcessGroup", &x.fgProcessGroup) + m.Save("termios", &x.termios) +} + +func (x *TTYFileOperations) afterLoad() {} +func (x *TTYFileOperations) load(m state.Map) { + m.Load("fileOperations", &x.fileOperations) + m.Load("session", &x.session) + m.Load("fgProcessGroup", &x.fgProcessGroup) + m.Load("termios", &x.termios) +} + +func init() { + state.Register("pkg/sentry/fs/host.descriptor", (*descriptor)(nil), state.Fns{Save: (*descriptor).save, Load: (*descriptor).load}) + state.Register("pkg/sentry/fs/host.fileOperations", (*fileOperations)(nil), state.Fns{Save: (*fileOperations).save, Load: (*fileOperations).load}) + state.Register("pkg/sentry/fs/host.Filesystem", (*Filesystem)(nil), state.Fns{Save: (*Filesystem).save, Load: (*Filesystem).load}) + state.Register("pkg/sentry/fs/host.superOperations", (*superOperations)(nil), state.Fns{Save: (*superOperations).save, Load: (*superOperations).load}) + state.Register("pkg/sentry/fs/host.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load}) + state.Register("pkg/sentry/fs/host.inodeFileState", (*inodeFileState)(nil), state.Fns{Save: (*inodeFileState).save, Load: (*inodeFileState).load}) + state.Register("pkg/sentry/fs/host.ConnectedEndpoint", (*ConnectedEndpoint)(nil), state.Fns{Save: (*ConnectedEndpoint).save, Load: (*ConnectedEndpoint).load}) + state.Register("pkg/sentry/fs/host.TTYFileOperations", (*TTYFileOperations)(nil), state.Fns{Save: (*TTYFileOperations).save, Load: (*TTYFileOperations).load}) +} diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go deleted file mode 100644 index 7221bc825..000000000 --- a/pkg/sentry/fs/host/inode_test.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "io/ioutil" - "os" - "path" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// TestMultipleReaddir verifies that multiple Readdir calls return the same -// thing if they use different dir contexts. -func TestMultipleReaddir(t *testing.T) { - p, err := ioutil.TempDir("", "readdir") - if err != nil { - t.Fatalf("Failed to create test dir: %v", err) - } - defer os.RemoveAll(p) - - f, err := os.Create(path.Join(p, "a.txt")) - if err != nil { - t.Fatalf("Failed to create a.txt: %v", err) - } - f.Close() - - f, err = os.Create(path.Join(p, "b.txt")) - if err != nil { - t.Fatalf("Failed to create b.txt: %v", err) - } - f.Close() - - fd, err := open(nil, p) - if err != nil { - t.Fatalf("Failed to open %q: %v", p, err) - } - ctx := contexttest.Context(t) - n, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false) - if err != nil { - t.Fatalf("Failed to create inode: %v", err) - } - - dirent := fs.NewDirent(ctx, n, "readdir") - openFile, err := n.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("Failed to get file: %v", err) - } - defer openFile.DecRef() - - c1 := &fs.DirCtx{DirCursor: new(string)} - if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c1, 0); err != nil { - t.Fatalf("First Readdir failed: %v", err) - } - - c2 := &fs.DirCtx{DirCursor: new(string)} - if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c2, 0); err != nil { - t.Errorf("Second Readdir failed: %v", err) - } - - if _, ok := c1.DentAttrs()["a.txt"]; !ok { - t.Errorf("want a.txt in first Readdir, got %v", c1.DentAttrs()) - } - if _, ok := c1.DentAttrs()["b.txt"]; !ok { - t.Errorf("want b.txt in first Readdir, got %v", c1.DentAttrs()) - } - - if _, ok := c2.DentAttrs()["a.txt"]; !ok { - t.Errorf("want a.txt in second Readdir, got %v", c2.DentAttrs()) - } - if _, ok := c2.DentAttrs()["b.txt"]; !ok { - t.Errorf("want b.txt in second Readdir, got %v", c2.DentAttrs()) - } -} - -// TestCloseFD verifies fds will be closed. -func TestCloseFD(t *testing.T) { - var p [2]int - if err := syscall.Pipe(p[0:]); err != nil { - t.Fatalf("Failed to create pipe %v", err) - } - defer syscall.Close(p[0]) - defer syscall.Close(p[1]) - - // Use the write-end because we will detect if it's closed on the read end. - ctx := contexttest.Context(t) - file, err := NewFile(ctx, p[1], fs.RootOwner) - if err != nil { - t.Fatalf("Failed to create File: %v", err) - } - file.DecRef() - - s := make([]byte, 10) - if c, err := syscall.Read(p[0], s); c != 0 || err != nil { - t.Errorf("want 0, nil (EOF) from read end, got %v, %v", c, err) - } -} diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go deleted file mode 100644 index eb4afe520..000000000 --- a/pkg/sentry/fs/host/socket_test.go +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "reflect" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - // Make sure that ConnectedEndpoint implements transport.ConnectedEndpoint. - _ = transport.ConnectedEndpoint(new(ConnectedEndpoint)) - - // Make sure that ConnectedEndpoint implements transport.Receiver. - _ = transport.Receiver(new(ConnectedEndpoint)) -) - -func getFl(fd int) (uint32, error) { - fl, _, err := syscall.RawSyscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_GETFL, 0) - if err == 0 { - return uint32(fl), nil - } - return 0, err -} - -func TestSocketIsBlocking(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - - fl, err := getFl(pair[0]) - if err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) - } - if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK { - t.Fatalf("Expected socket %v to be blocking", pair[0]) - } - if fl, err = getFl(pair[1]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err) - } - if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK { - t.Fatalf("Expected socket %v to be blocking", pair[1]) - } - sock, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) failed => %v", pair[0], err) - } - defer sock.DecRef() - // Test that the socket now is non-blocking. - if fl, err = getFl(pair[0]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) - } - if fl&syscall.O_NONBLOCK != syscall.O_NONBLOCK { - t.Errorf("Expected socket %v to have become non-blocking", pair[0]) - } - if fl, err = getFl(pair[1]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err) - } - if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK { - t.Errorf("Did not expect socket %v to become non-blocking", pair[1]) - } -} - -func TestSocketWritev(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - socket, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer socket.DecRef() - buf := []byte("hello world\n") - n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf)) - if err != nil { - t.Fatalf("socket writev failed: %v", err) - } - - if n != int64(len(buf)) { - t.Fatalf("socket writev wrote incorrect bytes: %d", n) - } -} - -func TestSocketWritevLen0(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - socket, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer socket.DecRef() - n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil)) - if err != nil { - t.Fatalf("socket writev failed: %v", err) - } - - if n != 0 { - t.Fatalf("socket writev wrote incorrect bytes: %d", n) - } -} - -func TestSocketSendMsgLen0(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - sfile, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer sfile.DecRef() - - s := sfile.FileOperations.(socket.Socket) - n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{}) - if n != 0 { - t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n) - } - - if terr != nil { - t.Fatalf("socket sendmsg() failed: %v", terr) - } -} - -func TestListen(t *testing.T) { - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err) - } - sfile1, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer sfile1.DecRef() - socket1 := sfile1.FileOperations.(socket.Socket) - - sfile2, err := newSocket(contexttest.Context(t), pair[1], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[1], err) - } - defer sfile2.DecRef() - socket2 := sfile2.FileOperations.(socket.Socket) - - // Socketpairs can not be listened to. - if err := socket1.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket1.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } - if err := socket2.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket2.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } - - // Create a Unix socket, do not bind it. - sock, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err) - } - sfile3, err := newSocket(contexttest.Context(t), sock, false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", sock, err) - } - defer sfile3.DecRef() - socket3 := sfile3.FileOperations.(socket.Socket) - - // This socket is not bound so we can't listen on it. - if err := socket3.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket3.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } -} - -func TestPasscred(t *testing.T) { - e := ConnectedEndpoint{} - if got, want := e.Passcred(), false; got != want { - t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want) - } -} - -func TestGetLocalAddress(t *testing.T) { - e := ConnectedEndpoint{path: "foo"} - want := tcpip.FullAddress{Addr: tcpip.Address("foo")} - if got, err := e.GetLocalAddress(); err != nil || got != want { - t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil) - } -} - -func TestQueuedSize(t *testing.T) { - e := ConnectedEndpoint{} - tests := []struct { - name string - f func() int64 - }{ - {"SendQueuedSize", e.SendQueuedSize}, - {"RecvQueuedSize", e.RecvQueuedSize}, - } - - for _, test := range tests { - if got, want := test.f(), int64(-1); got != want { - t.Errorf("Got %#v.%s() = %d, want = %d", e, test.name, got, want) - } - } -} - -func TestRelease(t *testing.T) { - f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - want := &ConnectedEndpoint{queue: c.queue} - want.ref.DecRef() - fdnotifier.AddFD(int32(c.file.FD()), nil) - c.Release() - if !reflect.DeepEqual(c, want) { - t.Errorf("got = %#v, want = %#v", c, want) - } -} diff --git a/pkg/sentry/fs/host/util_amd64_unsafe.go b/pkg/sentry/fs/host/util_amd64_unsafe.go index 66da6e9f5..66da6e9f5 100644..100755 --- a/pkg/sentry/fs/host/util_amd64_unsafe.go +++ b/pkg/sentry/fs/host/util_amd64_unsafe.go diff --git a/pkg/sentry/fs/host/util_arm64_unsafe.go b/pkg/sentry/fs/host/util_arm64_unsafe.go index e8cb94aeb..e8cb94aeb 100644..100755 --- a/pkg/sentry/fs/host/util_arm64_unsafe.go +++ b/pkg/sentry/fs/host/util_arm64_unsafe.go diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go deleted file mode 100644 index d49c3a635..000000000 --- a/pkg/sentry/fs/host/wait_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestWait(t *testing.T) { - var fds [2]int - err := syscall.Pipe(fds[:]) - if err != nil { - t.Fatalf("Unable to create pipe: %v", err) - } - - defer syscall.Close(fds[1]) - - ctx := contexttest.Context(t) - file, err := NewFile(ctx, fds[0], fs.RootOwner) - if err != nil { - syscall.Close(fds[0]) - t.Fatalf("NewFile failed: %v", err) - } - - defer file.DecRef() - - r := file.Readiness(waiter.EventIn) - if r != 0 { - t.Fatalf("File is ready for read when it shouldn't be.") - } - - e, ch := waiter.NewChannelEntry(nil) - file.EventRegister(&e, waiter.EventIn) - defer file.EventUnregister(&e) - - // Check that there are no notifications yet. - if len(ch) != 0 { - t.Fatalf("Channel is non-empty") - } - - // Write to the pipe, so it should be writable now. - syscall.Write(fds[1], []byte{1}) - - // Check that we get a notification. We need to yield the current thread - // so that the fdnotifier can deliver notifications, so we use a - // 1-second timeout instead of just checking the length of the channel. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Channel not notified") - } -} diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go deleted file mode 100644 index 389c219d6..000000000 --- a/pkg/sentry/fs/inode_overlay_test.go +++ /dev/null @@ -1,470 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/syserror" -) - -func TestLookup(t *testing.T) { - ctx := contexttest.Context(t) - for _, test := range []struct { - // Test description. - desc string - - // Lookup parameters. - dir *fs.Inode - name string - - // Want from lookup. - found bool - hasUpper bool - hasLower bool - }{ - { - desc: "no upper, lower has name", - dir: fs.NewTestOverlayDir(ctx, - nil, /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - { - desc: "no lower, upper has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - nil, /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, only lower has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "b", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - { - desc: "upper and lower, only upper has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "b", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, both have file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, both have directory", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: true, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: true, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: true, - }, - { - desc: "upper and lower, upper negative masks lower file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, nil, []string{"a"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: false, - hasUpper: false, - hasLower: false, - }, - { - desc: "upper and lower, upper negative does not mask lower file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, nil, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - } { - t.Run(test.desc, func(t *testing.T) { - dirent, err := test.dir.Lookup(ctx, test.name) - if test.found && (err == syserror.ENOENT || dirent.IsNegative()) { - t.Fatalf("lookup %q expected to find positive dirent, got dirent %v err %v", test.name, dirent, err) - } - if !test.found { - if err != syserror.ENOENT && !dirent.IsNegative() { - t.Errorf("lookup %q expected to return ENOENT or negative dirent, got dirent %v err %v", test.name, dirent, err) - } - // Nothing more to check. - return - } - if hasUpper := dirent.Inode.TestHasUpperFS(); hasUpper != test.hasUpper { - t.Fatalf("lookup got upper filesystem %v, want %v", hasUpper, test.hasUpper) - } - if hasLower := dirent.Inode.TestHasLowerFS(); hasLower != test.hasLower { - t.Errorf("lookup got lower filesystem %v, want %v", hasLower, test.hasLower) - } - }) - } -} - -func TestLookupRevalidation(t *testing.T) { - // File name used in the tests. - fileName := "foofile" - ctx := contexttest.Context(t) - for _, tc := range []struct { - // Test description. - desc string - - // Upper and lower fs for the overlay. - upper *fs.Inode - lower *fs.Inode - - // Whether the upper requires revalidation. - revalidate bool - - // Whether we should get the same dirent on second lookup. - wantSame bool - }{ - { - desc: "file from upper with no revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, nil, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from upper with revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, nil, nil), - revalidate: true, - wantSame: false, - }, - { - desc: "file from lower with no revalidation", - upper: newTestRamfsDir(ctx, nil, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from lower with revalidation", - upper: newTestRamfsDir(ctx, nil, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: true, - // The file does not exist in the upper, so we do not - // need to revalidate it. - wantSame: true, - }, - { - desc: "file from upper and lower with no revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from upper and lower with revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: true, - wantSame: false, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - root := fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root") - ctx = &rootContext{ - Context: ctx, - root: root, - } - overlay := fs.NewDirent(ctx, fs.NewTestOverlayDir(ctx, tc.upper, tc.lower, tc.revalidate), "overlay") - // Lookup the file twice through the overlay. - first, err := overlay.Walk(ctx, root, fileName) - if err != nil { - t.Fatalf("overlay.Walk(%q) failed: %v", fileName, err) - } - second, err := overlay.Walk(ctx, root, fileName) - if err != nil { - t.Fatalf("overlay.Walk(%q) failed: %v", fileName, err) - } - - if tc.wantSame && first != second { - t.Errorf("dirent lookup got different dirents, wanted same\nfirst=%+v\nsecond=%+v", first, second) - } else if !tc.wantSame && first == second { - t.Errorf("dirent lookup got the same dirent, wanted different: %+v", first) - } - }) - } -} - -func TestCacheFlush(t *testing.T) { - ctx := contexttest.Context(t) - - // Upper and lower each have a file. - upperFileName := "file-from-upper" - lowerFileName := "file-from-lower" - upper := newTestRamfsDir(ctx, []dirContent{{name: upperFileName}}, nil) - lower := newTestRamfsDir(ctx, []dirContent{{name: lowerFileName}}, nil) - - overlay := fs.NewTestOverlayDir(ctx, upper, lower, true /* revalidate */) - - mns, err := fs.NewMountNamespace(ctx, overlay) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - root := mns.Root() - defer root.DecRef() - - ctx = &rootContext{ - Context: ctx, - root: root, - } - - for _, fileName := range []string{upperFileName, lowerFileName} { - // Walk to the file. - maxTraversals := uint(0) - dirent, err := mns.FindInode(ctx, root, nil, fileName, &maxTraversals) - if err != nil { - t.Fatalf("FindInode(%q) failed: %v", fileName, err) - } - - // Get a file from the dirent. - file, err := dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile() failed: %v", err) - } - - // The dirent should have 3 refs, one from us, one from the - // file, and one from the dirent cache. - // dirent cache. - if got, want := dirent.ReadRefs(), 3; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Drop the file reference. - file.DecRef() - - // Dirent should have 2 refs left. - if got, want := dirent.ReadRefs(), 2; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Flush the dirent cache. - mns.FlushMountSourceRefs() - - // Dirent should have 1 ref left from the dirent cache. - if got, want := dirent.ReadRefs(), 1; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Drop our ref. - dirent.DecRef() - - // We should be back to zero refs. - if got, want := dirent.ReadRefs(), 0; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - } - -} - -type dir struct { - fs.InodeOperations - - // List of negative child names. - negative []string - - // ReaddirCalled records whether Readdir was called on a file - // corresponding to this inode. - ReaddirCalled bool -} - -// GetXattr implements InodeOperations.GetXattr. -func (d *dir) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) { - for _, n := range d.negative { - if name == fs.XattrOverlayWhiteout(n) { - return "y", nil - } - } - return "", syserror.ENOATTR -} - -// GetFile implements InodeOperations.GetFile. -func (d *dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - file, err := d.InodeOperations.GetFile(ctx, dirent, flags) - if err != nil { - return nil, err - } - defer file.DecRef() - // Wrap the file's FileOperations in a dirFile. - fops := &dirFile{ - FileOperations: file.FileOperations, - inode: d, - } - return fs.NewFile(ctx, dirent, flags, fops), nil -} - -type dirContent struct { - name string - dir bool -} - -type dirFile struct { - fs.FileOperations - inode *dir -} - -type inode struct { - fsutil.InodeGenericChecker `state:"nosave"` - fsutil.InodeNoExtendedAttributes `state:"nosave"` - fsutil.InodeNoopRelease `state:"nosave"` - fsutil.InodeNoopWriteOut `state:"nosave"` - fsutil.InodeNotAllocatable `state:"nosave"` - fsutil.InodeNotDirectory `state:"nosave"` - fsutil.InodeNotMappable `state:"nosave"` - fsutil.InodeNotSocket `state:"nosave"` - fsutil.InodeNotSymlink `state:"nosave"` - fsutil.InodeNotTruncatable `state:"nosave"` - fsutil.InodeNotVirtual `state:"nosave"` - - fsutil.InodeSimpleAttributes - fsutil.InodeStaticFileGetter -} - -// Readdir implements fs.FileOperations.Readdir. It sets the ReaddirCalled -// field on the inode. -func (f *dirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) { - f.inode.ReaddirCalled = true - return f.FileOperations.Readdir(ctx, file, ser) -} - -func newTestRamfsInode(ctx context.Context, msrc *fs.MountSource) *fs.Inode { - inode := fs.NewInode(ctx, &inode{ - InodeStaticFileGetter: fsutil.InodeStaticFileGetter{ - Contents: []byte("foobar"), - }, - }, msrc, fs.StableAttr{Type: fs.RegularFile}) - return inode -} - -func newTestRamfsDir(ctx context.Context, contains []dirContent, negative []string) *fs.Inode { - msrc := fs.NewPseudoMountSource(ctx) - contents := make(map[string]*fs.Inode) - for _, c := range contains { - if c.dir { - contents[c.name] = newTestRamfsDir(ctx, nil, nil) - } else { - contents[c.name] = newTestRamfsInode(ctx, msrc) - } - } - dops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermissions{ - User: fs.PermMask{Read: true, Execute: true}, - }) - return fs.NewInode(ctx, &dir{ - InodeOperations: dops, - negative: negative, - }, msrc, fs.StableAttr{Type: fs.Directory}) -} diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD deleted file mode 100644 index ae3331737..000000000 --- a/pkg/sentry/fs/lock/BUILD +++ /dev/null @@ -1,58 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "lock_range", - out = "lock_range.go", - package = "lock", - prefix = "Lock", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "lock_set", - out = "lock_set.go", - consts = { - "minDegree": "3", - }, - package = "lock", - prefix = "Lock", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "LockRange", - "Value": "Lock", - "Functions": "lockSetFunctions", - }, -) - -go_library( - name = "lock", - srcs = [ - "lock.go", - "lock_range.go", - "lock_set.go", - "lock_set_functions.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/waiter", - ], -) - -go_test( - name = "lock_test", - size = "small", - srcs = [ - "lock_range_test.go", - "lock_test.go", - ], - library = ":lock", -) diff --git a/pkg/segment/range.go b/pkg/sentry/fs/lock/lock_range.go index 4d4aeffef..7a6f77640 100644..100755 --- a/pkg/segment/range.go +++ b/pkg/sentry/fs/lock/lock_range.go @@ -1,64 +1,47 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -// T is a required type parameter that must be an integral type. -type T uint64 +package lock // A Range represents a contiguous range of T. // // +stateify savable -type Range struct { +type LockRange struct { // Start is the inclusive start of the range. - Start T + Start uint64 // End is the exclusive end of the range. - End T + End uint64 } // WellFormed returns true if r.Start <= r.End. All other methods on a Range // require that the Range is well-formed. -func (r Range) WellFormed() bool { +func (r LockRange) WellFormed() bool { return r.Start <= r.End } // Length returns the length of the range. -func (r Range) Length() T { +func (r LockRange) Length() uint64 { return r.End - r.Start } // Contains returns true if r contains x. -func (r Range) Contains(x T) bool { +func (r LockRange) Contains(x uint64) bool { return r.Start <= x && x < r.End } // Overlaps returns true if r and r2 overlap. -func (r Range) Overlaps(r2 Range) bool { +func (r LockRange) Overlaps(r2 LockRange) bool { return r.Start < r2.End && r2.Start < r.End } // IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is // contained within r. -func (r Range) IsSupersetOf(r2 Range) bool { +func (r LockRange) IsSupersetOf(r2 LockRange) bool { return r.Start <= r2.Start && r.End >= r2.End } // Intersect returns a range consisting of the intersection between r and r2. // If r and r2 do not overlap, Intersect returns a range with unspecified // bounds, but for which Length() == 0. -func (r Range) Intersect(r2 Range) Range { +func (r LockRange) Intersect(r2 LockRange) LockRange { if r.Start < r2.Start { r.Start = r2.Start } @@ -74,6 +57,6 @@ func (r Range) Intersect(r2 Range) Range { // CanSplitAt returns true if it is legal to split a segment spanning the range // r at x; that is, splitting at x would produce two ranges, both of which have // non-zero length. -func (r Range) CanSplitAt(x T) bool { +func (r LockRange) CanSplitAt(x uint64) bool { return r.Contains(x) && r.Start < x } diff --git a/pkg/sentry/fs/lock/lock_range_test.go b/pkg/sentry/fs/lock/lock_range_test.go deleted file mode 100644 index 6221199d1..000000000 --- a/pkg/sentry/fs/lock/lock_range_test.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lock - -import ( - "syscall" - "testing" -) - -func TestComputeRange(t *testing.T) { - tests := []struct { - // Description of test. - name string - - // Requested start of the lock range. - start int64 - - // Requested length of the lock range, - // can be negative :( - length int64 - - // Pre-computed file offset based on whence. - // Will be added to start. - offset int64 - - // Expected error. - err error - - // If error is nil, the expected LockRange. - LockRange - }{ - { - name: "offset, start, and length all zero", - LockRange: LockRange{Start: 0, End: LockEOF}, - }, - { - name: "zero offset, zero start, positive length", - start: 0, - length: 4096, - offset: 0, - LockRange: LockRange{Start: 0, End: 4096}, - }, - { - name: "zero offset, negative start", - start: -4096, - offset: 0, - err: syscall.EINVAL, - }, - { - name: "large offset, negative start, positive length", - start: -2048, - length: 2048, - offset: 4096, - LockRange: LockRange{Start: 2048, End: 4096}, - }, - { - name: "large offset, negative start, zero length", - start: -2048, - length: 0, - offset: 4096, - LockRange: LockRange{Start: 2048, End: LockEOF}, - }, - { - name: "zero offset, zero start, negative length", - start: 0, - length: -4096, - offset: 0, - err: syscall.EINVAL, - }, - { - name: "large offset, zero start, negative length", - start: 0, - length: -4096, - offset: 4096, - LockRange: LockRange{Start: 0, End: 4096}, - }, - { - name: "offset, start, and length equal, length is negative", - start: 1024, - length: -1024, - offset: 1024, - LockRange: LockRange{Start: 1024, End: 2048}, - }, - { - name: "offset, start, and length equal, start is negative", - start: -1024, - length: 1024, - offset: 1024, - LockRange: LockRange{Start: 0, End: 1024}, - }, - { - name: "offset, start, and length equal, offset is negative", - start: 1024, - length: 1024, - offset: -1024, - LockRange: LockRange{Start: 0, End: 1024}, - }, - { - name: "offset, start, and length equal, all negative", - start: -1024, - length: -1024, - offset: -1024, - err: syscall.EINVAL, - }, - { - name: "offset, start, and length equal, all positive", - start: 1024, - length: 1024, - offset: 1024, - LockRange: LockRange{Start: 2048, End: 3072}, - }, - } - - for _, test := range tests { - rng, err := ComputeRange(test.start, test.length, test.offset) - if err != test.err { - t.Errorf("%s: lockRange(%d, %d, %d) got error %v, want %v", test.name, test.start, test.length, test.offset, err, test.err) - continue - } - if err == nil && rng != test.LockRange { - t.Errorf("%s: lockRange(%d, %d, %d) got LockRange %v, want %v", test.name, test.start, test.length, test.offset, rng, test.LockRange) - } - } -} diff --git a/pkg/sentry/fs/lock/lock_set.go b/pkg/sentry/fs/lock/lock_set.go new file mode 100755 index 000000000..2343ca0b4 --- /dev/null +++ b/pkg/sentry/fs/lock/lock_set.go @@ -0,0 +1,1270 @@ +package lock + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + LockminDegree = 3 + + LockmaxDegree = 2 * LockminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type LockSet struct { + root Locknode `state:".(*LockSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *LockSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *LockSet) IsEmptyRange(r LockRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *LockSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *LockSet) SpanRange(r LockRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *LockSet) FirstSegment() LockIterator { + if s.root.nrSegments == 0 { + return LockIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *LockSet) LastSegment() LockIterator { + if s.root.nrSegments == 0 { + return LockIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *LockSet) FirstGap() LockGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return LockGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *LockSet) LastGap() LockGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return LockGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *LockSet) Find(key uint64) (LockIterator, LockGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return LockIterator{n, i}, LockGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return LockIterator{}, LockGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *LockSet) FindSegment(key uint64) LockIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *LockSet) LowerBoundSegment(min uint64) LockIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *LockSet) UpperBoundSegment(max uint64) LockIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *LockSet) FindGap(key uint64) LockGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *LockSet) LowerBoundGap(min uint64) LockGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *LockSet) UpperBoundGap(max uint64) LockGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *LockSet) Add(r LockRange, val Lock) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *LockSet) AddWithoutMerging(r LockRange, val Lock) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *LockSet) Insert(gap LockGapIterator, r LockRange, val Lock) LockIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (lockSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (lockSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (lockSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *LockSet) InsertWithoutMerging(gap LockGapIterator, r LockRange, val Lock) LockIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *LockSet) InsertWithoutMergingUnchecked(gap LockGapIterator, r LockRange, val Lock) LockIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return LockIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *LockSet) Remove(seg LockIterator) LockGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + lockSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(LockGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *LockSet) RemoveAll() { + s.root = Locknode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *LockSet) RemoveRange(r LockRange) LockGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *LockSet) Merge(first, second LockIterator) LockIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *LockSet) MergeUnchecked(first, second LockIterator) LockIterator { + if first.End() == second.Start() { + if mval, ok := (lockSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return LockIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *LockSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *LockSet) MergeRange(r LockRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *LockSet) MergeAdjacent(r LockRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *LockSet) Split(seg LockIterator, split uint64) (LockIterator, LockIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *LockSet) SplitUnchecked(seg LockIterator, split uint64) (LockIterator, LockIterator) { + val1, val2 := (lockSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), LockRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *LockSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *LockSet) Isolate(seg LockIterator, r LockRange) LockIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *LockSet) ApplyContiguous(r LockRange, fn func(seg LockIterator)) LockGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return LockGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return LockGapIterator{} + } + } +} + +// +stateify savable +type Locknode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Locknode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [LockmaxDegree - 1]LockRange + values [LockmaxDegree - 1]Lock + children [LockmaxDegree]*Locknode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Locknode) firstSegment() LockIterator { + for n.hasChildren { + n = n.children[0] + } + return LockIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Locknode) lastSegment() LockIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return LockIterator{n, n.nrSegments - 1} +} + +func (n *Locknode) prevSibling() *Locknode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Locknode) nextSibling() *Locknode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Locknode) rebalanceBeforeInsert(gap LockGapIterator) LockGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < LockmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:LockminDegree-1], n.keys[:LockminDegree-1]) + copy(left.values[:LockminDegree-1], n.values[:LockminDegree-1]) + copy(right.keys[:LockminDegree-1], n.keys[LockminDegree:]) + copy(right.values[:LockminDegree-1], n.values[LockminDegree:]) + n.keys[0], n.values[0] = n.keys[LockminDegree-1], n.values[LockminDegree-1] + LockzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:LockminDegree], n.children[:LockminDegree]) + copy(right.children[:LockminDegree], n.children[LockminDegree:]) + LockzeroNodeSlice(n.children[2:]) + for i := 0; i < LockminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < LockminDegree { + return LockGapIterator{left, gap.index} + } + return LockGapIterator{right, gap.index - LockminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[LockminDegree-1], n.values[LockminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:LockminDegree-1], n.keys[LockminDegree:]) + copy(sibling.values[:LockminDegree-1], n.values[LockminDegree:]) + LockzeroValueSlice(n.values[LockminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:LockminDegree], n.children[LockminDegree:]) + LockzeroNodeSlice(n.children[LockminDegree:]) + for i := 0; i < LockminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = LockminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < LockminDegree { + return gap + } + return LockGapIterator{sibling, gap.index - LockminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Locknode) rebalanceAfterRemove(gap LockGapIterator) LockGapIterator { + for { + if n.nrSegments >= LockminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= LockminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + lockSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return LockGapIterator{n, 0} + } + if gap.node == n { + return LockGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= LockminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + lockSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return LockGapIterator{n, n.nrSegments} + } + return LockGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return LockGapIterator{p, gap.index} + } + if gap.node == right { + return LockGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Locknode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = LockGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + lockSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type LockIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Locknode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg LockIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg LockIterator) Range() LockRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg LockIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg LockIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg LockIterator) SetRangeUnchecked(r LockRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetRange(r LockRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg LockIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg LockIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg LockIterator) Value() Lock { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg LockIterator) ValuePtr() *Lock { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg LockIterator) SetValue(val Lock) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg LockIterator) PrevSegment() LockIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return LockIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return LockIterator{} + } + return LocksegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg LockIterator) NextSegment() LockIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return LockIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return LockIterator{} + } + return LocksegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg LockIterator) PrevGap() LockGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return LockGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg LockIterator) NextGap() LockGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return LockGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg LockIterator) PrevNonEmpty() (LockIterator, LockGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return LockIterator{}, gap + } + return gap.PrevSegment(), LockGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg LockIterator) NextNonEmpty() (LockIterator, LockGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return LockIterator{}, gap + } + return gap.NextSegment(), LockGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type LockGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Locknode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap LockGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap LockGapIterator) Range() LockRange { + return LockRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap LockGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return lockSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap LockGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return lockSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap LockGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap LockGapIterator) PrevSegment() LockIterator { + return LocksegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap LockGapIterator) NextSegment() LockIterator { + return LocksegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap LockGapIterator) PrevGap() LockGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return LockGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap LockGapIterator) NextGap() LockGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return LockGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func LocksegmentBeforePosition(n *Locknode, i int) LockIterator { + for i == 0 { + if n.parent == nil { + return LockIterator{} + } + n, i = n.parent, n.parentIndex + } + return LockIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func LocksegmentAfterPosition(n *Locknode, i int) LockIterator { + for i == n.nrSegments { + if n.parent == nil { + return LockIterator{} + } + n, i = n.parent, n.parentIndex + } + return LockIterator{n, i} +} + +func LockzeroValueSlice(slice []Lock) { + + for i := range slice { + lockSetFunctions{}.ClearValue(&slice[i]) + } +} + +func LockzeroNodeSlice(slice []*Locknode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *LockSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *Locknode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Locknode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type LockSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []Lock +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *LockSet) ExportSortedSlices() *LockSegmentDataSlices { + var sds LockSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *LockSet) ImportSortedSlices(sds *LockSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := LockRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *LockSet) saveRoot() *LockSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *LockSet) loadRoot(sds *LockSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/lock/lock_state_autogen.go b/pkg/sentry/fs/lock/lock_state_autogen.go new file mode 100755 index 000000000..aabf3d570 --- /dev/null +++ b/pkg/sentry/fs/lock/lock_state_autogen.go @@ -0,0 +1,108 @@ +// automatically generated by stateify. + +package lock + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Lock) beforeSave() {} +func (x *Lock) save(m state.Map) { + x.beforeSave() + m.Save("Readers", &x.Readers) + m.Save("HasWriter", &x.HasWriter) + m.Save("Writer", &x.Writer) +} + +func (x *Lock) afterLoad() {} +func (x *Lock) load(m state.Map) { + m.Load("Readers", &x.Readers) + m.Load("HasWriter", &x.HasWriter) + m.Load("Writer", &x.Writer) +} + +func (x *Locks) beforeSave() {} +func (x *Locks) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.blockedQueue) { + m.Failf("blockedQueue is %v, expected zero", x.blockedQueue) + } + m.Save("locks", &x.locks) +} + +func (x *Locks) afterLoad() {} +func (x *Locks) load(m state.Map) { + m.Load("locks", &x.locks) +} + +func (x *LockRange) beforeSave() {} +func (x *LockRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *LockRange) afterLoad() {} +func (x *LockRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *LockSet) beforeSave() {} +func (x *LockSet) save(m state.Map) { + x.beforeSave() + var root *LockSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *LockSet) afterLoad() {} +func (x *LockSet) load(m state.Map) { + m.LoadValue("root", new(*LockSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*LockSegmentDataSlices)) }) +} + +func (x *Locknode) beforeSave() {} +func (x *Locknode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *Locknode) afterLoad() {} +func (x *Locknode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *LockSegmentDataSlices) beforeSave() {} +func (x *LockSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *LockSegmentDataSlices) afterLoad() {} +func (x *LockSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("pkg/sentry/fs/lock.Lock", (*Lock)(nil), state.Fns{Save: (*Lock).save, Load: (*Lock).load}) + state.Register("pkg/sentry/fs/lock.Locks", (*Locks)(nil), state.Fns{Save: (*Locks).save, Load: (*Locks).load}) + state.Register("pkg/sentry/fs/lock.LockRange", (*LockRange)(nil), state.Fns{Save: (*LockRange).save, Load: (*LockRange).load}) + state.Register("pkg/sentry/fs/lock.LockSet", (*LockSet)(nil), state.Fns{Save: (*LockSet).save, Load: (*LockSet).load}) + state.Register("pkg/sentry/fs/lock.Locknode", (*Locknode)(nil), state.Fns{Save: (*Locknode).save, Load: (*Locknode).load}) + state.Register("pkg/sentry/fs/lock.LockSegmentDataSlices", (*LockSegmentDataSlices)(nil), state.Fns{Save: (*LockSegmentDataSlices).save, Load: (*LockSegmentDataSlices).load}) +} diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go deleted file mode 100644 index ba002aeb7..000000000 --- a/pkg/sentry/fs/lock/lock_test.go +++ /dev/null @@ -1,1059 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lock - -import ( - "reflect" - "testing" -) - -type entry struct { - Lock - LockRange -} - -func equals(e0, e1 []entry) bool { - if len(e0) != len(e1) { - return false - } - for i := range e0 { - for k := range e0[i].Lock.Readers { - if !e1[i].Lock.Readers[k] { - return false - } - } - for k := range e1[i].Lock.Readers { - if !e0[i].Lock.Readers[k] { - return false - } - } - if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) { - return false - } - if e0[i].Lock.HasWriter != e1[i].Lock.HasWriter { - return false - } - if e0[i].Lock.Writer != e1[i].Lock.Writer { - return false - } - } - return true -} - -// fill a LockSet with consecutive region locks. Will panic if -// LockRanges are not consecutive. -func fill(entries []entry) LockSet { - l := LockSet{} - for _, e := range entries { - gap := l.FindGap(e.LockRange.Start) - if !gap.Ok() { - panic("cannot insert into existing segment") - } - l.Insert(gap, e.LockRange, e.Lock) - } - return l -} - -func TestCanLockEmpty(t *testing.T) { - l := LockSet{} - - // Expect to be able to take any locks given that the set is empty. - eof := l.FirstGap().End() - r := LockRange{0, eof} - if !l.canLock(1, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 1) - } - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - if !l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 2) - } -} - -func TestCanLock(t *testing.T) { - // + -------------- + ---------- + -------------- + --------- + - // | Readers 1 & 2 | Readers 1 | Readers 1 & 3 | Writer 1 | - // + ------------- + ---------- + -------------- + --------- + - // 0 1024 2048 3072 4096 - l := fill([]entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 3: true}}, - LockRange: LockRange{2048, 3072}, - }, - { - Lock: Lock{HasWriter: true, Writer: 1}, - LockRange: LockRange{3072, 4096}, - }, - }) - - // Now that we have a mildly interesting layout, try some checks on different - // ranges, uids, and lock types. - // - // Expect to be able to extend the read lock, despite the writer lock, because - // the writer has the same uid as the requested read lock. - r := LockRange{0, 8192} - if !l.canLock(1, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 1) - } - // Expect to *not* be able to extend the read lock since there is an overlapping - // writer region locked by someone other than the uid. - if l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", ReadLock, r, 2) - } - // Expect to be able to extend the read lock if there are only other readers in - // the way. - r = LockRange{64, 3072} - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - // Expect to be able to set a read lock beyond the range of any existing locks. - r = LockRange{4096, 10240} - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - - // Expect to not be able to take a write lock with other readers in the way. - r = LockRange{0, 8192} - if l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", WriteLock, r, 1) - } - // Expect to be able to extend the write lock for the same uid. - r = LockRange{3072, 8192} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - // Expect to not be able to overlap a write lock for two different uids. - if l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", WriteLock, r, 2) - } - // Expect to be able to set a write lock that is beyond the range of any - // existing locks. - r = LockRange{8192, 10240} - if !l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 2) - } - // Expect to be able to upgrade a read lock (any portion of it). - r = LockRange{1024, 2048} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - r = LockRange{1080, 2000} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } -} - -func TestSetLock(t *testing.T) { - tests := []struct { - // description of test. - name string - - // LockSet entries to pre-fill. - before []entry - - // Description of region to lock: - // - // start is the file offset of the lock. - start uint64 - // end is the end file offset of the lock. - end uint64 - // uid of lock attempter. - uid UniqueID - // lock type requested. - lockType LockType - - // success is true if taking the above - // lock should succeed. - success bool - - // Expected layout of the set after locking - // if success is true. - after []entry - }{ - { - name: "set zero length ReadLock on empty set", - start: 0, - end: 0, - uid: 0, - lockType: ReadLock, - success: true, - }, - { - name: "set zero length WriteLock on empty set", - start: 0, - end: 0, - uid: 0, - lockType: WriteLock, - success: true, - }, - { - name: "set ReadLock on empty set", - start: 0, - end: LockEOF, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "set WriteLock on empty set", - start: 0, - end: LockEOF, - uid: 0, - lockType: WriteLock, - success: true, - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "set ReadLock on WriteLock same uid", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------- + --------------------------- + - // | Readers 0 | Writer 0 | - // + ----------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "set WriteLock on ReadLock same uid", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - lockType: WriteLock, - success: true, - // + ----------- + --------------------------- + - // | Writer 0 | Readers 0 | - // + ----------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "set ReadLock on WriteLock different uid", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: false, - }, - { - name: "set WriteLock on ReadLock different uid", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: WriteLock, - success: false, - }, - { - name: "split ReadLock for overlapping lock at start 0", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: true, - // + -------------- + --------------------------- + - // | Readers 0 & 1 | Readers 0 | - // + -------------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "split ReadLock for overlapping lock at non-zero start", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: 8192, - uid: 1, - lockType: ReadLock, - success: true, - // + ---------- + -------------- + ----------- + - // | Readers 0 | Readers 0 & 1 | Readers 0 | - // + ---------- + -------------- + ----------- + - // 0 4096 8192 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{4096, 8192}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{8192, LockEOF}, - }, - }, - }, - { - name: "fill front gap with ReadLock", - // + --------- + ---------------------------- + - // | gap | Readers 0 | - // + --------- + ---------------------------- + - // 0 1024 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - start: 0, - end: 8192, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "fill end gap with ReadLock", - // + ---------------------------- + - // | Readers 0 | - // + ---------------------------- + - // 0 4096 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 4096}, - }, - }, - start: 1024, - end: LockEOF, - uid: 0, - lockType: ReadLock, - success: true, - // Note that this is not merged after lock does a Split. This is - // fine because the two locks will still *behave* as one. In other - // words we can fragment any lock all we want and semantically it - // makes no difference. - // - // + ----------- + --------------------------- + - // | Readers 0 | Readers 0 | - // + ----------- + --------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - }, - { - name: "fill gap with ReadLock and split", - // + --------- + ---------------------------- + - // | gap | Readers 0 | - // + --------- + ---------------------------- + - // 0 1024 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: true, - // + --------- + ------------- + ------------- + - // | Reader 1 | Readers 0 & 1 | Reader 0 | - // + ----------+ ------------- + ------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "upgrade ReadLock to WriteLock for single uid fill gap", - // + ------------- + --------- + --- + ------------- + - // | Readers 0 & 1 | Readers 0 | gap | Readers 0 & 2 | - // + ------------- + --------- + --- + ------------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - lockType: WriteLock, - success: true, - // + ------------- + -------- + ------------- + - // | Readers 0 & 1 | Writer 0 | Readers 0 & 2 | - // + ------------- + -------- + ------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "upgrade ReadLock to WriteLock for single uid keep gap", - // + ------------- + --------- + --- + ------------- + - // | Readers 0 & 1 | Readers 0 | gap | Readers 0 & 2 | - // + ------------- + --------- + --- + ------------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 1024, - end: 3072, - uid: 0, - lockType: WriteLock, - success: true, - // + ------------- + -------- + --- + ------------- + - // | Readers 0 & 1 | Writer 0 | gap | Readers 0 & 2 | - // + ------------- + -------- + --- + ------------- + - // 0 1024 3072 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{1024, 3072}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "fail to upgrade ReadLock to WriteLock with conflicting Reader", - // + ------------- + --------- + - // | Readers 0 & 1 | Readers 0 | - // + ------------- + --------- + - // 0 1024 2048 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, 2048}, - }, - }, - start: 0, - end: 2048, - uid: 0, - lockType: WriteLock, - success: false, - }, - { - name: "take WriteLock on whole file if all uids are the same", - // + ------------- + --------- + --------- + ---------- + - // | Writer 0 | Readers 0 | Readers 0 | Readers 0 | - // + ------------- + --------- + --------- + ---------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{2048, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - lockType: WriteLock, - success: true, - // We do not manually merge locks. Semantically a fragmented lock - // held by the same uid will behave as one lock so it makes no difference. - // - // + ------------- + ---------------------------- + - // | Writer 0 | Writer 0 | - // + ------------- + ---------------------------- + - // 0 1024 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - }, - } - - for _, test := range tests { - l := fill(test.before) - - r := LockRange{Start: test.start, End: test.end} - success := l.lock(test.uid, test.lockType, r) - var got []entry - for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - got = append(got, entry{ - Lock: seg.Value(), - LockRange: seg.Range(), - }) - } - - if success != test.success { - t.Errorf("%s: setlock(%v, %+v, %d, %d) got success %v, want %v", test.name, test.before, r, test.uid, test.lockType, success, test.success) - continue - } - - if success { - if !equals(got, test.after) { - t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after) - } - } - } -} - -func TestUnlock(t *testing.T) { - tests := []struct { - // description of test. - name string - - // LockSet entries to pre-fill. - before []entry - - // Description of region to unlock: - // - // start is the file start of the lock. - start uint64 - // end is the end file start of the lock. - end uint64 - // uid of lock holder. - uid UniqueID - - // Expected layout of the set after unlocking. - after []entry - }{ - { - name: "unlock zero length on empty set", - start: 0, - end: 0, - uid: 0, - }, - { - name: "unlock on empty set (no-op)", - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock uid not locked (no-op)", - // + --------------------------- + - // | Readers 1 & 2 | - // + --------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - // + --------------------------- + - // | Readers 1 & 2 | - // + --------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "unlock ReadLock over entire file", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock WriteLock over entire file", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock partial ReadLock (start)", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - // + ------ + --------------------------- + - // | gap | Readers 0 | - // +------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock partial WriteLock (start)", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - // + ------ + --------------------------- + - // | gap | Writer 0 | - // +------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock partial ReadLock (end)", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: LockEOF, - uid: 0, - // + --------------------------- + - // | Readers 0 | - // +---------------------------- + - // 0 4096 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 4096}, - }, - }, - }, - { - name: "unlock partial WriteLock (end)", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: LockEOF, - uid: 0, - // + --------------------------- + - // | Writer 0 | - // +---------------------------- + - // 0 4096 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 4096}, - }, - }, - }, - { - name: "unlock for single uid", - // + ------------- + --------- + ------------------- + - // | Readers 0 & 1 | Writer 0 | Readers 0 & 1 & 2 | - // + ------------- + --------- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - // + --------- + --- + --------------- + - // | Readers 1 | gap | Readers 1 & 2 | - // + --------- + --- + --------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock subsection locked", - // + ------------------------------- + - // | Readers 0 & 1 & 2 | - // + ------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - // + ----------------- + ------------- + ----------------- + - // | Readers 0 & 1 & 2 | Readers 1 & 2 | Readers 0 & 1 & 2 | - // + ----------------- + ------------- + ----------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock mid-gap to increase gap", - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 8, - end: 2048, - uid: 0, - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 8 4096 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 8}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock split region on uid mid-gap", - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 2048, - end: 8192, - uid: 0, - // + --------- + ----- + --------- + ------------- + - // | Writer 0 | gap | Readers 1 | Readers 0 & 1 | - // + --------- + ----- + --------- + ------------- + - // 0 1024 4096 8192 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, - LockRange: LockRange{4096, 8192}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{8192, LockEOF}, - }, - }, - }, - } - - for _, test := range tests { - l := fill(test.before) - - r := LockRange{Start: test.start, End: test.end} - l.unlock(test.uid, r) - var got []entry - for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - got = append(got, entry{ - Lock: seg.Value(), - LockRange: seg.Range(), - }) - } - if !equals(got, test.after) { - t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after) - } - } -} diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go deleted file mode 100644 index a3d10770b..000000000 --- a/pkg/sentry/fs/mount_test.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" -) - -// cacheReallyContains iterates through the dirent cache to determine whether -// it contains the given dirent. -func cacheReallyContains(cache *DirentCache, d *Dirent) bool { - for i := cache.list.Front(); i != nil; i = i.Next() { - if i == d { - return true - } - } - return false -} - -func mountPathsAre(root *Dirent, got []*Mount, want ...string) error { - gotPaths := make(map[string]struct{}, len(got)) - gotStr := make([]string, len(got)) - for i, g := range got { - if groot := g.Root(); groot != nil { - name, _ := groot.FullName(root) - groot.DecRef() - gotStr[i] = name - gotPaths[name] = struct{}{} - } - } - if len(got) != len(want) { - return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want) - } - for _, w := range want { - if _, ok := gotPaths[w]; !ok { - return fmt.Errorf("no mount with path %q found", w) - } - } - return nil -} - -// TestMountSourceOnlyCachedOnce tests that a Dirent that is mounted over only ends -// up in a single Dirent Cache. NOTE(b/63848693): Having a dirent in multiple -// caches causes major consistency issues. -func TestMountSourceOnlyCachedOnce(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef() - - // Get a child of the root which we will mount over. Note that the - // MockInodeOperations causes Walk to always succeed. - child, err := rootDirent.Walk(ctx, rootDirent, "child") - if err != nil { - t.Fatalf("failed to walk to child dirent: %v", err) - } - child.maybeExtendReference() // Cache. - - // Ensure that the root cache contains the child. - if !cacheReallyContains(rootCache, child) { - t.Errorf("wanted rootCache to contain child dirent, but it did not") - } - - // Create a new cache and inode, and mount it over child. - submountCache := NewDirentCache(100) - submountInode := NewMockInode(ctx, NewMockMountSource(submountCache), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, child, submountInode); err != nil { - t.Fatalf("failed to mount over child: %v", err) - } - - // Walk to the child again. - child2, err := rootDirent.Walk(ctx, rootDirent, "child") - if err != nil { - t.Fatalf("failed to walk to child dirent: %v", err) - } - - // Should have a different Dirent than before. - if child == child2 { - t.Fatalf("expected %v not equal to %v, but they are the same", child, child2) - } - - // Neither of the caches should no contain the child. - if cacheReallyContains(rootCache, child) { - t.Errorf("wanted rootCache not to contain child dirent, but it did") - } - if cacheReallyContains(submountCache, child) { - t.Errorf("wanted submountCache not to contain child dirent, but it did") - } -} - -func TestAllMountsUnder(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef() - - // Add mounts at the following paths: - paths := []string{ - "/foo", - "/foo/bar", - "/foo/bar/baz", - "/foo/qux", - "/waldo", - } - - var maxTraversals uint - for _, p := range paths { - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - submountInode := NewMockInode(ctx, NewMockMountSource(nil), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, d, submountInode); err != nil { - t.Fatalf("could not mount at %q: %v", p, err) - } - d.DecRef() - } - - // mm root should contain all submounts (and does not include the root mount). - rootMnt := mm.FindMount(rootDirent) - submounts := mm.AllMountsUnder(rootMnt) - allPaths := append(paths, "/") - if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil { - t.Error(err) - } - - // Each mount should have a unique ID. - foundIDs := make(map[uint64]struct{}) - for _, m := range submounts { - if _, ok := foundIDs[m.ID]; ok { - t.Errorf("got multiple mounts with id %d", m.ID) - } - foundIDs[m.ID] = struct{}{} - } - - // Root mount should have no parent. - if p := rootMnt.ParentID; p != invalidMountID { - t.Errorf("root.Parent got %v wanted nil", p) - } - - // Check that "foo" mount has 3 children. - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, "/foo", &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", "/foo", err) - } - defer d.DecRef() - submounts = mm.AllMountsUnder(mm.FindMount(d)) - if err := mountPathsAre(rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil { - t.Error(err) - } - - // "waldo" mount should have no children. - maxTraversals = 0 - waldo, err := mm.FindLink(ctx, rootDirent, nil, "/waldo", &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", "/waldo", err) - } - defer waldo.DecRef() - submounts = mm.AllMountsUnder(mm.FindMount(waldo)) - if err := mountPathsAre(rootDirent, submounts, "/waldo"); err != nil { - t.Error(err) - } -} - -func TestUnmount(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef() - - // Add mounts at the following paths: - paths := []string{ - "/foo", - "/foo/bar", - "/foo/bar/goo", - "/foo/bar/goo/abc", - "/foo/abc", - "/foo/def", - "/waldo", - "/wally", - } - - var maxTraversals uint - for _, p := range paths { - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - submountInode := NewMockInode(ctx, NewMockMountSource(nil), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, d, submountInode); err != nil { - t.Fatalf("could not mount at %q: %v", p, err) - } - d.DecRef() - } - - allPaths := make([]string, len(paths)+1) - allPaths[0] = "/" - copy(allPaths[1:], paths) - - rootMnt := mm.FindMount(rootDirent) - for i := len(paths) - 1; i >= 0; i-- { - maxTraversals = 0 - p := paths[i] - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - if err := mm.Unmount(ctx, d, false); err != nil { - t.Fatalf("could not unmount at %q: %v", p, err) - } - d.DecRef() - - // Remove the path that has been unmounted and the check that the remaining - // mounts are still there. - allPaths = allPaths[:len(allPaths)-1] - submounts := mm.AllMountsUnder(rootMnt) - if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil { - t.Error(err) - } - } -} diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go deleted file mode 100644 index a69b41468..000000000 --- a/pkg/sentry/fs/mounts_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" -) - -// Creates a new MountNamespace with filesystem: -// / (root dir) -// |-foo (dir) -// |-bar (file) -func createMountNamespace(ctx context.Context) (*fs.MountNamespace, error) { - perms := fs.FilePermsFromMode(0777) - m := fs.NewPseudoMountSource(ctx) - - barFile := fsutil.NewSimpleFileInode(ctx, fs.RootOwner, perms, 0) - fooDir := ramfs.NewDir(ctx, map[string]*fs.Inode{ - "bar": fs.NewInode(ctx, barFile, m, fs.StableAttr{Type: fs.RegularFile}), - }, fs.RootOwner, perms) - rootDir := ramfs.NewDir(ctx, map[string]*fs.Inode{ - "foo": fs.NewInode(ctx, fooDir, m, fs.StableAttr{Type: fs.Directory}), - }, fs.RootOwner, perms) - - return fs.NewMountNamespace(ctx, fs.NewInode(ctx, rootDir, m, fs.StableAttr{Type: fs.Directory})) -} - -func TestFindLink(t *testing.T) { - ctx := contexttest.Context(t) - mm, err := createMountNamespace(ctx) - if err != nil { - t.Fatalf("createMountNamespace failed: %v", err) - } - - root := mm.Root() - defer root.DecRef() - foo, err := root.Walk(ctx, root, "foo") - if err != nil { - t.Fatalf("Error walking to foo: %v", err) - } - - // Positive cases. - for _, tc := range []struct { - findPath string - wd *fs.Dirent - wantPath string - }{ - {".", root, "/"}, - {".", foo, "/foo"}, - {"..", foo, "/"}, - {"../../..", foo, "/"}, - {"///foo", foo, "/foo"}, - {"/foo", foo, "/foo"}, - {"/foo/bar", foo, "/foo/bar"}, - {"/foo/.///./bar", foo, "/foo/bar"}, - {"/foo///bar", foo, "/foo/bar"}, - {"/foo/../foo/bar", foo, "/foo/bar"}, - {"foo/bar", root, "/foo/bar"}, - {"foo////bar", root, "/foo/bar"}, - {"bar", foo, "/foo/bar"}, - } { - wdPath, _ := tc.wd.FullName(root) - maxTraversals := uint(0) - if d, err := mm.FindLink(ctx, root, tc.wd, tc.findPath, &maxTraversals); err != nil { - t.Errorf("FindLink(%q, wd=%q) failed: %v", tc.findPath, wdPath, err) - } else if got, _ := d.FullName(root); got != tc.wantPath { - t.Errorf("FindLink(%q, wd=%q) got dirent %q, want %q", tc.findPath, wdPath, got, tc.wantPath) - } - } - - // Negative cases. - for _, tc := range []struct { - findPath string - wd *fs.Dirent - }{ - {"bar", root}, - {"/bar", root}, - {"/foo/../../bar", root}, - {"foo", foo}, - } { - wdPath, _ := tc.wd.FullName(root) - maxTraversals := uint(0) - if _, err := mm.FindLink(ctx, root, tc.wd, tc.findPath, &maxTraversals); err == nil { - t.Errorf("FindLink(%q, wd=%q) did not return error", tc.findPath, wdPath) - } - } -} diff --git a/pkg/sentry/fs/path_test.go b/pkg/sentry/fs/path_test.go deleted file mode 100644 index e6f57ebba..000000000 --- a/pkg/sentry/fs/path_test.go +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "testing" -) - -// TestSplitLast tests variants of path splitting. -func TestSplitLast(t *testing.T) { - cases := []struct { - path string - dir string - file string - }{ - {path: "/", dir: "/", file: "."}, - {path: "/.", dir: "/", file: "."}, - {path: "/./", dir: "/", file: "."}, - {path: "/./.", dir: "/.", file: "."}, - {path: "/././", dir: "/.", file: "."}, - {path: "/./..", dir: "/.", file: ".."}, - {path: "/./../", dir: "/.", file: ".."}, - {path: "/..", dir: "/", file: ".."}, - {path: "/../", dir: "/", file: ".."}, - {path: "/../.", dir: "/..", file: "."}, - {path: "/.././", dir: "/..", file: "."}, - {path: "/../..", dir: "/..", file: ".."}, - {path: "/../../", dir: "/..", file: ".."}, - - {path: "", dir: ".", file: "."}, - {path: ".", dir: ".", file: "."}, - {path: "./", dir: ".", file: "."}, - {path: "./.", dir: ".", file: "."}, - {path: "././", dir: ".", file: "."}, - {path: "./..", dir: ".", file: ".."}, - {path: "./../", dir: ".", file: ".."}, - {path: "..", dir: ".", file: ".."}, - {path: "../", dir: ".", file: ".."}, - {path: "../.", dir: "..", file: "."}, - {path: ".././", dir: "..", file: "."}, - {path: "../..", dir: "..", file: ".."}, - {path: "../../", dir: "..", file: ".."}, - - {path: "/foo", dir: "/", file: "foo"}, - {path: "/foo/", dir: "/", file: "foo"}, - {path: "/foo/.", dir: "/foo", file: "."}, - {path: "/foo/./", dir: "/foo", file: "."}, - {path: "/foo/./.", dir: "/foo/.", file: "."}, - {path: "/foo/./..", dir: "/foo/.", file: ".."}, - {path: "/foo/..", dir: "/foo", file: ".."}, - {path: "/foo/../", dir: "/foo", file: ".."}, - {path: "/foo/../.", dir: "/foo/..", file: "."}, - {path: "/foo/../..", dir: "/foo/..", file: ".."}, - - {path: "/foo/bar", dir: "/foo", file: "bar"}, - {path: "/foo/bar/", dir: "/foo", file: "bar"}, - {path: "/foo/bar/.", dir: "/foo/bar", file: "."}, - {path: "/foo/bar/./", dir: "/foo/bar", file: "."}, - {path: "/foo/bar/./.", dir: "/foo/bar/.", file: "."}, - {path: "/foo/bar/./..", dir: "/foo/bar/.", file: ".."}, - {path: "/foo/bar/..", dir: "/foo/bar", file: ".."}, - {path: "/foo/bar/../", dir: "/foo/bar", file: ".."}, - {path: "/foo/bar/../.", dir: "/foo/bar/..", file: "."}, - {path: "/foo/bar/../..", dir: "/foo/bar/..", file: ".."}, - - {path: "foo", dir: ".", file: "foo"}, - {path: "foo", dir: ".", file: "foo"}, - {path: "foo/", dir: ".", file: "foo"}, - {path: "foo/.", dir: "foo", file: "."}, - {path: "foo/./", dir: "foo", file: "."}, - {path: "foo/./.", dir: "foo/.", file: "."}, - {path: "foo/./..", dir: "foo/.", file: ".."}, - {path: "foo/..", dir: "foo", file: ".."}, - {path: "foo/../", dir: "foo", file: ".."}, - {path: "foo/../.", dir: "foo/..", file: "."}, - {path: "foo/../..", dir: "foo/..", file: ".."}, - {path: "foo/", dir: ".", file: "foo"}, - {path: "foo/.", dir: "foo", file: "."}, - - {path: "foo/bar", dir: "foo", file: "bar"}, - {path: "foo/bar/", dir: "foo", file: "bar"}, - {path: "foo/bar/.", dir: "foo/bar", file: "."}, - {path: "foo/bar/./", dir: "foo/bar", file: "."}, - {path: "foo/bar/./.", dir: "foo/bar/.", file: "."}, - {path: "foo/bar/./..", dir: "foo/bar/.", file: ".."}, - {path: "foo/bar/..", dir: "foo/bar", file: ".."}, - {path: "foo/bar/../", dir: "foo/bar", file: ".."}, - {path: "foo/bar/../.", dir: "foo/bar/..", file: "."}, - {path: "foo/bar/../..", dir: "foo/bar/..", file: ".."}, - {path: "foo/bar/", dir: "foo", file: "bar"}, - {path: "foo/bar/.", dir: "foo/bar", file: "."}, - } - - for _, c := range cases { - dir, file := SplitLast(c.path) - if dir != c.dir || file != c.file { - t.Errorf("SplitLast(%q) got (%q, %q), expected (%q, %q)", c.path, dir, file, c.dir, c.file) - } - } -} - -// TestSplitFirst tests variants of path splitting. -func TestSplitFirst(t *testing.T) { - cases := []struct { - path string - first string - remainder string - }{ - {path: "/", first: "/", remainder: ""}, - {path: "/.", first: "/", remainder: "."}, - {path: "///.", first: "/", remainder: "//."}, - {path: "/.///", first: "/", remainder: "."}, - {path: "/./.", first: "/", remainder: "./."}, - {path: "/././", first: "/", remainder: "./."}, - {path: "/./..", first: "/", remainder: "./.."}, - {path: "/./../", first: "/", remainder: "./.."}, - {path: "/..", first: "/", remainder: ".."}, - {path: "/../", first: "/", remainder: ".."}, - {path: "/../.", first: "/", remainder: "../."}, - {path: "/.././", first: "/", remainder: "../."}, - {path: "/../..", first: "/", remainder: "../.."}, - {path: "/../../", first: "/", remainder: "../.."}, - - {path: "", first: ".", remainder: ""}, - {path: ".", first: ".", remainder: ""}, - {path: "./", first: ".", remainder: ""}, - {path: ".///", first: ".", remainder: ""}, - {path: "./.", first: ".", remainder: "."}, - {path: "././", first: ".", remainder: "."}, - {path: "./..", first: ".", remainder: ".."}, - {path: "./../", first: ".", remainder: ".."}, - {path: "..", first: "..", remainder: ""}, - {path: "../", first: "..", remainder: ""}, - {path: "../.", first: "..", remainder: "."}, - {path: ".././", first: "..", remainder: "."}, - {path: "../..", first: "..", remainder: ".."}, - {path: "../../", first: "..", remainder: ".."}, - - {path: "/foo", first: "/", remainder: "foo"}, - {path: "/foo/", first: "/", remainder: "foo"}, - {path: "/foo///", first: "/", remainder: "foo"}, - {path: "/foo/.", first: "/", remainder: "foo/."}, - {path: "/foo/./", first: "/", remainder: "foo/."}, - {path: "/foo/./.", first: "/", remainder: "foo/./."}, - {path: "/foo/./..", first: "/", remainder: "foo/./.."}, - {path: "/foo/..", first: "/", remainder: "foo/.."}, - {path: "/foo/../", first: "/", remainder: "foo/.."}, - {path: "/foo/../.", first: "/", remainder: "foo/../."}, - {path: "/foo/../..", first: "/", remainder: "foo/../.."}, - - {path: "/foo/bar", first: "/", remainder: "foo/bar"}, - {path: "///foo/bar", first: "/", remainder: "//foo/bar"}, - {path: "/foo///bar", first: "/", remainder: "foo///bar"}, - {path: "/foo/bar/.", first: "/", remainder: "foo/bar/."}, - {path: "/foo/bar/./", first: "/", remainder: "foo/bar/."}, - {path: "/foo/bar/./.", first: "/", remainder: "foo/bar/./."}, - {path: "/foo/bar/./..", first: "/", remainder: "foo/bar/./.."}, - {path: "/foo/bar/..", first: "/", remainder: "foo/bar/.."}, - {path: "/foo/bar/../", first: "/", remainder: "foo/bar/.."}, - {path: "/foo/bar/../.", first: "/", remainder: "foo/bar/../."}, - {path: "/foo/bar/../..", first: "/", remainder: "foo/bar/../.."}, - - {path: "foo", first: "foo", remainder: ""}, - {path: "foo", first: "foo", remainder: ""}, - {path: "foo/", first: "foo", remainder: ""}, - {path: "foo///", first: "foo", remainder: ""}, - {path: "foo/.", first: "foo", remainder: "."}, - {path: "foo/./", first: "foo", remainder: "."}, - {path: "foo/./.", first: "foo", remainder: "./."}, - {path: "foo/./..", first: "foo", remainder: "./.."}, - {path: "foo/..", first: "foo", remainder: ".."}, - {path: "foo/../", first: "foo", remainder: ".."}, - {path: "foo/../.", first: "foo", remainder: "../."}, - {path: "foo/../..", first: "foo", remainder: "../.."}, - {path: "foo/", first: "foo", remainder: ""}, - {path: "foo/.", first: "foo", remainder: "."}, - - {path: "foo/bar", first: "foo", remainder: "bar"}, - {path: "foo///bar", first: "foo", remainder: "bar"}, - {path: "foo/bar/", first: "foo", remainder: "bar"}, - {path: "foo/bar/.", first: "foo", remainder: "bar/."}, - {path: "foo/bar/./", first: "foo", remainder: "bar/."}, - {path: "foo/bar/./.", first: "foo", remainder: "bar/./."}, - {path: "foo/bar/./..", first: "foo", remainder: "bar/./.."}, - {path: "foo/bar/..", first: "foo", remainder: "bar/.."}, - {path: "foo/bar/../", first: "foo", remainder: "bar/.."}, - {path: "foo/bar/../.", first: "foo", remainder: "bar/../."}, - {path: "foo/bar/../..", first: "foo", remainder: "bar/../.."}, - {path: "foo/bar/", first: "foo", remainder: "bar"}, - {path: "foo/bar/.", first: "foo", remainder: "bar/."}, - } - - for _, c := range cases { - first, remainder := SplitFirst(c.path) - if first != c.first || remainder != c.remainder { - t.Errorf("SplitFirst(%q) got (%q, %q), expected (%q, %q)", c.path, first, remainder, c.first, c.remainder) - } - } -} - -// TestIsSubpath tests the IsSubpath method. -func TestIsSubpath(t *testing.T) { - tcs := []struct { - // Two absolute paths. - pathA string - pathB string - - // Whether pathA is a subpath of pathB. - wantIsSubpath bool - - // Relative path from pathA to pathB. Only checked if - // wantIsSubpath is true. - wantRelpath string - }{ - { - pathA: "/foo/bar/baz", - pathB: "/foo", - wantIsSubpath: true, - wantRelpath: "bar/baz", - }, - { - pathA: "/foo", - pathB: "/foo/bar/baz", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foobar", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foobar", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foobar", - wantIsSubpath: false, - }, - { - pathA: "/", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/", - wantIsSubpath: true, - wantRelpath: "foo", - }, - { - pathA: "/foo/bar/../bar", - pathB: "/foo", - wantIsSubpath: true, - wantRelpath: "bar", - }, - { - pathA: "/foo/bar", - pathB: "/foo/../foo", - wantIsSubpath: true, - wantRelpath: "bar", - }, - } - - for _, tc := range tcs { - gotRelpath, gotIsSubpath := IsSubpath(tc.pathA, tc.pathB) - if gotRelpath != tc.wantRelpath || gotIsSubpath != tc.wantIsSubpath { - t.Errorf("IsSubpath(%q, %q) got %q %t, want %q %t", tc.pathA, tc.pathB, gotRelpath, gotIsSubpath, tc.wantRelpath, tc.wantIsSubpath) - } - } -} diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD deleted file mode 100644 index 77c2c5c0e..000000000 --- a/pkg/sentry/fs/proc/BUILD +++ /dev/null @@ -1,72 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "proc", - srcs = [ - "cgroup.go", - "cpuinfo.go", - "exec_args.go", - "fds.go", - "filesystems.go", - "fs.go", - "inode.go", - "loadavg.go", - "meminfo.go", - "mounts.go", - "net.go", - "proc.go", - "stat.go", - "sys.go", - "sys_net.go", - "sys_net_state.go", - "task.go", - "uid_gid_map.go", - "uptime.go", - "version.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/proc/device", - "//pkg/sentry/fs/proc/seqfile", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fsbridge", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/mm", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/syserror", - "//pkg/tcpip/header", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "proc_test", - size = "small", - srcs = [ - "net_test.go", - "sys_net_test.go", - ], - library = ":proc", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/inet", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md deleted file mode 100644 index 6667a0916..000000000 --- a/pkg/sentry/fs/proc/README.md +++ /dev/null @@ -1,336 +0,0 @@ -This document tracks what is implemented in procfs. Refer to -Documentation/filesystems/proc.txt in the Linux project for information about -procfs generally. - -**NOTE**: This document is not guaranteed to be up to date. If you find an -inconsistency, please file a bug. - -[TOC] - -## Kernel data - -The following files are implemented: - -<!-- mdformat off(don't wrap the table) --> - -| File /proc/ | Content | -| :------------------------ | :---------------------------------------------------- | -| [cpuinfo](#cpuinfo) | Info about the CPU | -| [filesystems](#filesystems) | Supported filesystems | -| [loadavg](#loadavg) | Load average of last 1, 5 & 15 minutes | -| [meminfo](#meminfo) | Overall memory info | -| [stat](#stat) | Overall kernel statistics | -| [sys](#sys) | Change parameters within the kernel | -| [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus | -| [version](#version) | Kernel version | - -<!-- mdformat on --> - -### cpuinfo - -```bash -$ cat /proc/cpuinfo -processor : 0 -vendor_id : GenuineIntel -cpu family : 6 -model : 45 -model name : unknown -stepping : unknown -cpu MHz : 1234.588 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic popcnt tsc_deadline_timer aes xsave avx xsaveopt -bogomips : 1234.59 -clflush size : 64 -cache_alignment : 64 -address sizes : 46 bits physical, 48 bits virtual -power management: - -... -``` - -Notable divergences: - -Field name | Notes -:--------------- | :--------------------------------------- -model name | Always unknown -stepping | Always unknown -fpu | Always yes -fpu_exception | Always yes -wp | Always yes -bogomips | Bogus value (matches cpu MHz) -clflush size | Always 64 -cache_alignment | Always 64 -address sizes | Always 46 bits physical, 48 bits virtual -power management | Always blank - -Otherwise fields are derived from the sentry configuration. - -### filesystems - -```bash -$ cat /proc/filesystems -nodev 9p -nodev devpts -nodev devtmpfs -nodev proc -nodev sysfs -nodev tmpfs -``` - -### loadavg - -```bash -$ cat /proc/loadavg -0.00 0.00 0.00 0/0 0 -``` - -Column | Notes -:------------------------------------ | :---------- -CPU.IO utilization in last 1 minute | Always zero -CPU.IO utilization in last 5 minutes | Always zero -CPU.IO utilization in last 10 minutes | Always zero -Num currently running processes | Always zero -Total num processes | Always zero - -TODO(b/62345059): Populate the columns with accurate statistics. - -### meminfo - -```bash -$ cat /proc/meminfo -MemTotal: 2097152 kB -MemFree: 2083540 kB -MemAvailable: 2083540 kB -Buffers: 0 kB -Cached: 4428 kB -SwapCache: 0 kB -Active: 10812 kB -Inactive: 2216 kB -Active(anon): 8600 kB -Inactive(anon): 0 kB -Active(file): 2212 kB -Inactive(file): 2216 kB -Unevictable: 0 kB -Mlocked: 0 kB -SwapTotal: 0 kB -SwapFree: 0 kB -Dirty: 0 kB -Writeback: 0 kB -AnonPages: 8600 kB -Mapped: 4428 kB -Shmem: 0 kB - -``` - -Notable divergences: - -Field name | Notes -:---------------- | :----------------------------------------------------- -Buffers | Always zero, no block devices -SwapCache | Always zero, no swap -Inactive(anon) | Always zero, see SwapCache -Unevictable | Always zero TODO(b/31823263) -Mlocked | Always zero TODO(b/31823263) -SwapTotal | Always zero, no swap -SwapFree | Always zero, no swap -Dirty | Always zero TODO(b/31823263) -Writeback | Always zero TODO(b/31823263) -MemAvailable | Uses the same value as MemFree since there is no swap. -Slab | Missing -SReclaimable | Missing -SUnreclaim | Missing -KernelStack | Missing -PageTables | Missing -NFS_Unstable | Missing -Bounce | Missing -WritebackTmp | Missing -CommitLimit | Missing -Committed_AS | Missing -VmallocTotal | Missing -VmallocUsed | Missing -VmallocChunk | Missing -HardwareCorrupted | Missing -AnonHugePages | Missing -ShmemHugePages | Missing -ShmemPmdMapped | Missing -HugePages_Total | Missing -HugePages_Free | Missing -HugePages_Rsvd | Missing -HugePages_Surp | Missing -Hugepagesize | Missing -DirectMap4k | Missing -DirectMap2M | Missing -DirectMap1G | Missing - -### stat - -```bash -$ cat /proc/stat -cpu 0 0 0 0 0 0 0 0 0 0 -cpu0 0 0 0 0 0 0 0 0 0 0 -cpu1 0 0 0 0 0 0 0 0 0 0 -cpu2 0 0 0 0 0 0 0 0 0 0 -cpu3 0 0 0 0 0 0 0 0 0 0 -cpu4 0 0 0 0 0 0 0 0 0 0 -cpu5 0 0 0 0 0 0 0 0 0 0 -cpu6 0 0 0 0 0 0 0 0 0 0 -cpu7 0 0 0 0 0 0 0 0 0 0 -intr 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -ctxt 0 -btime 1504040968 -processes 0 -procs_running 0 -procs_blokkcked 0 -softirq 0 0 0 0 0 0 0 0 0 0 0 -``` - -All fields except for `btime` are always zero. - -TODO(b/37226836): Populate with accurate fields. - -### sys - -```bash -$ ls /proc/sys -kernel vm -``` - -Directory | Notes -:-------- | :---------------------------- -abi | Missing -debug | Missing -dev | Missing -fs | Missing -kernel | Contains hostname (only) -net | Missing -user | Missing -vm | Contains mmap_min_addr (only) - -### uptime - -```bash -$ cat /proc/uptime -3204.62 0.00 -``` - -Column | Notes -:------------------------------- | :---------------------------- -Total num seconds system running | Time since procfs was mounted -Number of seconds idle | Always zero - -### version - -```bash -$ cat /proc/version -Linux version 4.4 #1 SMP Sun Jan 10 15:06:54 PST 2016 -``` - -## Process-specific data - -The following files are implemented: - -File /proc/PID | Content -:---------------------- | :--------------------------------------------------- -[auxv](#auxv) | Copy of auxiliary vector for the process -[cmdline](#cmdline) | Command line arguments -[comm](#comm) | Command name associated with the process -[environ](#environ) | Process environment -[exe](#exe) | Symlink to the process's executable -[fd](#fd) | Directory containing links to open file descriptors -[fdinfo](#fdinfo) | Information associated with open file descriptors -[gid_map](#gid_map) | Mappings for group IDs inside the user namespace -[io](#io) | IO statistics -[maps](#maps) | Memory mappings (anon, executables, library files) -[mounts](#mounts) | Mounted filesystems -[mountinfo](#mountinfo) | Information about mounts -[ns](#ns) | Directory containing info about supported namespaces -[stat](#stat) | Process statistics -[statm](#statm) | Process memory statistics -[status](#status) | Process status in human readable format -[task](#task) | Directory containing info about running threads -[uid_map](#uid_map) | Mappings for user IDs inside the user namespace - -### auxv - -TODO - -### cmdline - -TODO - -### comm - -TODO - -### environment - -TODO - -### exe - -TODO - -### fd - -TODO - -### fdinfo - -TODO - -### gid_map - -TODO - -### io - -Only has data for rchar, wchar, syscr, and syscw. - -TODO: add more detail. - -### maps - -TODO - -### mounts - -TODO - -### mountinfo - -TODO - -### ns - -TODO - -### stat - -Only has data for pid, comm, state, ppid, utime, stime, cutime, cstime, -num_threads, and exit_signal. - -TODO: add more detail. - -### statm - -Only has data for vss and rss. - -TODO: add more detail. - -### status - -Contains data for Name, State, Tgid, Pid, Ppid, TracerPid, FDSize, VmSize, -VmRSS, Threads, CapInh, CapPrm, CapEff, CapBnd, Seccomp. - -TODO: add more detail. - -### task - -TODO - -### uid_map - -TODO diff --git a/pkg/sentry/fs/proc/device/BUILD b/pkg/sentry/fs/proc/device/BUILD deleted file mode 100644 index 52c9aa93d..000000000 --- a/pkg/sentry/fs/proc/device/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "device", - srcs = ["device.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/sentry/device"], -) diff --git a/pkg/sentry/fs/proc/device/device_state_autogen.go b/pkg/sentry/fs/proc/device/device_state_autogen.go new file mode 100755 index 000000000..4a5e3cc88 --- /dev/null +++ b/pkg/sentry/fs/proc/device/device_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package device diff --git a/pkg/sentry/fs/proc/net_test.go b/pkg/sentry/fs/proc/net_test.go deleted file mode 100644 index f18681405..000000000 --- a/pkg/sentry/fs/proc/net_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/inet" -) - -func newIPv6TestStack() *inet.TestStack { - s := inet.NewTestStack() - s.SupportsIPv6Flag = true - return s -} - -func TestIfinet6NoAddresses(t *testing.T) { - n := &ifinet6{s: newIPv6TestStack()} - if got := n.contents(); got != nil { - t.Errorf("Got n.contents() = %v, want = %v", got, nil) - } -} - -func TestIfinet6(t *testing.T) { - s := newIPv6TestStack() - s.InterfacesMap[1] = inet.Interface{Name: "eth0"} - s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"), - }, - } - s.InterfacesMap[2] = inet.Interface{Name: "eth1"} - s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"), - }, - } - want := map[string]struct{}{ - "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {}, - "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {}, - } - - n := &ifinet6{s: s} - contents := n.contents() - if len(contents) != len(want) { - t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want)) - } - got := map[string]struct{}{} - for _, l := range contents { - got[l] = struct{}{} - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("Got n.contents() = %v, want = %v", got, want) - } -} diff --git a/pkg/sentry/fs/proc/proc_state_autogen.go b/pkg/sentry/fs/proc/proc_state_autogen.go new file mode 100755 index 000000000..45f2b0a40 --- /dev/null +++ b/pkg/sentry/fs/proc/proc_state_autogen.go @@ -0,0 +1,743 @@ +// automatically generated by stateify. + +package proc + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *execArgInode) beforeSave() {} +func (x *execArgInode) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("arg", &x.arg) + m.Save("t", &x.t) +} + +func (x *execArgInode) afterLoad() {} +func (x *execArgInode) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("arg", &x.arg) + m.Load("t", &x.t) +} + +func (x *execArgFile) beforeSave() {} +func (x *execArgFile) save(m state.Map) { + x.beforeSave() + m.Save("arg", &x.arg) + m.Save("t", &x.t) +} + +func (x *execArgFile) afterLoad() {} +func (x *execArgFile) load(m state.Map) { + m.Load("arg", &x.arg) + m.Load("t", &x.t) +} + +func (x *fdDir) beforeSave() {} +func (x *fdDir) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("t", &x.t) +} + +func (x *fdDir) afterLoad() {} +func (x *fdDir) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("t", &x.t) +} + +func (x *fdDirFile) beforeSave() {} +func (x *fdDirFile) save(m state.Map) { + x.beforeSave() + m.Save("isInfoFile", &x.isInfoFile) + m.Save("t", &x.t) +} + +func (x *fdDirFile) afterLoad() {} +func (x *fdDirFile) load(m state.Map) { + m.Load("isInfoFile", &x.isInfoFile) + m.Load("t", &x.t) +} + +func (x *fdInfoDir) beforeSave() {} +func (x *fdInfoDir) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("t", &x.t) +} + +func (x *fdInfoDir) afterLoad() {} +func (x *fdInfoDir) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("t", &x.t) +} + +func (x *filesystemsData) beforeSave() {} +func (x *filesystemsData) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystemsData) afterLoad() {} +func (x *filesystemsData) load(m state.Map) { +} + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func (x *taskOwnedInodeOps) beforeSave() {} +func (x *taskOwnedInodeOps) save(m state.Map) { + x.beforeSave() + m.Save("InodeOperations", &x.InodeOperations) + m.Save("t", &x.t) +} + +func (x *taskOwnedInodeOps) afterLoad() {} +func (x *taskOwnedInodeOps) load(m state.Map) { + m.Load("InodeOperations", &x.InodeOperations) + m.Load("t", &x.t) +} + +func (x *staticFileInodeOps) beforeSave() {} +func (x *staticFileInodeOps) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeStaticFileGetter", &x.InodeStaticFileGetter) +} + +func (x *staticFileInodeOps) afterLoad() {} +func (x *staticFileInodeOps) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeStaticFileGetter", &x.InodeStaticFileGetter) +} + +func (x *loadavgData) beforeSave() {} +func (x *loadavgData) save(m state.Map) { + x.beforeSave() +} + +func (x *loadavgData) afterLoad() {} +func (x *loadavgData) load(m state.Map) { +} + +func (x *meminfoData) beforeSave() {} +func (x *meminfoData) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *meminfoData) afterLoad() {} +func (x *meminfoData) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *mountInfoFile) beforeSave() {} +func (x *mountInfoFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *mountInfoFile) afterLoad() {} +func (x *mountInfoFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *mountsFile) beforeSave() {} +func (x *mountsFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *mountsFile) afterLoad() {} +func (x *mountsFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *ifinet6) beforeSave() {} +func (x *ifinet6) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *ifinet6) afterLoad() {} +func (x *ifinet6) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *netDev) beforeSave() {} +func (x *netDev) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *netDev) afterLoad() {} +func (x *netDev) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *netSnmp) beforeSave() {} +func (x *netSnmp) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *netSnmp) afterLoad() {} +func (x *netSnmp) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *netRoute) beforeSave() {} +func (x *netRoute) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *netRoute) afterLoad() {} +func (x *netRoute) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *netUnix) beforeSave() {} +func (x *netUnix) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *netUnix) afterLoad() {} +func (x *netUnix) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *netTCP) beforeSave() {} +func (x *netTCP) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *netTCP) afterLoad() {} +func (x *netTCP) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *netTCP6) beforeSave() {} +func (x *netTCP6) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *netTCP6) afterLoad() {} +func (x *netTCP6) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *netUDP) beforeSave() {} +func (x *netUDP) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *netUDP) afterLoad() {} +func (x *netUDP) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *proc) beforeSave() {} +func (x *proc) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("k", &x.k) + m.Save("pidns", &x.pidns) + m.Save("cgroupControllers", &x.cgroupControllers) +} + +func (x *proc) afterLoad() {} +func (x *proc) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("k", &x.k) + m.Load("pidns", &x.pidns) + m.Load("cgroupControllers", &x.cgroupControllers) +} + +func (x *self) beforeSave() {} +func (x *self) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) + m.Save("pidns", &x.pidns) +} + +func (x *self) afterLoad() {} +func (x *self) load(m state.Map) { + m.Load("Symlink", &x.Symlink) + m.Load("pidns", &x.pidns) +} + +func (x *threadSelf) beforeSave() {} +func (x *threadSelf) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) + m.Save("pidns", &x.pidns) +} + +func (x *threadSelf) afterLoad() {} +func (x *threadSelf) load(m state.Map) { + m.Load("Symlink", &x.Symlink) + m.Load("pidns", &x.pidns) +} + +func (x *rootProcFile) beforeSave() {} +func (x *rootProcFile) save(m state.Map) { + x.beforeSave() + m.Save("iops", &x.iops) +} + +func (x *rootProcFile) afterLoad() {} +func (x *rootProcFile) load(m state.Map) { + m.Load("iops", &x.iops) +} + +func (x *statData) beforeSave() {} +func (x *statData) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *statData) afterLoad() {} +func (x *statData) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *mmapMinAddrData) beforeSave() {} +func (x *mmapMinAddrData) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *mmapMinAddrData) afterLoad() {} +func (x *mmapMinAddrData) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *overcommitMemory) beforeSave() {} +func (x *overcommitMemory) save(m state.Map) { + x.beforeSave() +} + +func (x *overcommitMemory) afterLoad() {} +func (x *overcommitMemory) load(m state.Map) { +} + +func (x *hostname) beforeSave() {} +func (x *hostname) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) +} + +func (x *hostname) afterLoad() {} +func (x *hostname) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) +} + +func (x *hostnameFile) beforeSave() {} +func (x *hostnameFile) save(m state.Map) { + x.beforeSave() +} + +func (x *hostnameFile) afterLoad() {} +func (x *hostnameFile) load(m state.Map) { +} + +func (x *tcpMemInode) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("dir", &x.dir) + m.Save("s", &x.s) + m.Save("size", &x.size) +} + +func (x *tcpMemInode) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("dir", &x.dir) + m.LoadWait("s", &x.s) + m.Load("size", &x.size) + m.AfterLoad(x.afterLoad) +} + +func (x *tcpMemFile) beforeSave() {} +func (x *tcpMemFile) save(m state.Map) { + x.beforeSave() + m.Save("tcpMemInode", &x.tcpMemInode) +} + +func (x *tcpMemFile) afterLoad() {} +func (x *tcpMemFile) load(m state.Map) { + m.Load("tcpMemInode", &x.tcpMemInode) +} + +func (x *tcpSack) beforeSave() {} +func (x *tcpSack) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("stack", &x.stack) + m.Save("enabled", &x.enabled) +} + +func (x *tcpSack) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.LoadWait("stack", &x.stack) + m.Load("enabled", &x.enabled) + m.AfterLoad(x.afterLoad) +} + +func (x *tcpSackFile) beforeSave() {} +func (x *tcpSackFile) save(m state.Map) { + x.beforeSave() + m.Save("tcpSack", &x.tcpSack) + m.Save("stack", &x.stack) +} + +func (x *tcpSackFile) afterLoad() {} +func (x *tcpSackFile) load(m state.Map) { + m.Load("tcpSack", &x.tcpSack) + m.LoadWait("stack", &x.stack) +} + +func (x *taskDir) beforeSave() {} +func (x *taskDir) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("t", &x.t) + m.Save("pidns", &x.pidns) +} + +func (x *taskDir) afterLoad() {} +func (x *taskDir) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("t", &x.t) + m.Load("pidns", &x.pidns) +} + +func (x *subtasks) beforeSave() {} +func (x *subtasks) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("t", &x.t) + m.Save("p", &x.p) +} + +func (x *subtasks) afterLoad() {} +func (x *subtasks) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("t", &x.t) + m.Load("p", &x.p) +} + +func (x *subtasksFile) beforeSave() {} +func (x *subtasksFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("pidns", &x.pidns) +} + +func (x *subtasksFile) afterLoad() {} +func (x *subtasksFile) load(m state.Map) { + m.Load("t", &x.t) + m.Load("pidns", &x.pidns) +} + +func (x *exe) beforeSave() {} +func (x *exe) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) + m.Save("t", &x.t) +} + +func (x *exe) afterLoad() {} +func (x *exe) load(m state.Map) { + m.Load("Symlink", &x.Symlink) + m.Load("t", &x.t) +} + +func (x *namespaceSymlink) beforeSave() {} +func (x *namespaceSymlink) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) + m.Save("t", &x.t) +} + +func (x *namespaceSymlink) afterLoad() {} +func (x *namespaceSymlink) load(m state.Map) { + m.Load("Symlink", &x.Symlink) + m.Load("t", &x.t) +} + +func (x *mapsData) beforeSave() {} +func (x *mapsData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *mapsData) afterLoad() {} +func (x *mapsData) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *smapsData) beforeSave() {} +func (x *smapsData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *smapsData) afterLoad() {} +func (x *smapsData) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *taskStatData) beforeSave() {} +func (x *taskStatData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("tgstats", &x.tgstats) + m.Save("pidns", &x.pidns) +} + +func (x *taskStatData) afterLoad() {} +func (x *taskStatData) load(m state.Map) { + m.Load("t", &x.t) + m.Load("tgstats", &x.tgstats) + m.Load("pidns", &x.pidns) +} + +func (x *statmData) beforeSave() {} +func (x *statmData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *statmData) afterLoad() {} +func (x *statmData) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *statusData) beforeSave() {} +func (x *statusData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("pidns", &x.pidns) +} + +func (x *statusData) afterLoad() {} +func (x *statusData) load(m state.Map) { + m.Load("t", &x.t) + m.Load("pidns", &x.pidns) +} + +func (x *ioData) beforeSave() {} +func (x *ioData) save(m state.Map) { + x.beforeSave() + m.Save("ioUsage", &x.ioUsage) +} + +func (x *ioData) afterLoad() {} +func (x *ioData) load(m state.Map) { + m.Load("ioUsage", &x.ioUsage) +} + +func (x *comm) beforeSave() {} +func (x *comm) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("t", &x.t) +} + +func (x *comm) afterLoad() {} +func (x *comm) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("t", &x.t) +} + +func (x *commFile) beforeSave() {} +func (x *commFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *commFile) afterLoad() {} +func (x *commFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *auxvec) beforeSave() {} +func (x *auxvec) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("t", &x.t) +} + +func (x *auxvec) afterLoad() {} +func (x *auxvec) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("t", &x.t) +} + +func (x *auxvecFile) beforeSave() {} +func (x *auxvecFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *auxvecFile) afterLoad() {} +func (x *auxvecFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *oomScoreAdj) beforeSave() {} +func (x *oomScoreAdj) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("t", &x.t) +} + +func (x *oomScoreAdj) afterLoad() {} +func (x *oomScoreAdj) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("t", &x.t) +} + +func (x *oomScoreAdjFile) beforeSave() {} +func (x *oomScoreAdjFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *oomScoreAdjFile) afterLoad() {} +func (x *oomScoreAdjFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *idMapInodeOperations) beforeSave() {} +func (x *idMapInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("t", &x.t) + m.Save("gids", &x.gids) +} + +func (x *idMapInodeOperations) afterLoad() {} +func (x *idMapInodeOperations) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("t", &x.t) + m.Load("gids", &x.gids) +} + +func (x *idMapFileOperations) beforeSave() {} +func (x *idMapFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("iops", &x.iops) +} + +func (x *idMapFileOperations) afterLoad() {} +func (x *idMapFileOperations) load(m state.Map) { + m.Load("iops", &x.iops) +} + +func (x *uptime) beforeSave() {} +func (x *uptime) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("startTime", &x.startTime) +} + +func (x *uptime) afterLoad() {} +func (x *uptime) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("startTime", &x.startTime) +} + +func (x *uptimeFile) beforeSave() {} +func (x *uptimeFile) save(m state.Map) { + x.beforeSave() + m.Save("startTime", &x.startTime) +} + +func (x *uptimeFile) afterLoad() {} +func (x *uptimeFile) load(m state.Map) { + m.Load("startTime", &x.startTime) +} + +func (x *versionData) beforeSave() {} +func (x *versionData) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *versionData) afterLoad() {} +func (x *versionData) load(m state.Map) { + m.Load("k", &x.k) +} + +func init() { + state.Register("pkg/sentry/fs/proc.execArgInode", (*execArgInode)(nil), state.Fns{Save: (*execArgInode).save, Load: (*execArgInode).load}) + state.Register("pkg/sentry/fs/proc.execArgFile", (*execArgFile)(nil), state.Fns{Save: (*execArgFile).save, Load: (*execArgFile).load}) + state.Register("pkg/sentry/fs/proc.fdDir", (*fdDir)(nil), state.Fns{Save: (*fdDir).save, Load: (*fdDir).load}) + state.Register("pkg/sentry/fs/proc.fdDirFile", (*fdDirFile)(nil), state.Fns{Save: (*fdDirFile).save, Load: (*fdDirFile).load}) + state.Register("pkg/sentry/fs/proc.fdInfoDir", (*fdInfoDir)(nil), state.Fns{Save: (*fdInfoDir).save, Load: (*fdInfoDir).load}) + state.Register("pkg/sentry/fs/proc.filesystemsData", (*filesystemsData)(nil), state.Fns{Save: (*filesystemsData).save, Load: (*filesystemsData).load}) + state.Register("pkg/sentry/fs/proc.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) + state.Register("pkg/sentry/fs/proc.taskOwnedInodeOps", (*taskOwnedInodeOps)(nil), state.Fns{Save: (*taskOwnedInodeOps).save, Load: (*taskOwnedInodeOps).load}) + state.Register("pkg/sentry/fs/proc.staticFileInodeOps", (*staticFileInodeOps)(nil), state.Fns{Save: (*staticFileInodeOps).save, Load: (*staticFileInodeOps).load}) + state.Register("pkg/sentry/fs/proc.loadavgData", (*loadavgData)(nil), state.Fns{Save: (*loadavgData).save, Load: (*loadavgData).load}) + state.Register("pkg/sentry/fs/proc.meminfoData", (*meminfoData)(nil), state.Fns{Save: (*meminfoData).save, Load: (*meminfoData).load}) + state.Register("pkg/sentry/fs/proc.mountInfoFile", (*mountInfoFile)(nil), state.Fns{Save: (*mountInfoFile).save, Load: (*mountInfoFile).load}) + state.Register("pkg/sentry/fs/proc.mountsFile", (*mountsFile)(nil), state.Fns{Save: (*mountsFile).save, Load: (*mountsFile).load}) + state.Register("pkg/sentry/fs/proc.ifinet6", (*ifinet6)(nil), state.Fns{Save: (*ifinet6).save, Load: (*ifinet6).load}) + state.Register("pkg/sentry/fs/proc.netDev", (*netDev)(nil), state.Fns{Save: (*netDev).save, Load: (*netDev).load}) + state.Register("pkg/sentry/fs/proc.netSnmp", (*netSnmp)(nil), state.Fns{Save: (*netSnmp).save, Load: (*netSnmp).load}) + state.Register("pkg/sentry/fs/proc.netRoute", (*netRoute)(nil), state.Fns{Save: (*netRoute).save, Load: (*netRoute).load}) + state.Register("pkg/sentry/fs/proc.netUnix", (*netUnix)(nil), state.Fns{Save: (*netUnix).save, Load: (*netUnix).load}) + state.Register("pkg/sentry/fs/proc.netTCP", (*netTCP)(nil), state.Fns{Save: (*netTCP).save, Load: (*netTCP).load}) + state.Register("pkg/sentry/fs/proc.netTCP6", (*netTCP6)(nil), state.Fns{Save: (*netTCP6).save, Load: (*netTCP6).load}) + state.Register("pkg/sentry/fs/proc.netUDP", (*netUDP)(nil), state.Fns{Save: (*netUDP).save, Load: (*netUDP).load}) + state.Register("pkg/sentry/fs/proc.proc", (*proc)(nil), state.Fns{Save: (*proc).save, Load: (*proc).load}) + state.Register("pkg/sentry/fs/proc.self", (*self)(nil), state.Fns{Save: (*self).save, Load: (*self).load}) + state.Register("pkg/sentry/fs/proc.threadSelf", (*threadSelf)(nil), state.Fns{Save: (*threadSelf).save, Load: (*threadSelf).load}) + state.Register("pkg/sentry/fs/proc.rootProcFile", (*rootProcFile)(nil), state.Fns{Save: (*rootProcFile).save, Load: (*rootProcFile).load}) + state.Register("pkg/sentry/fs/proc.statData", (*statData)(nil), state.Fns{Save: (*statData).save, Load: (*statData).load}) + state.Register("pkg/sentry/fs/proc.mmapMinAddrData", (*mmapMinAddrData)(nil), state.Fns{Save: (*mmapMinAddrData).save, Load: (*mmapMinAddrData).load}) + state.Register("pkg/sentry/fs/proc.overcommitMemory", (*overcommitMemory)(nil), state.Fns{Save: (*overcommitMemory).save, Load: (*overcommitMemory).load}) + state.Register("pkg/sentry/fs/proc.hostname", (*hostname)(nil), state.Fns{Save: (*hostname).save, Load: (*hostname).load}) + state.Register("pkg/sentry/fs/proc.hostnameFile", (*hostnameFile)(nil), state.Fns{Save: (*hostnameFile).save, Load: (*hostnameFile).load}) + state.Register("pkg/sentry/fs/proc.tcpMemInode", (*tcpMemInode)(nil), state.Fns{Save: (*tcpMemInode).save, Load: (*tcpMemInode).load}) + state.Register("pkg/sentry/fs/proc.tcpMemFile", (*tcpMemFile)(nil), state.Fns{Save: (*tcpMemFile).save, Load: (*tcpMemFile).load}) + state.Register("pkg/sentry/fs/proc.tcpSack", (*tcpSack)(nil), state.Fns{Save: (*tcpSack).save, Load: (*tcpSack).load}) + state.Register("pkg/sentry/fs/proc.tcpSackFile", (*tcpSackFile)(nil), state.Fns{Save: (*tcpSackFile).save, Load: (*tcpSackFile).load}) + state.Register("pkg/sentry/fs/proc.taskDir", (*taskDir)(nil), state.Fns{Save: (*taskDir).save, Load: (*taskDir).load}) + state.Register("pkg/sentry/fs/proc.subtasks", (*subtasks)(nil), state.Fns{Save: (*subtasks).save, Load: (*subtasks).load}) + state.Register("pkg/sentry/fs/proc.subtasksFile", (*subtasksFile)(nil), state.Fns{Save: (*subtasksFile).save, Load: (*subtasksFile).load}) + state.Register("pkg/sentry/fs/proc.exe", (*exe)(nil), state.Fns{Save: (*exe).save, Load: (*exe).load}) + state.Register("pkg/sentry/fs/proc.namespaceSymlink", (*namespaceSymlink)(nil), state.Fns{Save: (*namespaceSymlink).save, Load: (*namespaceSymlink).load}) + state.Register("pkg/sentry/fs/proc.mapsData", (*mapsData)(nil), state.Fns{Save: (*mapsData).save, Load: (*mapsData).load}) + state.Register("pkg/sentry/fs/proc.smapsData", (*smapsData)(nil), state.Fns{Save: (*smapsData).save, Load: (*smapsData).load}) + state.Register("pkg/sentry/fs/proc.taskStatData", (*taskStatData)(nil), state.Fns{Save: (*taskStatData).save, Load: (*taskStatData).load}) + state.Register("pkg/sentry/fs/proc.statmData", (*statmData)(nil), state.Fns{Save: (*statmData).save, Load: (*statmData).load}) + state.Register("pkg/sentry/fs/proc.statusData", (*statusData)(nil), state.Fns{Save: (*statusData).save, Load: (*statusData).load}) + state.Register("pkg/sentry/fs/proc.ioData", (*ioData)(nil), state.Fns{Save: (*ioData).save, Load: (*ioData).load}) + state.Register("pkg/sentry/fs/proc.comm", (*comm)(nil), state.Fns{Save: (*comm).save, Load: (*comm).load}) + state.Register("pkg/sentry/fs/proc.commFile", (*commFile)(nil), state.Fns{Save: (*commFile).save, Load: (*commFile).load}) + state.Register("pkg/sentry/fs/proc.auxvec", (*auxvec)(nil), state.Fns{Save: (*auxvec).save, Load: (*auxvec).load}) + state.Register("pkg/sentry/fs/proc.auxvecFile", (*auxvecFile)(nil), state.Fns{Save: (*auxvecFile).save, Load: (*auxvecFile).load}) + state.Register("pkg/sentry/fs/proc.oomScoreAdj", (*oomScoreAdj)(nil), state.Fns{Save: (*oomScoreAdj).save, Load: (*oomScoreAdj).load}) + state.Register("pkg/sentry/fs/proc.oomScoreAdjFile", (*oomScoreAdjFile)(nil), state.Fns{Save: (*oomScoreAdjFile).save, Load: (*oomScoreAdjFile).load}) + state.Register("pkg/sentry/fs/proc.idMapInodeOperations", (*idMapInodeOperations)(nil), state.Fns{Save: (*idMapInodeOperations).save, Load: (*idMapInodeOperations).load}) + state.Register("pkg/sentry/fs/proc.idMapFileOperations", (*idMapFileOperations)(nil), state.Fns{Save: (*idMapFileOperations).save, Load: (*idMapFileOperations).load}) + state.Register("pkg/sentry/fs/proc.uptime", (*uptime)(nil), state.Fns{Save: (*uptime).save, Load: (*uptime).load}) + state.Register("pkg/sentry/fs/proc.uptimeFile", (*uptimeFile)(nil), state.Fns{Save: (*uptimeFile).save, Load: (*uptimeFile).load}) + state.Register("pkg/sentry/fs/proc.versionData", (*versionData)(nil), state.Fns{Save: (*versionData).save, Load: (*versionData).load}) +} diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD deleted file mode 100644 index 21338d912..000000000 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "seqfile", - srcs = ["seqfile.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/proc/device", - "//pkg/sentry/kernel/time", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "seqfile_test", - size = "small", - srcs = ["seqfile_test.go"], - library = ":seqfile", - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/ramfs", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go b/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go new file mode 100755 index 000000000..cfd3a40b4 --- /dev/null +++ b/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go @@ -0,0 +1,58 @@ +// automatically generated by stateify. + +package seqfile + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SeqData) beforeSave() {} +func (x *SeqData) save(m state.Map) { + x.beforeSave() + m.Save("Buf", &x.Buf) + m.Save("Handle", &x.Handle) +} + +func (x *SeqData) afterLoad() {} +func (x *SeqData) load(m state.Map) { + m.Load("Buf", &x.Buf) + m.Load("Handle", &x.Handle) +} + +func (x *SeqFile) beforeSave() {} +func (x *SeqFile) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("SeqSource", &x.SeqSource) + m.Save("source", &x.source) + m.Save("generation", &x.generation) + m.Save("lastRead", &x.lastRead) +} + +func (x *SeqFile) afterLoad() {} +func (x *SeqFile) load(m state.Map) { + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("SeqSource", &x.SeqSource) + m.Load("source", &x.source) + m.Load("generation", &x.generation) + m.Load("lastRead", &x.lastRead) +} + +func (x *seqFileOperations) beforeSave() {} +func (x *seqFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("seqFile", &x.seqFile) +} + +func (x *seqFileOperations) afterLoad() {} +func (x *seqFileOperations) load(m state.Map) { + m.Load("seqFile", &x.seqFile) +} + +func init() { + state.Register("pkg/sentry/fs/proc/seqfile.SeqData", (*SeqData)(nil), state.Fns{Save: (*SeqData).save, Load: (*SeqData).load}) + state.Register("pkg/sentry/fs/proc/seqfile.SeqFile", (*SeqFile)(nil), state.Fns{Save: (*SeqFile).save, Load: (*SeqFile).load}) + state.Register("pkg/sentry/fs/proc/seqfile.seqFileOperations", (*seqFileOperations)(nil), state.Fns{Save: (*seqFileOperations).save, Load: (*seqFileOperations).load}) +} diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_test.go b/pkg/sentry/fs/proc/seqfile/seqfile_test.go deleted file mode 100644 index 98e394569..000000000 --- a/pkg/sentry/fs/proc/seqfile/seqfile_test.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqfile - -import ( - "bytes" - "fmt" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -type seqTest struct { - actual []SeqData - update bool -} - -func (s *seqTest) Init() { - var sq []SeqData - // Create some SeqData. - for i := 0; i < 10; i++ { - var b []byte - for j := 0; j < 10; j++ { - b = append(b, byte(i)) - } - sq = append(sq, SeqData{ - Buf: b, - Handle: &testHandle{i: i}, - }) - } - s.actual = sq -} - -// NeedsUpdate reports whether we need to update the data we've previously read. -func (s *seqTest) NeedsUpdate(int64) bool { - return s.update -} - -// ReadSeqFiledata returns a slice of SeqData which contains elements -// greater than the handle. -func (s *seqTest) ReadSeqFileData(ctx context.Context, handle SeqHandle) ([]SeqData, int64) { - if handle == nil { - return s.actual, 0 - } - h := *handle.(*testHandle) - var ret []SeqData - for _, b := range s.actual { - // We want the next one. - h2 := *b.Handle.(*testHandle) - if h2.i > h.i { - ret = append(ret, b) - } - } - return ret, 0 -} - -// Flatten a slice of slices into one slice. -func flatten(buf ...[]byte) []byte { - var flat []byte - for _, b := range buf { - flat = append(flat, b...) - } - return flat -} - -type testHandle struct { - i int -} - -type testTable struct { - offset int64 - readBufferSize int - expectedData []byte - expectedError error -} - -func runTableTests(ctx context.Context, table []testTable, dirent *fs.Dirent) error { - for _, tt := range table { - file, err := dirent.Inode.InodeOperations.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - return fmt.Errorf("GetFile returned error: %v", err) - } - - data := make([]byte, tt.readBufferSize) - resultLen, err := file.Preadv(ctx, usermem.BytesIOSequence(data), tt.offset) - if err != tt.expectedError { - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (error) => %v expected %v", tt.readBufferSize, tt.offset, err, tt.expectedError) - } - expectedLen := int64(len(tt.expectedData)) - if resultLen != expectedLen { - // We make this just an error so we wall through and print the data below. - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (size) => %v expected %v", tt.readBufferSize, tt.offset, resultLen, expectedLen) - } - if !bytes.Equal(data[:expectedLen], tt.expectedData) { - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (data) => %v expected %v", tt.readBufferSize, tt.offset, data[:expectedLen], tt.expectedData) - } - } - return nil -} - -func TestSeqFile(t *testing.T) { - testSource := &seqTest{} - testSource.Init() - - // Create a file that can be R/W. - ctx := contexttest.Context(t) - m := fs.NewPseudoMountSource(ctx) - contents := map[string]*fs.Inode{ - "foo": NewSeqFileInode(ctx, testSource, m), - } - root := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0777)) - - // How about opening it? - inode := fs.NewInode(ctx, root, m, fs.StableAttr{Type: fs.Directory}) - dirent2, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo for n2: %v", err) - } - n2 := dirent2.Inode.InodeOperations - file2, err := n2.GetFile(ctx, dirent2, fs.FileFlags{Read: true, Write: true}) - if err != nil { - t.Fatalf("GetFile returned error: %v", err) - } - - // Writing? - if _, err := file2.Writev(ctx, usermem.BytesIOSequence([]byte("test"))); err == nil { - t.Fatalf("managed to write to n2: %v", err) - } - - // How about reading? - dirent3, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo: %v", err) - } - n3 := dirent3.Inode.InodeOperations - if n2 != n3 { - t.Error("got n2 != n3, want same") - } - - testSource.update = true - - table := []testTable{ - // Read past the end. - {100, 4, []byte{}, io.EOF}, - {110, 4, []byte{}, io.EOF}, - {200, 4, []byte{}, io.EOF}, - // Read a truncated first line. - {0, 4, testSource.actual[0].Buf[:4], nil}, - // Read the whole first line. - {0, 10, testSource.actual[0].Buf, nil}, - // Read the whole first line + 5 bytes of second line. - {0, 15, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf[:5]), nil}, - // First 4 bytes of the second line. - {10, 4, testSource.actual[1].Buf[:4], nil}, - // Read the two first lines. - {0, 20, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf), nil}, - // Read three lines. - {0, 30, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf, testSource.actual[2].Buf), nil}, - // Read everything, but use a bigger buffer than necessary. - {0, 150, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf, testSource.actual[2].Buf, testSource.actual[3].Buf, testSource.actual[4].Buf, testSource.actual[5].Buf, testSource.actual[6].Buf, testSource.actual[7].Buf, testSource.actual[8].Buf, testSource.actual[9].Buf), nil}, - // Read the last 3 bytes. - {97, 10, testSource.actual[9].Buf[7:], nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed with testSource.update = %v : %v", testSource.update, err) - } - - // Disable updates and do it again. - testSource.update = false - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed with testSource.update = %v: %v", testSource.update, err) - } -} - -// Test that we behave correctly when the file is updated. -func TestSeqFileFileUpdated(t *testing.T) { - testSource := &seqTest{} - testSource.Init() - testSource.update = true - - // Create a file that can be R/W. - ctx := contexttest.Context(t) - m := fs.NewPseudoMountSource(ctx) - contents := map[string]*fs.Inode{ - "foo": NewSeqFileInode(ctx, testSource, m), - } - root := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0777)) - - // How about opening it? - inode := fs.NewInode(ctx, root, m, fs.StableAttr{Type: fs.Directory}) - dirent2, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo for dirent2: %v", err) - } - - table := []testTable{ - {0, 16, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf[:6]), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed: %v", err) - } - // Delete the first entry. - cut := testSource.actual[0].Buf - testSource.actual = testSource.actual[1:] - - table = []testTable{ - // Try reading buffer 0 with an offset. This will not delete the old data. - {1, 5, cut[1:6], nil}, - // Reset our file by reading at offset 0. - {0, 10, testSource.actual[0].Buf, nil}, - {16, 14, flatten(testSource.actual[1].Buf[6:], testSource.actual[2].Buf), nil}, - // Read the same data a second time. - {16, 14, flatten(testSource.actual[1].Buf[6:], testSource.actual[2].Buf), nil}, - // Read the following two lines. - {30, 20, flatten(testSource.actual[3].Buf, testSource.actual[4].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after removing first entry: %v", err) - } - - // Add a new duplicate line in the middle (6666...) - after := testSource.actual[5:] - testSource.actual = testSource.actual[:4] - // Note the list must be sorted. - testSource.actual = append(testSource.actual, after[0]) - testSource.actual = append(testSource.actual, after...) - - table = []testTable{ - {50, 20, flatten(testSource.actual[4].Buf, testSource.actual[5].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after adding middle entry: %v", err) - } - // This will be used in a later test. - oldTestData := testSource.actual - - // Delete everything. - testSource.actual = testSource.actual[:0] - table = []testTable{ - {20, 20, []byte{}, io.EOF}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after removing all entries: %v", err) - } - // Restore some of the data. - testSource.actual = oldTestData[:1] - table = []testTable{ - {6, 20, testSource.actual[0].Buf[6:], nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after adding first entry back: %v", err) - } - - // Re-extend the data - testSource.actual = oldTestData - table = []testTable{ - {30, 20, flatten(testSource.actual[3].Buf, testSource.actual[4].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after extending testSource: %v", err) - } -} diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go deleted file mode 100644 index 355e83d47..000000000 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/usermem" -) - -func TestQuerySendBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - s.TCPSendBufSize = inet.TCPBufferSize{100, 200, 300} - tmi := &tcpMemInode{s: s, dir: tcpWMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - buf := make([]byte, 100) - dst := usermem.BytesIOSequence(buf) - n, err := tmf.Read(ctx, nil, dst, 0) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if got, want := string(buf[:n]), "100\t200\t300\n"; got != want { - t.Fatalf("Bad string: got %v, want %v", got, want) - } -} - -func TestQueryRecvBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - s.TCPRecvBufSize = inet.TCPBufferSize{100, 200, 300} - tmi := &tcpMemInode{s: s, dir: tcpRMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - buf := make([]byte, 100) - dst := usermem.BytesIOSequence(buf) - n, err := tmf.Read(ctx, nil, dst, 0) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if got, want := string(buf[:n]), "100\t200\t300\n"; got != want { - t.Fatalf("Bad string: got %v, want %v", got, want) - } -} - -var cases = []struct { - str string - initial inet.TCPBufferSize - final inet.TCPBufferSize -}{ - { - str: "", - initial: inet.TCPBufferSize{1, 2, 3}, - final: inet.TCPBufferSize{1, 2, 3}, - }, - { - str: "100\n", - initial: inet.TCPBufferSize{1, 100, 200}, - final: inet.TCPBufferSize{100, 100, 200}, - }, - { - str: "100 200 300\n", - initial: inet.TCPBufferSize{1, 2, 3}, - final: inet.TCPBufferSize{100, 200, 300}, - }, -} - -func TestConfigureSendBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - for _, c := range cases { - s.TCPSendBufSize = c.initial - tmi := &tcpMemInode{s: s, dir: tcpWMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - // Write the values. - src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := tmf.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("Write, case = %q: got (%d, %v), wanted (%d, nil)", c.str, n, err, len(c.str)) - } - - // Read the values from the stack and check them. - if s.TCPSendBufSize != c.final { - t.Errorf("TCPSendBufferSize, case = %q: got %v, wanted %v", c.str, s.TCPSendBufSize, c.final) - } - } -} - -func TestConfigureRecvBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - for _, c := range cases { - s.TCPRecvBufSize = c.initial - tmi := &tcpMemInode{s: s, dir: tcpRMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - // Write the values. - src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := tmf.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("Write, case = %q: got (%d, %v), wanted (%d, nil)", c.str, n, err, len(c.str)) - } - - // Read the values from the stack and check them. - if s.TCPRecvBufSize != c.final { - t.Errorf("TCPRecvBufferSize, case = %q: got %v, wanted %v", c.str, s.TCPRecvBufSize, c.final) - } - } -} diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD deleted file mode 100644 index 8ca823fb3..000000000 --- a/pkg/sentry/fs/ramfs/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ramfs", - srcs = [ - "dir.go", - "socket.go", - "symlink.go", - "tree.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/socket/unix/transport", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "ramfs_test", - size = "small", - srcs = ["tree_test.go"], - library = ":ramfs", - deps = [ - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - ], -) diff --git a/pkg/sentry/fs/ramfs/ramfs_state_autogen.go b/pkg/sentry/fs/ramfs/ramfs_state_autogen.go new file mode 100755 index 000000000..0a001e0b6 --- /dev/null +++ b/pkg/sentry/fs/ramfs/ramfs_state_autogen.go @@ -0,0 +1,94 @@ +// automatically generated by stateify. + +package ramfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Dir) beforeSave() {} +func (x *Dir) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("children", &x.children) + m.Save("dentryMap", &x.dentryMap) +} + +func (x *Dir) afterLoad() {} +func (x *Dir) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("children", &x.children) + m.Load("dentryMap", &x.dentryMap) +} + +func (x *dirFileOperations) beforeSave() {} +func (x *dirFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("dirCursor", &x.dirCursor) + m.Save("dir", &x.dir) +} + +func (x *dirFileOperations) afterLoad() {} +func (x *dirFileOperations) load(m state.Map) { + m.Load("dirCursor", &x.dirCursor) + m.Load("dir", &x.dir) +} + +func (x *Socket) beforeSave() {} +func (x *Socket) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("ep", &x.ep) +} + +func (x *Socket) afterLoad() {} +func (x *Socket) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("ep", &x.ep) +} + +func (x *socketFileOperations) beforeSave() {} +func (x *socketFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *socketFileOperations) afterLoad() {} +func (x *socketFileOperations) load(m state.Map) { +} + +func (x *Symlink) beforeSave() {} +func (x *Symlink) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("Target", &x.Target) +} + +func (x *Symlink) afterLoad() {} +func (x *Symlink) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("Target", &x.Target) +} + +func (x *symlinkFileOperations) beforeSave() {} +func (x *symlinkFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *symlinkFileOperations) afterLoad() {} +func (x *symlinkFileOperations) load(m state.Map) { +} + +func init() { + state.Register("pkg/sentry/fs/ramfs.Dir", (*Dir)(nil), state.Fns{Save: (*Dir).save, Load: (*Dir).load}) + state.Register("pkg/sentry/fs/ramfs.dirFileOperations", (*dirFileOperations)(nil), state.Fns{Save: (*dirFileOperations).save, Load: (*dirFileOperations).load}) + state.Register("pkg/sentry/fs/ramfs.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load}) + state.Register("pkg/sentry/fs/ramfs.socketFileOperations", (*socketFileOperations)(nil), state.Fns{Save: (*socketFileOperations).save, Load: (*socketFileOperations).load}) + state.Register("pkg/sentry/fs/ramfs.Symlink", (*Symlink)(nil), state.Fns{Save: (*Symlink).save, Load: (*Symlink).load}) + state.Register("pkg/sentry/fs/ramfs.symlinkFileOperations", (*symlinkFileOperations)(nil), state.Fns{Save: (*symlinkFileOperations).save, Load: (*symlinkFileOperations).load}) +} diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go deleted file mode 100644 index a6ed8b2c5..000000000 --- a/pkg/sentry/fs/ramfs/tree_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ramfs - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -func TestMakeDirectoryTree(t *testing.T) { - - for _, test := range []struct { - name string - subdirs []string - }{ - { - name: "abs paths", - subdirs: []string{ - "/tmp", - "/tmp/a/b", - "/tmp/a/c/d", - "/tmp/c", - "/proc", - "/dev/a/b", - "/tmp", - }, - }, - { - name: "rel paths", - subdirs: []string{ - "tmp", - "tmp/a/b", - "tmp/a/c/d", - "tmp/c", - "proc", - "dev/a/b", - "tmp", - }, - }, - } { - ctx := contexttest.Context(t) - mount := fs.NewPseudoMountSource(ctx) - tree, err := MakeDirectoryTree(ctx, mount, test.subdirs) - if err != nil { - t.Errorf("%s: failed to make ramfs tree, got error %v, want nil", test.name, err) - continue - } - - // Expect to be able to find each of the paths. - mm, err := fs.NewMountNamespace(ctx, tree) - if err != nil { - t.Errorf("%s: failed to create mount manager: %v", test.name, err) - continue - } - root := mm.Root() - defer mm.DecRef() - - for _, p := range test.subdirs { - maxTraversals := uint(0) - if _, err := mm.FindInode(ctx, root, nil, p, &maxTraversals); err != nil { - t.Errorf("%s: failed to find node %s: %v", test.name, p, err) - break - } - } - } -} diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD deleted file mode 100644 index f2e8b9932..000000000 --- a/pkg/sentry/fs/sys/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sys", - srcs = [ - "device.go", - "devices.go", - "fs.go", - "sys.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/kernel", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/sys/sys_state_autogen.go b/pkg/sentry/fs/sys/sys_state_autogen.go new file mode 100755 index 000000000..733c504b1 --- /dev/null +++ b/pkg/sentry/fs/sys/sys_state_autogen.go @@ -0,0 +1,34 @@ +// automatically generated by stateify. + +package sys + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *cpunum) beforeSave() {} +func (x *cpunum) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeStaticFileGetter", &x.InodeStaticFileGetter) +} + +func (x *cpunum) afterLoad() {} +func (x *cpunum) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeStaticFileGetter", &x.InodeStaticFileGetter) +} + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func init() { + state.Register("pkg/sentry/fs/sys.cpunum", (*cpunum)(nil), state.Fns{Save: (*cpunum).save, Load: (*cpunum).load}) + state.Register("pkg/sentry/fs/sys.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) +} diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD deleted file mode 100644 index d16cdb4df..000000000 --- a/pkg/sentry/fs/timerfd/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "timerfd", - srcs = ["timerfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel/time", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/timerfd/timerfd_state_autogen.go b/pkg/sentry/fs/timerfd/timerfd_state_autogen.go new file mode 100755 index 000000000..b1335d3c7 --- /dev/null +++ b/pkg/sentry/fs/timerfd/timerfd_state_autogen.go @@ -0,0 +1,27 @@ +// automatically generated by stateify. + +package timerfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *TimerOperations) beforeSave() {} +func (x *TimerOperations) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.events) { + m.Failf("events is %v, expected zero", x.events) + } + m.Save("timer", &x.timer) + m.Save("val", &x.val) +} + +func (x *TimerOperations) afterLoad() {} +func (x *TimerOperations) load(m state.Map) { + m.Load("timer", &x.timer) + m.Load("val", &x.val) +} + +func init() { + state.Register("pkg/sentry/fs/timerfd.TimerOperations", (*TimerOperations)(nil), state.Fns{Save: (*TimerOperations).save, Load: (*TimerOperations).load}) +} diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD deleted file mode 100644 index aa7199014..000000000 --- a/pkg/sentry/fs/tmpfs/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tmpfs", - srcs = [ - "device.go", - "file_regular.go", - "fs.go", - "inode_file.go", - "tmpfs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/metric", - "//pkg/safemem", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "tmpfs_test", - size = "small", - srcs = ["file_test.go"], - library = ":tmpfs", - deps = [ - "//pkg/context", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/contexttest", - "//pkg/sentry/usage", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go deleted file mode 100644 index aaba35502..000000000 --- a/pkg/sentry/fs/tmpfs/file_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/usermem" -) - -func newFileInode(ctx context.Context) *fs.Inode { - m := fs.NewCachingMountSource(ctx, &Filesystem{}, fs.MountSourceFlags{}) - iops := NewInMemoryFile(ctx, usage.Tmpfs, fs.WithCurrentTime(ctx, fs.UnstableAttr{})) - return fs.NewInode(ctx, iops, m, fs.StableAttr{ - DeviceID: tmpfsDevice.DeviceID(), - InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, - Type: fs.RegularFile, - }) -} - -func newFile(ctx context.Context) *fs.File { - inode := newFileInode(ctx) - f, _ := inode.GetFile(ctx, fs.NewDirent(ctx, inode, "stub"), fs.FileFlags{Read: true, Write: true}) - return f -} - -// Allocate once, write twice. -func TestGrow(t *testing.T) { - ctx := contexttest.Context(t) - f := newFile(ctx) - defer f.DecRef() - - abuf := bytes.Repeat([]byte{'a'}, 68) - n, err := f.Pwritev(ctx, usermem.BytesIOSequence(abuf), 0) - if n != int64(len(abuf)) || err != nil { - t.Fatalf("Pwritev got (%d, %v) want (%d, nil)", n, err, len(abuf)) - } - - bbuf := bytes.Repeat([]byte{'b'}, 856) - n, err = f.Pwritev(ctx, usermem.BytesIOSequence(bbuf), 68) - if n != int64(len(bbuf)) || err != nil { - t.Fatalf("Pwritev got (%d, %v) want (%d, nil)", n, err, len(bbuf)) - } - - rbuf := make([]byte, len(abuf)+len(bbuf)) - n, err = f.Preadv(ctx, usermem.BytesIOSequence(rbuf), 0) - if n != int64(len(rbuf)) || err != nil { - t.Fatalf("Preadv got (%d, %v) want (%d, nil)", n, err, len(rbuf)) - } - - if want := append(abuf, bbuf...); !bytes.Equal(rbuf, want) { - t.Fatalf("Read %v, want %v", rbuf, want) - } -} diff --git a/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go b/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go new file mode 100755 index 000000000..e4d2584fd --- /dev/null +++ b/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go @@ -0,0 +1,108 @@ +// automatically generated by stateify. + +package tmpfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *regularFileOperations) beforeSave() {} +func (x *regularFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("iops", &x.iops) +} + +func (x *regularFileOperations) afterLoad() {} +func (x *regularFileOperations) load(m state.Map) { + m.Load("iops", &x.iops) +} + +func (x *Filesystem) beforeSave() {} +func (x *Filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *Filesystem) afterLoad() {} +func (x *Filesystem) load(m state.Map) { +} + +func (x *fileInodeOperations) beforeSave() {} +func (x *fileInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("kernel", &x.kernel) + m.Save("memUsage", &x.memUsage) + m.Save("attr", &x.attr) + m.Save("mappings", &x.mappings) + m.Save("writableMappingPages", &x.writableMappingPages) + m.Save("data", &x.data) + m.Save("seals", &x.seals) +} + +func (x *fileInodeOperations) afterLoad() {} +func (x *fileInodeOperations) load(m state.Map) { + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("kernel", &x.kernel) + m.Load("memUsage", &x.memUsage) + m.Load("attr", &x.attr) + m.Load("mappings", &x.mappings) + m.Load("writableMappingPages", &x.writableMappingPages) + m.Load("data", &x.data) + m.Load("seals", &x.seals) +} + +func (x *Dir) beforeSave() {} +func (x *Dir) save(m state.Map) { + x.beforeSave() + m.Save("ramfsDir", &x.ramfsDir) + m.Save("kernel", &x.kernel) +} + +func (x *Dir) load(m state.Map) { + m.Load("ramfsDir", &x.ramfsDir) + m.Load("kernel", &x.kernel) + m.AfterLoad(x.afterLoad) +} + +func (x *Symlink) beforeSave() {} +func (x *Symlink) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) +} + +func (x *Symlink) afterLoad() {} +func (x *Symlink) load(m state.Map) { + m.Load("Symlink", &x.Symlink) +} + +func (x *Socket) beforeSave() {} +func (x *Socket) save(m state.Map) { + x.beforeSave() + m.Save("Socket", &x.Socket) +} + +func (x *Socket) afterLoad() {} +func (x *Socket) load(m state.Map) { + m.Load("Socket", &x.Socket) +} + +func (x *Fifo) beforeSave() {} +func (x *Fifo) save(m state.Map) { + x.beforeSave() + m.Save("InodeOperations", &x.InodeOperations) +} + +func (x *Fifo) afterLoad() {} +func (x *Fifo) load(m state.Map) { + m.Load("InodeOperations", &x.InodeOperations) +} + +func init() { + state.Register("pkg/sentry/fs/tmpfs.regularFileOperations", (*regularFileOperations)(nil), state.Fns{Save: (*regularFileOperations).save, Load: (*regularFileOperations).load}) + state.Register("pkg/sentry/fs/tmpfs.Filesystem", (*Filesystem)(nil), state.Fns{Save: (*Filesystem).save, Load: (*Filesystem).load}) + state.Register("pkg/sentry/fs/tmpfs.fileInodeOperations", (*fileInodeOperations)(nil), state.Fns{Save: (*fileInodeOperations).save, Load: (*fileInodeOperations).load}) + state.Register("pkg/sentry/fs/tmpfs.Dir", (*Dir)(nil), state.Fns{Save: (*Dir).save, Load: (*Dir).load}) + state.Register("pkg/sentry/fs/tmpfs.Symlink", (*Symlink)(nil), state.Fns{Save: (*Symlink).save, Load: (*Symlink).load}) + state.Register("pkg/sentry/fs/tmpfs.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load}) + state.Register("pkg/sentry/fs/tmpfs.Fifo", (*Fifo)(nil), state.Fns{Save: (*Fifo).save, Load: (*Fifo).load}) +} diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD deleted file mode 100644 index 5cb0e0417..000000000 --- a/pkg/sentry/fs/tty/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tty", - srcs = [ - "dir.go", - "fs.go", - "line_discipline.go", - "master.go", - "queue.go", - "slave.go", - "terminal.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/refs", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/unimpl", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "tty_test", - size = "small", - srcs = ["tty_test.go"], - library = ":tty", - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/contexttest", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/tty/tty_state_autogen.go b/pkg/sentry/fs/tty/tty_state_autogen.go new file mode 100755 index 000000000..9963096dd --- /dev/null +++ b/pkg/sentry/fs/tty/tty_state_autogen.go @@ -0,0 +1,210 @@ +// automatically generated by stateify. + +package tty + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *dirInodeOperations) beforeSave() {} +func (x *dirInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("msrc", &x.msrc) + m.Save("master", &x.master) + m.Save("slaves", &x.slaves) + m.Save("dentryMap", &x.dentryMap) + m.Save("next", &x.next) +} + +func (x *dirInodeOperations) afterLoad() {} +func (x *dirInodeOperations) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("msrc", &x.msrc) + m.Load("master", &x.master) + m.Load("slaves", &x.slaves) + m.Load("dentryMap", &x.dentryMap) + m.Load("next", &x.next) +} + +func (x *dirFileOperations) beforeSave() {} +func (x *dirFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("di", &x.di) + m.Save("dirCursor", &x.dirCursor) +} + +func (x *dirFileOperations) afterLoad() {} +func (x *dirFileOperations) load(m state.Map) { + m.Load("di", &x.di) + m.Load("dirCursor", &x.dirCursor) +} + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func (x *superOperations) beforeSave() {} +func (x *superOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *superOperations) afterLoad() {} +func (x *superOperations) load(m state.Map) { +} + +func (x *lineDiscipline) beforeSave() {} +func (x *lineDiscipline) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.masterWaiter) { + m.Failf("masterWaiter is %v, expected zero", x.masterWaiter) + } + if !state.IsZeroValue(x.slaveWaiter) { + m.Failf("slaveWaiter is %v, expected zero", x.slaveWaiter) + } + m.Save("size", &x.size) + m.Save("inQueue", &x.inQueue) + m.Save("outQueue", &x.outQueue) + m.Save("termios", &x.termios) + m.Save("column", &x.column) +} + +func (x *lineDiscipline) afterLoad() {} +func (x *lineDiscipline) load(m state.Map) { + m.Load("size", &x.size) + m.Load("inQueue", &x.inQueue) + m.Load("outQueue", &x.outQueue) + m.Load("termios", &x.termios) + m.Load("column", &x.column) +} + +func (x *outputQueueTransformer) beforeSave() {} +func (x *outputQueueTransformer) save(m state.Map) { + x.beforeSave() +} + +func (x *outputQueueTransformer) afterLoad() {} +func (x *outputQueueTransformer) load(m state.Map) { +} + +func (x *inputQueueTransformer) beforeSave() {} +func (x *inputQueueTransformer) save(m state.Map) { + x.beforeSave() +} + +func (x *inputQueueTransformer) afterLoad() {} +func (x *inputQueueTransformer) load(m state.Map) { +} + +func (x *masterInodeOperations) beforeSave() {} +func (x *masterInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("d", &x.d) +} + +func (x *masterInodeOperations) afterLoad() {} +func (x *masterInodeOperations) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("d", &x.d) +} + +func (x *masterFileOperations) beforeSave() {} +func (x *masterFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("d", &x.d) + m.Save("t", &x.t) +} + +func (x *masterFileOperations) afterLoad() {} +func (x *masterFileOperations) load(m state.Map) { + m.Load("d", &x.d) + m.Load("t", &x.t) +} + +func (x *queue) beforeSave() {} +func (x *queue) save(m state.Map) { + x.beforeSave() + m.Save("readBuf", &x.readBuf) + m.Save("waitBuf", &x.waitBuf) + m.Save("waitBufLen", &x.waitBufLen) + m.Save("readable", &x.readable) + m.Save("transformer", &x.transformer) +} + +func (x *queue) afterLoad() {} +func (x *queue) load(m state.Map) { + m.Load("readBuf", &x.readBuf) + m.Load("waitBuf", &x.waitBuf) + m.Load("waitBufLen", &x.waitBufLen) + m.Load("readable", &x.readable) + m.Load("transformer", &x.transformer) +} + +func (x *slaveInodeOperations) beforeSave() {} +func (x *slaveInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("d", &x.d) + m.Save("t", &x.t) +} + +func (x *slaveInodeOperations) afterLoad() {} +func (x *slaveInodeOperations) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("d", &x.d) + m.Load("t", &x.t) +} + +func (x *slaveFileOperations) beforeSave() {} +func (x *slaveFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("si", &x.si) +} + +func (x *slaveFileOperations) afterLoad() {} +func (x *slaveFileOperations) load(m state.Map) { + m.Load("si", &x.si) +} + +func (x *Terminal) beforeSave() {} +func (x *Terminal) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("n", &x.n) + m.Save("d", &x.d) + m.Save("ld", &x.ld) + m.Save("masterKTTY", &x.masterKTTY) + m.Save("slaveKTTY", &x.slaveKTTY) +} + +func (x *Terminal) afterLoad() {} +func (x *Terminal) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("n", &x.n) + m.Load("d", &x.d) + m.Load("ld", &x.ld) + m.Load("masterKTTY", &x.masterKTTY) + m.Load("slaveKTTY", &x.slaveKTTY) +} + +func init() { + state.Register("pkg/sentry/fs/tty.dirInodeOperations", (*dirInodeOperations)(nil), state.Fns{Save: (*dirInodeOperations).save, Load: (*dirInodeOperations).load}) + state.Register("pkg/sentry/fs/tty.dirFileOperations", (*dirFileOperations)(nil), state.Fns{Save: (*dirFileOperations).save, Load: (*dirFileOperations).load}) + state.Register("pkg/sentry/fs/tty.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) + state.Register("pkg/sentry/fs/tty.superOperations", (*superOperations)(nil), state.Fns{Save: (*superOperations).save, Load: (*superOperations).load}) + state.Register("pkg/sentry/fs/tty.lineDiscipline", (*lineDiscipline)(nil), state.Fns{Save: (*lineDiscipline).save, Load: (*lineDiscipline).load}) + state.Register("pkg/sentry/fs/tty.outputQueueTransformer", (*outputQueueTransformer)(nil), state.Fns{Save: (*outputQueueTransformer).save, Load: (*outputQueueTransformer).load}) + state.Register("pkg/sentry/fs/tty.inputQueueTransformer", (*inputQueueTransformer)(nil), state.Fns{Save: (*inputQueueTransformer).save, Load: (*inputQueueTransformer).load}) + state.Register("pkg/sentry/fs/tty.masterInodeOperations", (*masterInodeOperations)(nil), state.Fns{Save: (*masterInodeOperations).save, Load: (*masterInodeOperations).load}) + state.Register("pkg/sentry/fs/tty.masterFileOperations", (*masterFileOperations)(nil), state.Fns{Save: (*masterFileOperations).save, Load: (*masterFileOperations).load}) + state.Register("pkg/sentry/fs/tty.queue", (*queue)(nil), state.Fns{Save: (*queue).save, Load: (*queue).load}) + state.Register("pkg/sentry/fs/tty.slaveInodeOperations", (*slaveInodeOperations)(nil), state.Fns{Save: (*slaveInodeOperations).save, Load: (*slaveInodeOperations).load}) + state.Register("pkg/sentry/fs/tty.slaveFileOperations", (*slaveFileOperations)(nil), state.Fns{Save: (*slaveFileOperations).save, Load: (*slaveFileOperations).load}) + state.Register("pkg/sentry/fs/tty.Terminal", (*Terminal)(nil), state.Fns{Save: (*Terminal).save, Load: (*Terminal).load}) +} diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go deleted file mode 100644 index 2cbc05678..000000000 --- a/pkg/sentry/fs/tty/tty_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tty - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/usermem" -) - -func TestSimpleMasterToSlave(t *testing.T) { - ld := newLineDiscipline(linux.DefaultSlaveTermios) - ctx := contexttest.Context(t) - inBytes := []byte("hello, tty\n") - src := usermem.BytesIOSequence(inBytes) - outBytes := make([]byte, 32) - dst := usermem.BytesIOSequence(outBytes) - - // Write to the input queue. - nw, err := ld.inputQueueWrite(ctx, src) - if err != nil { - t.Fatalf("error writing to input queue: %v", err) - } - if nw != int64(len(inBytes)) { - t.Fatalf("wrote wrong length: got %d, want %d", nw, len(inBytes)) - } - - // Read from the input queue. - nr, err := ld.inputQueueRead(ctx, dst) - if err != nil { - t.Fatalf("error reading from input queue: %v", err) - } - if nr != int64(len(inBytes)) { - t.Fatalf("read wrong length: got %d, want %d", nr, len(inBytes)) - } - - outStr := string(outBytes[:nr]) - inStr := string(inBytes) - if outStr != inStr { - t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr) - } -} diff --git a/pkg/sentry/fsbridge/BUILD b/pkg/sentry/fsbridge/BUILD deleted file mode 100644 index 6c798f0bd..000000000 --- a/pkg/sentry/fsbridge/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "fsbridge", - srcs = [ - "bridge.go", - "fs.go", - "vfs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsbridge/bridge.go b/pkg/sentry/fsbridge/bridge.go index 8e7590721..8e7590721 100644..100755 --- a/pkg/sentry/fsbridge/bridge.go +++ b/pkg/sentry/fsbridge/bridge.go diff --git a/pkg/sentry/fsbridge/fs.go b/pkg/sentry/fsbridge/fs.go index 093ce1fb3..093ce1fb3 100644..100755 --- a/pkg/sentry/fsbridge/fs.go +++ b/pkg/sentry/fsbridge/fs.go diff --git a/pkg/sentry/fsbridge/fsbridge_state_autogen.go b/pkg/sentry/fsbridge/fsbridge_state_autogen.go new file mode 100755 index 000000000..51b57d859 --- /dev/null +++ b/pkg/sentry/fsbridge/fsbridge_state_autogen.go @@ -0,0 +1,66 @@ +// automatically generated by stateify. + +package fsbridge + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *fsFile) beforeSave() {} +func (x *fsFile) save(m state.Map) { + x.beforeSave() + m.Save("file", &x.file) +} + +func (x *fsFile) afterLoad() {} +func (x *fsFile) load(m state.Map) { + m.Load("file", &x.file) +} + +func (x *fsLookup) beforeSave() {} +func (x *fsLookup) save(m state.Map) { + x.beforeSave() + m.Save("mntns", &x.mntns) + m.Save("root", &x.root) + m.Save("workingDir", &x.workingDir) +} + +func (x *fsLookup) afterLoad() {} +func (x *fsLookup) load(m state.Map) { + m.Load("mntns", &x.mntns) + m.Load("root", &x.root) + m.Load("workingDir", &x.workingDir) +} + +func (x *vfsFile) beforeSave() {} +func (x *vfsFile) save(m state.Map) { + x.beforeSave() + m.Save("file", &x.file) +} + +func (x *vfsFile) afterLoad() {} +func (x *vfsFile) load(m state.Map) { + m.Load("file", &x.file) +} + +func (x *vfsLookup) beforeSave() {} +func (x *vfsLookup) save(m state.Map) { + x.beforeSave() + m.Save("mntns", &x.mntns) + m.Save("root", &x.root) + m.Save("workingDir", &x.workingDir) +} + +func (x *vfsLookup) afterLoad() {} +func (x *vfsLookup) load(m state.Map) { + m.Load("mntns", &x.mntns) + m.Load("root", &x.root) + m.Load("workingDir", &x.workingDir) +} + +func init() { + state.Register("pkg/sentry/fsbridge.fsFile", (*fsFile)(nil), state.Fns{Save: (*fsFile).save, Load: (*fsFile).load}) + state.Register("pkg/sentry/fsbridge.fsLookup", (*fsLookup)(nil), state.Fns{Save: (*fsLookup).save, Load: (*fsLookup).load}) + state.Register("pkg/sentry/fsbridge.vfsFile", (*vfsFile)(nil), state.Fns{Save: (*vfsFile).save, Load: (*vfsFile).load}) + state.Register("pkg/sentry/fsbridge.vfsLookup", (*vfsLookup)(nil), state.Fns{Save: (*vfsLookup).save, Load: (*vfsLookup).load}) +} diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index 79b808359..79b808359 100644..100755 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD deleted file mode 100644 index aa0c2ad8c..000000000 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "devtmpfs", - srcs = ["devtmpfs.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/sync", - ], -) - -go_test( - name = "devtmpfs_test", - size = "small", - srcs = ["devtmpfs_test.go"], - library = ":devtmpfs", - deps = [ - "//pkg/abi/linux", - "//pkg/fspath", - "//pkg/sentry/contexttest", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - ], -) diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go deleted file mode 100644 index abd4f24e7..000000000 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package devtmpfs provides an implementation of /dev based on tmpfs, -// analogous to Linux's devtmpfs. -package devtmpfs - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" -) - -// Name is the default filesystem name. -const Name = "devtmpfs" - -// FilesystemType implements vfs.FilesystemType. -type FilesystemType struct { - initOnce sync.Once - initErr error - - // fs is the tmpfs filesystem that backs all mounts of this FilesystemType. - // root is fs' root. fs and root are immutable. - fs *vfs.Filesystem - root *vfs.Dentry -} - -// GetFilesystem implements vfs.FilesystemType.GetFilesystem. -func (fst *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - fst.initOnce.Do(func() { - fs, root, err := tmpfs.FilesystemType{}.GetFilesystem(ctx, vfsObj, creds, "" /* source */, vfs.GetFilesystemOptions{ - Data: "mode=0755", // opts from drivers/base/devtmpfs.c:devtmpfs_init() - }) - if err != nil { - fst.initErr = err - return - } - fst.fs = fs - fst.root = root - }) - if fst.initErr != nil { - return nil, nil, fst.initErr - } - fst.fs.IncRef() - fst.root.IncRef() - return fst.fs, fst.root, nil -} - -// Accessor allows devices to create device special files in devtmpfs. -type Accessor struct { - vfsObj *vfs.VirtualFilesystem - mntns *vfs.MountNamespace - root vfs.VirtualDentry - creds *auth.Credentials -} - -// NewAccessor returns an Accessor that supports creation of device special -// files in the devtmpfs instance registered with name fsTypeName in vfsObj. -func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, fsTypeName string) (*Accessor, error) { - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.GetFilesystemOptions{}) - if err != nil { - return nil, err - } - return &Accessor{ - vfsObj: vfsObj, - mntns: mntns, - root: mntns.Root(), - creds: creds, - }, nil -} - -// Release must be called when a is no longer in use. -func (a *Accessor) Release() { - a.root.DecRef() - a.mntns.DecRef() -} - -// accessorContext implements context.Context by extending an existing -// context.Context with an Accessor's values for VFS-relevant state. -type accessorContext struct { - context.Context - a *Accessor -} - -func (a *Accessor) wrapContext(ctx context.Context) *accessorContext { - return &accessorContext{ - Context: ctx, - a: a, - } -} - -// Value implements context.Context.Value. -func (ac *accessorContext) Value(key interface{}) interface{} { - switch key { - case vfs.CtxMountNamespace: - ac.a.mntns.IncRef() - return ac.a.mntns - case vfs.CtxRoot: - ac.a.root.IncRef() - return ac.a.root - default: - return ac.Context.Value(key) - } -} - -func (a *Accessor) pathOperationAt(pathname string) *vfs.PathOperation { - return &vfs.PathOperation{ - Root: a.root, - Start: a.root, - Path: fspath.Parse(pathname), - } -} - -// CreateDeviceFile creates a device special file at the given pathname in the -// devtmpfs instance accessed by the Accessor. -func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind vfs.DeviceKind, major, minor uint32, perms uint16) error { - mode := (linux.FileMode)(perms) - switch kind { - case vfs.BlockDevice: - mode |= linux.S_IFBLK - case vfs.CharDevice: - mode |= linux.S_IFCHR - default: - panic(fmt.Sprintf("invalid vfs.DeviceKind: %v", kind)) - } - // NOTE: Linux's devtmpfs refuses to automatically delete files it didn't - // create, which it recognizes by storing a pointer to the kdevtmpfs struct - // thread in struct inode::i_private. Accessor doesn't yet support deletion - // of files at all, and probably won't as long as we don't need to support - // kernel modules, so this is moot for now. - return a.vfsObj.MknodAt(a.wrapContext(ctx), a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{ - Mode: mode, - DevMajor: major, - DevMinor: minor, - }) -} - -// UserspaceInit creates symbolic links and mount points in the devtmpfs -// instance accessed by the Accessor that are created by userspace in Linux. It -// does not create mounts. -func (a *Accessor) UserspaceInit(ctx context.Context) error { - actx := a.wrapContext(ctx) - - // systemd: src/shared/dev-setup.c:dev_setup() - for _, symlink := range []struct { - source string - target string - }{ - // /proc/kcore is not implemented. - {source: "fd", target: "/proc/self/fd"}, - {source: "stdin", target: "/proc/self/fd/0"}, - {source: "stdout", target: "/proc/self/fd/1"}, - {source: "stderr", target: "/proc/self/fd/2"}, - } { - if err := a.vfsObj.SymlinkAt(actx, a.creds, a.pathOperationAt(symlink.source), symlink.target); err != nil { - return fmt.Errorf("failed to create symlink %q => %q: %v", symlink.source, symlink.target, err) - } - } - - // systemd: src/core/mount-setup.c:mount_table - for _, dir := range []string{ - "shm", - "pts", - } { - if err := a.vfsObj.MkdirAt(actx, a.creds, a.pathOperationAt(dir), &vfs.MkdirOptions{ - // systemd: src/core/mount-setup.c:mount_one() - Mode: 0755, - }); err != nil { - return fmt.Errorf("failed to create directory %q: %v", dir, err) - } - } - - return nil -} diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go deleted file mode 100644 index b6d52c015..000000000 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devtmpfs - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -func TestDevtmpfs(t *testing.T) { - ctx := contexttest.Context(t) - creds := auth.CredentialsFromContext(ctx) - - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - t.Fatalf("VFS init: %v", err) - } - // Register tmpfs just so that we can have a root filesystem that isn't - // devtmpfs. - vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - vfsObj.MustRegisterFilesystemType("devtmpfs", &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - - // Create a test mount namespace with devtmpfs mounted at "/dev". - const devPath = "/dev" - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.GetFilesystemOptions{}) - if err != nil { - t.Fatalf("failed to create tmpfs root mount: %v", err) - } - defer mntns.DecRef() - root := mntns.Root() - defer root.DecRef() - devpop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(devPath), - } - if err := vfsObj.MkdirAt(ctx, creds, &devpop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - t.Fatalf("failed to create mount point: %v", err) - } - if err := vfsObj.MountAt(ctx, creds, "devtmpfs" /* source */, &devpop, "devtmpfs" /* fsTypeName */, &vfs.MountOptions{}); err != nil { - t.Fatalf("failed to mount devtmpfs: %v", err) - } - - a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs") - if err != nil { - t.Fatalf("failed to create devtmpfs.Accessor: %v", err) - } - defer a.Release() - - // Create "userspace-initialized" files using a devtmpfs.Accessor. - if err := a.UserspaceInit(ctx); err != nil { - t.Fatalf("failed to userspace-initialize devtmpfs: %v", err) - } - // Created files should be visible in the test mount namespace. - abspath := devPath + "/fd" - target, err := vfsObj.ReadlinkAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }) - if want := "/proc/self/fd"; err != nil || target != want { - t.Fatalf("readlink(%q): got (%q, %v), wanted (%q, nil)", abspath, target, err, want) - } - - // Create a dummy device special file using a devtmpfs.Accessor. - const ( - pathInDev = "dummy" - kind = vfs.CharDevice - major = 12 - minor = 34 - perms = 0600 - wantMode = linux.S_IFCHR | perms - ) - if err := a.CreateDeviceFile(ctx, pathInDev, kind, major, minor, perms); err != nil { - t.Fatalf("failed to create device file: %v", err) - } - // The device special file should be visible in the test mount namespace. - abspath = devPath + "/" + pathInDev - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }, &vfs.StatOptions{ - Mask: linux.STATX_TYPE | linux.STATX_MODE, - }) - if err != nil { - t.Fatalf("failed to stat device file at %q: %v", abspath, err) - } - if stat.Mode != wantMode { - t.Errorf("device file mode: got %v, wanted %v", stat.Mode, wantMode) - } - if stat.RdevMajor != major { - t.Errorf("major device number: got %v, wanted %v", stat.RdevMajor, major) - } - if stat.RdevMinor != minor { - t.Errorf("minor device number: got %v, wanted %v", stat.RdevMinor, minor) - } -} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD deleted file mode 100644 index 6f78f478f..000000000 --- a/pkg/sentry/fsimpl/ext/BUILD +++ /dev/null @@ -1,88 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "dirent_list", - out = "dirent_list.go", - package = "ext", - prefix = "dirent", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*dirent", - "Linker": "*dirent", - }, -) - -go_library( - name = "ext", - srcs = [ - "block_map_file.go", - "dentry.go", - "directory.go", - "dirent_list.go", - "ext.go", - "extent_file.go", - "file_description.go", - "filesystem.go", - "inode.go", - "regular_file.go", - "symlink.go", - "utils.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/fd", - "//pkg/fspath", - "//pkg/log", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fsimpl/ext/disklayout", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/syscalls/linux", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "ext_test", - size = "small", - srcs = [ - "block_map_test.go", - "ext_test.go", - "extent_test.go", - ], - data = [ - "//pkg/sentry/fsimpl/ext:assets/bigfile.txt", - "//pkg/sentry/fsimpl/ext:assets/file.txt", - "//pkg/sentry/fsimpl/ext:assets/tiny.ext2", - "//pkg/sentry/fsimpl/ext:assets/tiny.ext3", - "//pkg/sentry/fsimpl/ext:assets/tiny.ext4", - ], - library = ":ext", - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/fspath", - "//pkg/sentry/contexttest", - "//pkg/sentry/fsimpl/ext/disklayout", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/usermem", - "//runsc/testutil", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/ext/README.md b/pkg/sentry/fsimpl/ext/README.md deleted file mode 100644 index af00cfda8..000000000 --- a/pkg/sentry/fsimpl/ext/README.md +++ /dev/null @@ -1,117 +0,0 @@ -## EXT(2/3/4) File System - -This is a filesystem driver which supports ext2, ext3 and ext4 filesystems. -Linux has specialized drivers for each variant but none which supports all. This -library takes advantage of ext's backward compatibility and understands the -internal organization of on-disk structures to support all variants. - -This driver implementation diverges from the Linux implementations in being more -forgiving about versioning. For instance, if a filesystem contains both extent -based inodes and classical block map based inodes, this driver will not complain -and interpret them both correctly. While in Linux this would be an issue. This -blurs the line between the three ext fs variants. - -Ext2 is considered deprecated as of Red Hat Enterprise Linux 7, and ext3 has -been superseded by ext4 by large performance gains. Thus it is recommended to -upgrade older filesystem images to ext4 using e2fsprogs for better performance. - -### Read Only - -This driver currently only allows read only operations. A lot of the design -decisions are based on this feature. There are plans to implement write (the -process for which is documented in the future work section). - -### Performance - -One of the biggest wins about this driver is that it directly talks to the -underlying block device (or whatever persistent storage is being used), instead -of making expensive RPCs to a gofer. - -Another advantage is that ext fs supports fast concurrent reads. Currently the -device is represented using a `io.ReaderAt` which allows for concurrent reads. -All reads are directly passed to the device driver which intelligently serves -the read requests in the optimal order. There is no congestion due to locking -while reading in the filesystem level. - -Reads are optimized further in the way file data is transferred over to user -memory. Ext fs directly copies over file data from disk into user memory with no -additional allocations on the way. We can only get faster by preloading file -data into memory (see future work section). - -The internal structures used to represent files, inodes and file descriptors use -a lot of inheritance. With the level of indirection that an interface adds with -an internal pointer, it can quickly fragment a structure across memory. As this -runs along side a full blown kernel (which is memory intensive), having a -fragmented struct might hurt performance. Hence these internal structures, -though interfaced, are tightly packed in memory using the same inheritance -pattern that pkg/sentry/vfs uses. The pkg/sentry/fsimpl/ext/disklayout package -makes an execption to this pattern for reasons documented in the package. - -### Security - -This driver also intends to help sandbox the container better by reducing the -surface of the host kernel that the application touches. It prevents the -application from exploiting vulnerabilities in the host filesystem driver. All -`io.ReaderAt.ReadAt()` calls are translated to `pread(2)` which are directly -passed to the device driver in the kernel. Hence this reduces the surface for -attack. - -The application can not affect any host filesystems other than the one passed -via block device by the user. - -### Future Work - -#### Write - -To support write operations we would need to modify the block device underneath. -Currently, the driver does not modify the device at all, not even for updating -the access times for reads. Modifying the filesystem incorrectly can corrupt it -and render it unreadable for other correct ext(x) drivers. Hence caution must be -maintained while modifying metadata structures. - -Ext4 specifically is built for performance and has added a lot of complexity as -to how metadata structures are modified. For instance, files that are organized -via an extent tree which must be balanced and file data blocks must be placed in -the same extent as much as possible to increase locality. Such properties must -be maintained while modifying the tree. - -Ext filesystems boast a lot about locality, which plays a big role in them being -performant. The block allocation algorithm in Linux does a good job in keeping -related data together. This behavior must be maintained as much as possible, -else we might end up degrading the filesystem performance over time. - -Ext4 also supports a wide variety of features which are specialized for varying -use cases. Implementing all of them can get difficult very quickly. - -Ext(x) checksums all its metadata structures to check for corruption, so -modification of any metadata struct must correspond with re-checksumming the -struct. Linux filesystem drivers also order on-disk updates intelligently to not -corrupt the filesystem and also remain performant. The in-memory metadata -structures must be kept in sync with what is on disk. - -There is also replication of some important structures across the filesystem. -All replicas must be updated when their original copy is updated. There is also -provisioning for snapshotting which must be kept in mind, although it should not -affect this implementation unless we allow users to create filesystem snapshots. - -Ext4 also introduced journaling (jbd2). The journal must be updated -appropriately. - -#### Performance - -To improve performance we should implement a buffer cache, and optionally, read -ahead for small files. While doing so we must also keep in mind the memory usage -and have a reasonable cap on how much file data we want to hold in memory. - -#### Features - -Our current implementation will work with most ext4 filesystems for readonly -purposed. However, the following features are not supported yet: - -- Journal -- Snapshotting -- Extended Attributes -- Hash Tree Directories -- Meta Block Groups -- Multiple Mount Protection -- Bigalloc diff --git a/pkg/sentry/fsimpl/ext/assets/README.md b/pkg/sentry/fsimpl/ext/assets/README.md deleted file mode 100644 index 6f1e81b3a..000000000 --- a/pkg/sentry/fsimpl/ext/assets/README.md +++ /dev/null @@ -1,36 +0,0 @@ -### Tiny Ext(2/3/4) Images - -The images are of size 64Kb which supports 64 1k blocks and 16 inodes. This is -the smallest size mkfs.ext(2/3/4) works with. - -These images were generated using the following commands. - -```bash -fallocate -l 64K tiny.ext$VERSION -mkfs.ext$VERSION -j tiny.ext$VERSION -``` - -where `VERSION` is `2`, `3` or `4`. - -You can mount it using: - -```bash -sudo mount -o loop tiny.ext$VERSION $MOUNTPOINT -``` - -`file.txt`, `bigfile.txt` and `symlink.txt` were added to this image by just -mounting it and copying (while preserving links) those files to the mountpoint -directory using: - -```bash -sudo cp -P {file.txt,symlink.txt,bigfile.txt} $MOUNTPOINT -``` - -The files in this directory mirror the contents and organisation of the files -stored in the image. - -You can umount the filesystem using: - -```bash -sudo umount $MOUNTPOINT -``` diff --git a/pkg/sentry/fsimpl/ext/assets/bigfile.txt b/pkg/sentry/fsimpl/ext/assets/bigfile.txt deleted file mode 100644 index 3857cf516..000000000 --- a/pkg/sentry/fsimpl/ext/assets/bigfile.txt +++ /dev/null @@ -1,41 +0,0 @@ -Lorem ipsum dolor sit amet, consectetur adipiscing elit. Phasellus faucibus eleifend orci, ut ornare nibh faucibus eu. Cras at condimentum massa. Nullam luctus, elit non porttitor congue, sapien diam feugiat sapien, sed eleifend nulla mauris non arcu. Sed lacinia mauris magna, eu mollis libero varius sit amet. Donec mollis, quam convallis commodo posuere, dolor nisi placerat nisi, in faucibus augue mi eu lorem. In pharetra consectetur faucibus. Ut euismod ex efficitur egestas tincidunt. Maecenas condimentum ut ante in rutrum. Vivamus sed arcu tempor, faucibus turpis et, lacinia diam. - -Sed in lacus vel nisl interdum bibendum in sed justo. Nunc tellus risus, molestie vitae arcu sed, molestie tempus ligula. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Nunc risus neque, volutpat et ante non, ullamcorper condimentum ante. Aliquam sed metus in urna condimentum convallis. Vivamus ut libero mauris. Proin mollis posuere consequat. Vestibulum placerat mollis est et pulvinar. - -Donec rutrum odio ac diam pharetra, id fermentum magna cursus. Pellentesque in dapibus elit, et condimentum orci. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Suspendisse euismod dapibus est, id vestibulum mauris. Nulla facilisi. Nulla cursus gravida nisi. Phasellus vestibulum rutrum lectus, a dignissim mauris hendrerit vitae. In at elementum mauris. Integer vel efficitur velit. Nullam fringilla sapien mi, quis luctus neque efficitur ac. Aenean nec quam dapibus nunc commodo pharetra. Proin sapien mi, fermentum aliquet vulputate non, aliquet porttitor diam. Quisque lacinia, urna et finibus fermentum, nunc lacus vehicula ex, sed congue metus lectus ac quam. Aliquam erat volutpat. Suspendisse sodales, dolor ut tincidunt finibus, augue erat varius tellus, a interdum erat sem at nunc. Vestibulum cursus iaculis sapien, vitae feugiat dui auctor quis. - -Pellentesque nec maximus nulla, eu blandit diam. Maecenas quis arcu ornare, congue ante at, vehicula ipsum. Praesent feugiat mauris rutrum sem fermentum, nec luctus ipsum placerat. Pellentesque placerat ipsum at dignissim fringilla. Vivamus et posuere sem, eget hendrerit felis. Aenean vulputate, augue vel mollis feugiat, justo ipsum mollis dolor, eu mollis elit neque ut ipsum. Orci varius natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Fusce bibendum sem quam, vulputate laoreet mi dapibus imperdiet. Sed a purus non nibh pretium aliquet. Integer eget luctus augue, vitae tincidunt magna. Ut eros enim, egestas eu nulla et, lobortis egestas arcu. Cras id ipsum ac justo lacinia rutrum. Vivamus lectus leo, ultricies sed justo at, pellentesque feugiat magna. Ut sollicitudin neque elit, vel ornare mauris commodo id. - -Duis dapibus orci et sapien finibus finibus. Mauris eleifend, lacus at vestibulum maximus, quam ligula pharetra erat, sit amet dapibus neque elit vitae neque. In bibendum sollicitudin erat, eget ultricies tortor malesuada at. Sed sit amet orci turpis. Donec feugiat ligula nibh, molestie tincidunt lectus elementum id. Donec volutpat maximus nibh, in vulputate felis posuere eu. Cras tincidunt ullamcorper lacus. Phasellus porta lorem auctor, congue magna a, commodo elit. - -Etiam auctor mi quis elit sodales, eu pulvinar arcu condimentum. Aenean imperdiet risus et dapibus tincidunt. Nullam tincidunt dictum dui, sed commodo urna rutrum id. Ut mollis libero vel elit laoreet bibendum. Quisque arcu arcu, tincidunt at ultricies id, vulputate nec metus. In tristique posuere quam sit amet volutpat. Vivamus scelerisque et nunc at dapibus. Fusce finibus libero ut ligula pretium rhoncus. Mauris non elit in arcu finibus imperdiet. Pellentesque nec massa odio. Proin rutrum mauris non sagittis efficitur. Aliquam auctor quam at dignissim faucibus. Ut eget ligula in magna posuere ultricies vitae sit amet turpis. Duis maximus odio nulla. Donec gravida sem tristique tempus scelerisque. - -Interdum et malesuada fames ac ante ipsum primis in faucibus. Fusce pharetra magna vulputate aliquet tempus. Duis id hendrerit arcu. Quisque ut ex elit. Integer velit orci, venenatis ut sapien ac, placerat porttitor dui. Interdum et malesuada fames ac ante ipsum primis in faucibus. Nunc hendrerit cursus diam, hendrerit finibus ipsum scelerisque ut. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. - -Nulla non euismod neque. Phasellus vel sapien eu metus pulvinar rhoncus. Suspendisse eu mollis tellus, quis vestibulum tortor. Maecenas interdum dolor sed nulla fermentum maximus. Donec imperdiet ullamcorper condimentum. Nam quis nibh ante. Praesent quis tellus ut tortor pulvinar blandit sit amet ut sapien. Vestibulum est orci, pellentesque vitae tristique sit amet, tristique non felis. - -Vivamus sodales pellentesque varius. Sed vel tempus ligula. Nulla tristique nisl vel dui facilisis, ac sodales augue hendrerit. Proin augue nisi, vestibulum quis augue nec, sagittis tincidunt velit. Vestibulum euismod, nulla nec sodales faucibus, urna sapien vulputate magna, id varius metus sapien ut neque. Duis in mollis urna, in scelerisque enim. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Nunc condimentum dictum turpis, et egestas neque dapibus eget. Quisque fringilla, dui eu venenatis eleifend, erat nibh lacinia urna, at lacinia lacus sapien eu dui. Duis eu erat ut mi lacinia convallis a sed ex. - -Fusce elit metus, tincidunt nec eleifend a, hendrerit nec ligula. Duis placerat finibus sollicitudin. In euismod porta tellus, in luctus justo bibendum bibendum. Maecenas at magna eleifend lectus tincidunt suscipit ut a ligula. Nulla tempor accumsan felis, fermentum dapibus est eleifend vitae. Mauris urna sem, fringilla at ultricies non, ultrices in arcu. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Nam vehicula nunc at laoreet imperdiet. Nunc tristique ut risus id aliquet. Integer eleifend massa orci. - -Vestibulum sed ante sollicitudin nisi fringilla bibendum nec vel quam. Sed pretium augue eu ligula congue pulvinar. Donec vitae magna tincidunt, pharetra lacus id, convallis nulla. Cras viverra nisl nisl, varius convallis leo vulputate nec. Morbi at consequat dui, sed aliquet metus. Sed suscipit fermentum mollis. Maecenas nec mi sodales, tincidunt purus in, tristique mauris. Orci varius natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Donec interdum mi in velit efficitur, quis ultrices ex imperdiet. Sed vestibulum, magna ut tristique pretium, mi ipsum placerat tellus, non tempor enim augue et ex. Pellentesque eget felis quis ante sodales viverra ac sed lacus. Donec suscipit tempus massa, eget laoreet massa molestie at. - -Aenean fringilla dui non aliquet consectetur. Fusce cursus quam nec orci hendrerit faucibus. Donec consequat suscipit enim, non volutpat lectus auctor interdum. Proin lorem purus, maximus vel orci vitae, suscipit egestas turpis. Donec risus urna, congue a sem eu, aliquet placerat odio. Morbi gravida tristique turpis, quis efficitur enim. Nunc interdum gravida ipsum vel facilisis. Nunc congue finibus sollicitudin. Quisque euismod aliquet lectus et tincidunt. Curabitur ultrices sem ut mi fringilla fermentum. Morbi pretium, nisi sit amet dapibus congue, dolor enim consectetur risus, a interdum ligula odio sed odio. Quisque facilisis, mi at suscipit gravida, nunc sapien cursus justo, ut luctus odio nulla quis leo. Integer condimentum lobortis mauris, non egestas tellus lobortis sit amet. - -In sollicitudin velit ac ante vehicula, vitae varius tortor mollis. In hac habitasse platea dictumst. Quisque et orci lorem. Integer malesuada fringilla luctus. Pellentesque malesuada, mi non lobortis porttitor, ante ligula vulputate ante, nec dictum risus eros sit amet sapien. Nulla aliquam lorem libero, ac varius nulla tristique eget. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Ut pellentesque mauris orci, vel consequat mi varius a. Ut sit amet elit vulputate, lacinia metus non, fermentum nisl. Pellentesque eu nisi sed quam egestas blandit. Duis sit amet lobortis dolor. Donec consectetur sem interdum, tristique elit sit amet, sodales lacus. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Fusce id aliquam augue. Sed pretium congue risus vitae lacinia. Vestibulum non vulputate risus, ut malesuada justo. - -Sed odio elit, consectetur ac mauris quis, consequat commodo libero. Fusce sodales velit vulputate pulvinar fermentum. Donec iaculis nec nisl eget faucibus. Mauris at dictum velit. Donec fermentum lectus eu viverra volutpat. Aliquam consequat facilisis lorem, cursus consequat dui bibendum ullamcorper. Pellentesque nulla magna, imperdiet at magna et, cursus egestas enim. Nullam semper molestie lectus sit amet semper. Duis eget tincidunt est. Integer id neque risus. Integer ultricies hendrerit vestibulum. Donec blandit blandit sagittis. Nunc consectetur vitae nisi consectetur volutpat. - -Nulla id lorem fermentum, efficitur magna a, hendrerit dui. Vivamus sagittis orci gravida, bibendum quam eget, molestie est. Phasellus nec enim tincidunt, volutpat sapien non, laoreet diam. Nulla posuere enim nec porttitor lobortis. Donec auctor odio ut orci eleifend, ut eleifend purus convallis. Interdum et malesuada fames ac ante ipsum primis in faucibus. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Ut hendrerit, purus eget viverra tincidunt, sem magna imperdiet libero, et aliquam turpis neque vitae elit. Maecenas semper varius iaculis. Cras non lorem quis quam bibendum eleifend in et libero. Curabitur at purus mauris. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Vivamus porta diam sed elit eleifend gravida. - -Nulla facilisi. Ut ultricies diam vel diam consectetur, vel porta augue molestie. Fusce interdum sapien et metus facilisis pellentesque. Nulla convallis sem at nunc vehicula facilisis. Nam ac rutrum purus. Nunc bibendum, dolor sit amet tempus ullamcorper, lorem leo tempor sem, id fringilla nunc augue scelerisque augue. Nullam sit amet rutrum nisl. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Donec sed mauris gravida eros vehicula sagittis at eget orci. Cras elementum, eros at accumsan bibendum, libero neque blandit purus, vitae vestibulum libero massa ac nibh. Integer at placerat nulla. Mauris eu eleifend orci. Aliquam consequat ligula vitae erat porta lobortis. Duis fermentum elit ac aliquet ornare. - -Mauris eget cursus tellus, eget sodales purus. Aliquam malesuada, augue id vulputate finibus, nisi ex bibendum nisl, sit amet laoreet quam urna a dolor. Nullam ultricies, sapien eu laoreet consequat, erat eros dignissim diam, ultrices sodales lectus mauris et leo. Morbi lacinia eu ante at tempus. Sed iaculis finibus magna malesuada efficitur. Donec faucibus erat sit amet elementum feugiat. Praesent a placerat nisi. Etiam lacinia gravida diam, et sollicitudin sapien tincidunt ut. - -Maecenas felis quam, tincidunt vitae venenatis scelerisque, viverra vitae odio. Phasellus enim neque, ultricies suscipit malesuada sit amet, vehicula sit amet purus. Nulla placerat sit amet dui vel tincidunt. Nam quis neque vel magna commodo egestas. Vestibulum sagittis rutrum lorem ut congue. Maecenas vel ultrices tellus. Donec efficitur, urna ac consequat iaculis, lorem felis pharetra eros, eget faucibus orci lectus sit amet arcu. - -Ut a tempus nisi. Nulla facilisi. Praesent vulputate maximus mi et dapibus. Sed sit amet libero ac augue hendrerit efficitur in a sapien. Mauris placerat velit sit amet tellus sollicitudin faucibus. Donec egestas a magna ac suscipit. Duis enim sapien, mollis sed egestas et, vestibulum vel leo. - -Proin quis dapibus dui. Donec eu tincidunt nunc. Vivamus eget purus consectetur, maximus ante vitae, tincidunt elit. Aenean mattis dolor a gravida aliquam. Praesent quis tellus id sem maximus vulputate nec sed nulla. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur metus nulla, volutpat volutpat est eu, hendrerit congue erat. Aliquam sollicitudin augue ante. Sed sollicitudin, magna eu consequat elementum, mi augue ullamcorper felis, molestie imperdiet erat metus iaculis est. Proin ac tortor nisi. Pellentesque quis nisi risus. Integer enim sapien, tincidunt quis tortor id, accumsan venenatis mi. Nulla facilisi. - -Cras pretium sit amet quam congue maximus. Morbi lacus libero, imperdiet commodo massa sed, scelerisque placerat libero. Cras nisl nisi, consectetur sed bibendum eu, venenatis at enim. Proin sodales justo at quam aliquam, a consectetur mi ornare. Donec porta ac est sit amet efficitur. Suspendisse vestibulum tortor id neque imperdiet, id lacinia risus vehicula. Phasellus ac eleifend purus. Mauris vel gravida ante. Aliquam vitae lobortis risus. Sed vehicula consectetur tincidunt. Nam et justo vitae purus molestie consequat. Pellentesque ipsum ex, convallis quis blandit non, gravida et urna. Donec diam ligula amet. diff --git a/pkg/sentry/fsimpl/ext/assets/file.txt b/pkg/sentry/fsimpl/ext/assets/file.txt deleted file mode 100644 index 980a0d5f1..000000000 --- a/pkg/sentry/fsimpl/ext/assets/file.txt +++ /dev/null @@ -1 +0,0 @@ -Hello World! diff --git a/pkg/sentry/fsimpl/ext/assets/symlink.txt b/pkg/sentry/fsimpl/ext/assets/symlink.txt deleted file mode 120000 index 4c330738c..000000000 --- a/pkg/sentry/fsimpl/ext/assets/symlink.txt +++ /dev/null @@ -1 +0,0 @@ -file.txt
\ No newline at end of file diff --git a/pkg/sentry/fsimpl/ext/assets/tiny.ext2 b/pkg/sentry/fsimpl/ext/assets/tiny.ext2 Binary files differdeleted file mode 100644 index 381ade9bf..000000000 --- a/pkg/sentry/fsimpl/ext/assets/tiny.ext2 +++ /dev/null diff --git a/pkg/sentry/fsimpl/ext/assets/tiny.ext3 b/pkg/sentry/fsimpl/ext/assets/tiny.ext3 Binary files differdeleted file mode 100644 index 0e97a324c..000000000 --- a/pkg/sentry/fsimpl/ext/assets/tiny.ext3 +++ /dev/null diff --git a/pkg/sentry/fsimpl/ext/assets/tiny.ext4 b/pkg/sentry/fsimpl/ext/assets/tiny.ext4 Binary files differdeleted file mode 100644 index a6859736d..000000000 --- a/pkg/sentry/fsimpl/ext/assets/tiny.ext4 +++ /dev/null diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD deleted file mode 100644 index 6c5a559fd..000000000 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "benchmark_test", - size = "small", - srcs = ["benchmark_test.go"], - deps = [ - "//pkg/context", - "//pkg/fspath", - "//pkg/sentry/contexttest", - "//pkg/sentry/fsimpl/ext", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - ], -) diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go deleted file mode 100644 index 89caee3df..000000000 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// These benchmarks emulate memfs benchmarks. Ext4 images must be created -// before this benchmark is run using the `make_deep_ext4.sh` script at -// /tmp/image-{depth}.ext4 for all the depths tested below. -// -// The benchmark itself cannot run the script because the script requires -// sudo privileges to create the file system images. -package benchmark_test - -import ( - "fmt" - "os" - "runtime" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -var depths = []int{1, 2, 3, 8, 64, 100} - -const filename = "file.txt" - -// setUp opens imagePath as an ext Filesystem and returns all necessary -// elements required to run tests. If error is nil, it also returns a tear -// down function which must be called after the test is run for clean up. -func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesystem, *vfs.VirtualDentry, func(), error) { - f, err := os.Open(imagePath) - if err != nil { - return nil, nil, nil, nil, err - } - - ctx := contexttest.Context(b) - creds := auth.CredentialsFromContext(ctx) - - // Create VFS. - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - return nil, nil, nil, nil, err - } - vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) - if err != nil { - f.Close() - return nil, nil, nil, nil, err - } - - root := mntns.Root() - - tearDown := func() { - root.DecRef() - - if err := f.Close(); err != nil { - b.Fatalf("tearDown failed: %v", err) - } - } - return ctx, vfsObj, &root, tearDown, nil -} - -// mount mounts extfs at the path operation passed. Returns a tear down -// function which must be called after the test is run for clean up. -func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vfs.PathOperation) func() { - b.Helper() - - f, err := os.Open(imagePath) - if err != nil { - b.Fatalf("could not open image at %s: %v", imagePath, err) - } - - ctx := contexttest.Context(b) - creds := auth.CredentialsFromContext(ctx) - - if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ - GetFilesystemOptions: vfs.GetFilesystemOptions{ - InternalData: int(f.Fd()), - }, - }); err != nil { - b.Fatalf("failed to mount tmpfs submount: %v", err) - } - return func() { - if err := f.Close(); err != nil { - b.Fatalf("tearDown failed: %v", err) - } - } -} - -// BenchmarkVFS2Ext4fsStat emulates BenchmarkVFS2MemfsStat. -func BenchmarkVFS2Ext4fsStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx, vfsfs, root, tearDown, err := setUp(b, fmt.Sprintf("/tmp/image-%d.ext4", depth)) - if err != nil { - b.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - creds := auth.CredentialsFromContext(ctx) - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - for i := 1; i <= depth; i++ { - filePathBuilder.WriteString(fmt.Sprintf("%d", i)) - filePathBuilder.WriteByte('/') - } - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{ - Root: *root, - Start: *root, - Path: fspath.Parse(filePath), - FollowFinalSymlink: true, - }, &vfs.StatOptions{}) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - // Sanity check. - if stat.Size > 0 { - b.Fatalf("got wrong file size (%d)", stat.Size) - } - } - }) - } -} - -// BenchmarkVFS2ExtfsMountStat emulates BenchmarkVFS2MemfsMountStat. -func BenchmarkVFS2ExtfsMountStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - // Create root extfs with depth 1 so we can mount extfs again at /1/. - ctx, vfsfs, root, tearDown, err := setUp(b, fmt.Sprintf("/tmp/image-%d.ext4", 1)) - if err != nil { - b.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - creds := auth.CredentialsFromContext(ctx) - mountPointName := "/1/" - pop := vfs.PathOperation{ - Root: *root, - Start: *root, - Path: fspath.Parse(mountPointName), - } - - // Save the mount point for later use. - mountPoint, err := vfsfs.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to mount point: %v", err) - } - defer mountPoint.DecRef() - - // Create extfs submount. - mountTearDown := mount(b, fmt.Sprintf("/tmp/image-%d.ext4", depth), vfsfs, &pop) - defer mountTearDown() - - var filePathBuilder strings.Builder - filePathBuilder.WriteString(mountPointName) - for i := 1; i <= depth; i++ { - filePathBuilder.WriteString(fmt.Sprintf("%d", i)) - filePathBuilder.WriteByte('/') - } - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{ - Root: *root, - Start: *root, - Path: fspath.Parse(filePath), - FollowFinalSymlink: true, - }, &vfs.StatOptions{}) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - // Sanity check. touch(1) always creates files of size 0 (empty). - if stat.Size > 0 { - b.Fatalf("got wrong file size (%d)", stat.Size) - } - } - }) - } -} diff --git a/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh b/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh deleted file mode 100755 index d0910da1f..000000000 --- a/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script creates an ext4 image with $1 depth of directories and a file in -# the inner most directory. The created file is at path /1/2/.../depth/file.txt. -# The ext4 image is written to $2. The image is temporarily mounted at -# /tmp/mountpoint. This script must be run with sudo privileges. - -# Usage: -# sudo bash make_deep_ext4.sh {depth} {output path} - -# Check positional arguments. -if [ "$#" -ne 2 ]; then - echo "Usage: sudo bash make_deep_ext4.sh {depth} {output path}" - exit 1 -fi - -# Make sure depth is a non-negative number. -if ! [[ "$1" =~ ^[0-9]+$ ]]; then - echo "Depth must be a non-negative number." - exit 1 -fi - -# Create a 1 MB filesystem image at the requested output path. -rm -f $2 -fallocate -l 1M $2 -if [ $? -ne 0 ]; then - echo "fallocate failed" - exit $? -fi - -# Convert that blank into an ext4 image. -mkfs.ext4 -j $2 -if [ $? -ne 0 ]; then - echo "mkfs.ext4 failed" - exit $? -fi - -# Mount the image. -MOUNTPOINT=/tmp/mountpoint -mkdir -p $MOUNTPOINT -mount -o loop $2 $MOUNTPOINT -if [ $? -ne 0 ]; then - echo "mount failed" - exit $? -fi - -# Create nested directories and the file. -if [ "$1" -eq 0 ]; then - FILEPATH=$MOUNTPOINT/file.txt -else - FILEPATH=$MOUNTPOINT/$(seq -s '/' 1 $1)/file.txt -fi -mkdir -p $(dirname $FILEPATH) || exit -touch $FILEPATH - -# Clean up. -umount $MOUNTPOINT -rm -rf $MOUNTPOINT diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go deleted file mode 100644 index a2d8c3ad6..000000000 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "io" - "math" - - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/syserror" -) - -const ( - // numDirectBlks is the number of direct blocks in ext block map inodes. - numDirectBlks = 12 -) - -// blockMapFile is a type of regular file which uses direct/indirect block -// addressing to store file data. This was deprecated in ext4. -type blockMapFile struct { - regFile regularFile - - // directBlks are the direct blocks numbers. The physical blocks pointed by - // these holds file data. Contains file blocks 0 to 11. - directBlks [numDirectBlks]uint32 - - // indirectBlk is the physical block which contains (blkSize/4) direct block - // numbers (as uint32 integers). - indirectBlk uint32 - - // doubleIndirectBlk is the physical block which contains (blkSize/4) indirect - // block numbers (as uint32 integers). - doubleIndirectBlk uint32 - - // tripleIndirectBlk is the physical block which contains (blkSize/4) doubly - // indirect block numbers (as uint32 integers). - tripleIndirectBlk uint32 - - // coverage at (i)th index indicates the amount of file data a node at - // height (i) covers. Height 0 is the direct block. - coverage [4]uint64 -} - -// Compiles only if blockMapFile implements io.ReaderAt. -var _ io.ReaderAt = (*blockMapFile)(nil) - -// newBlockMapFile is the blockMapFile constructor. It initializes the file to -// physical blocks map with (at most) the first 12 (direct) blocks. -func newBlockMapFile(regFile regularFile) (*blockMapFile, error) { - file := &blockMapFile{regFile: regFile} - file.regFile.impl = file - - for i := uint(0); i < 4; i++ { - file.coverage[i] = getCoverage(regFile.inode.blkSize, i) - } - - blkMap := regFile.inode.diskInode.Data() - binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks) - binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk) - return file, nil -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) { - if len(dst) == 0 { - return 0, nil - } - - if off < 0 { - return 0, syserror.EINVAL - } - - offset := uint64(off) - size := f.regFile.inode.diskInode.Size() - if offset >= size { - return 0, io.EOF - } - - // dirBlksEnd is the file offset until which direct blocks cover file data. - // Direct blocks cover 0 <= file offset < dirBlksEnd. - dirBlksEnd := numDirectBlks * f.coverage[0] - - // indirBlkEnd is the file offset until which the indirect block covers file - // data. The indirect block covers dirBlksEnd <= file offset < indirBlkEnd. - indirBlkEnd := dirBlksEnd + f.coverage[1] - - // doubIndirBlkEnd is the file offset until which the double indirect block - // covers file data. The double indirect block covers the range - // indirBlkEnd <= file offset < doubIndirBlkEnd. - doubIndirBlkEnd := indirBlkEnd + f.coverage[2] - - read := 0 - toRead := len(dst) - if uint64(toRead)+offset > size { - toRead = int(size - offset) - } - for read < toRead { - var err error - var curR int - - // Figure out which block to delegate the read to. - switch { - case offset < dirBlksEnd: - // Direct block. - curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:]) - case offset < indirBlkEnd: - // Indirect block. - curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:]) - case offset < doubIndirBlkEnd: - // Doubly indirect block. - curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:]) - default: - // Triply indirect block. - curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:]) - } - - read += curR - offset += uint64(curR) - if err != nil { - return read, err - } - } - - if read < len(dst) { - return read, io.EOF - } - return read, nil -} - -// read is the recursive step of the ReadAt function. It relies on knowing the -// current node's location on disk (curPhyBlk) and its height in the block map -// tree. A height of 0 shows that the current node is actually holding file -// data. relFileOff tells the offset from which we need to start to reading -// under the current node. It is completely relative to the current node. -func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, dst []byte) (int, error) { - curPhyBlkOff := int64(curPhyBlk) * int64(f.regFile.inode.blkSize) - if height == 0 { - toRead := int(f.regFile.inode.blkSize - relFileOff) - if len(dst) < toRead { - toRead = len(dst) - } - - n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff)) - if n < toRead { - return n, syserror.EIO - } - return n, nil - } - - childCov := f.coverage[height-1] - startIdx := relFileOff / childCov - endIdx := f.regFile.inode.blkSize / 4 // This is exclusive. - wantEndIdx := (relFileOff + uint64(len(dst))) / childCov - wantEndIdx++ // Make this exclusive. - if wantEndIdx < endIdx { - endIdx = wantEndIdx - } - - read := 0 - curChildOff := relFileOff % childCov - for i := startIdx; i < endIdx; i++ { - var childPhyBlk uint32 - err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) - if err != nil { - return read, err - } - - n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:]) - read += n - if err != nil { - return read, err - } - - curChildOff = 0 - } - - return read, nil -} - -// getCoverage returns the number of bytes a node at the given height covers. -// Height 0 is the file data block itself. Height 1 is the indirect block. -// -// Formula: blkSize * ((blkSize / 4)^height) -func getCoverage(blkSize uint64, height uint) uint64 { - return blkSize * uint64(math.Pow(float64(blkSize/4), float64(height))) -} diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go deleted file mode 100644 index 181727ef7..000000000 --- a/pkg/sentry/fsimpl/ext/block_map_test.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "bytes" - "math/rand" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" -) - -// These consts are for mocking the block map tree. -const ( - mockBMBlkSize = uint32(16) - mockBMDiskSize = 2500 -) - -// TestBlockMapReader stress tests block map reader functionality. It performs -// random length reads from all possible positions in the block map structure. -func TestBlockMapReader(t *testing.T) { - mockBMFile, want := blockMapSetUp(t) - n := len(want) - - for from := 0; from < n; from++ { - got := make([]byte, n-from) - - if read, err := mockBMFile.ReadAt(got, int64(from)); err != nil { - t.Fatalf("file read operation from offset %d to %d only read %d bytes: %v", from, n, read, err) - } - - if diff := cmp.Diff(got, want[from:]); diff != "" { - t.Fatalf("file data from offset %d to %d mismatched (-want +got):\n%s", from, n, diff) - } - } -} - -// blkNumGen is a number generator which gives block numbers for building the -// block map file on disk. It gives unique numbers in a random order which -// facilitates in creating an extremely fragmented filesystem. -type blkNumGen struct { - nums []uint32 -} - -// newBlkNumGen is the blkNumGen constructor. -func newBlkNumGen() *blkNumGen { - blkNums := &blkNumGen{} - lim := mockBMDiskSize / mockBMBlkSize - blkNums.nums = make([]uint32, lim) - for i := uint32(0); i < lim; i++ { - blkNums.nums[i] = i - } - - rand.Shuffle(int(lim), func(i, j int) { - blkNums.nums[i], blkNums.nums[j] = blkNums.nums[j], blkNums.nums[i] - }) - return blkNums -} - -// next returns the next random block number. -func (n *blkNumGen) next() uint32 { - ret := n.nums[0] - n.nums = n.nums[1:] - return ret -} - -// blockMapSetUp creates a mock disk and a block map file. It initializes the -// block map file with 12 direct block, 1 indirect block, 1 double indirect -// block and 1 triple indirect block (basically fill it till the rim). It -// initializes the disk to reflect the inode. Also returns the file data that -// the inode covers and that is written to disk. -func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { - mockDisk := make([]byte, mockBMDiskSize) - regFile := regularFile{ - inode: inode{ - fs: &filesystem{ - dev: bytes.NewReader(mockDisk), - }, - diskInode: &disklayout.InodeNew{ - InodeOld: disklayout.InodeOld{ - SizeLo: getMockBMFileFize(), - }, - }, - blkSize: uint64(mockBMBlkSize), - }, - } - - var fileData []byte - blkNums := newBlkNumGen() - var data []byte - - // Write the direct blocks. - for i := 0; i < numDirectBlks; i++ { - curBlkNum := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, curBlkNum) - fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...) - } - - // Write to indirect block. - indirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, indirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...) - - // Write to indirect block. - doublyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...) - - // Write to indirect block. - triplyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...) - - copy(regFile.inode.diskInode.Data(), data) - - mockFile, err := newBlockMapFile(regFile) - if err != nil { - t.Fatalf("newBlockMapFile failed: %v", err) - } - return mockFile, fileData -} - -// writeFileDataToBlock writes random bytes to the block on disk. -func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkNumGen) []byte { - if height == 0 { - start := blkNum * mockBMBlkSize - end := start + mockBMBlkSize - rand.Read(disk[start:end]) - return disk[start:end] - } - - var fileData []byte - for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 { - curBlkNum := blkNums.next() - copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum)) - fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...) - } - return fileData -} - -// getMockBMFileFize gets the size of the mock block map file which is used for -// testing. -func getMockBMFileFize() uint32 { - return uint32(numDirectBlks*getCoverage(uint64(mockBMBlkSize), 0) + getCoverage(uint64(mockBMBlkSize), 1) + getCoverage(uint64(mockBMBlkSize), 2) + getCoverage(uint64(mockBMBlkSize), 3)) -} diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go deleted file mode 100644 index a080cb189..000000000 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -// dentry implements vfs.DentryImpl. -type dentry struct { - vfsd vfs.Dentry - - // inode is the inode represented by this dentry. Multiple Dentries may - // share a single non-directory Inode (with hard links). inode is - // immutable. - inode *inode -} - -// Compiles only if dentry implements vfs.DentryImpl. -var _ vfs.DentryImpl = (*dentry)(nil) - -// newDentry is the dentry constructor. -func newDentry(in *inode) *dentry { - d := &dentry{ - inode: in, - } - d.vfsd.Init(d) - return d -} - -// IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef() { - d.inode.incRef() -} - -// TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef() bool { - return d.inode.tryIncRef() -} - -// DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { - // FIXME(b/134676337): filesystem.mu may not be locked as required by - // inode.decRef(). - d.inode.decRef() -} diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go deleted file mode 100644 index bd6ede995..000000000 --- a/pkg/sentry/fsimpl/ext/directory.go +++ /dev/null @@ -1,307 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" -) - -// directory represents a directory inode. It holds the childList in memory. -type directory struct { - inode inode - - // mu serializes the changes to childList. - // Lock Order (outermost locks must be taken first): - // directory.mu - // filesystem.mu - mu sync.Mutex - - // childList is a list containing (1) child dirents and (2) fake dirents - // (with diskDirent == nil) that represent the iteration position of - // directoryFDs. childList is used to support directoryFD.IterDirents() - // efficiently. childList is protected by mu. - childList direntList - - // childMap maps the child's filename to the dirent structure stored in - // childList. This adds some data replication but helps in faster path - // traversal. For consistency, key == childMap[key].diskDirent.FileName(). - // Immutable. - childMap map[string]*dirent -} - -// newDirectroy is the directory constructor. -func newDirectroy(inode inode, newDirent bool) (*directory, error) { - file := &directory{inode: inode, childMap: make(map[string]*dirent)} - file.inode.impl = file - - // Initialize childList by reading dirents from the underlying file. - if inode.diskInode.Flags().Index { - // TODO(b/134676337): Support hash tree directories. Currently only the '.' - // and '..' entries are read in. - - // Users cannot navigate this hash tree directory yet. - log.Warningf("hash tree directory being used which is unsupported") - return file, nil - } - - // The dirents are organized in a linear array in the file data. - // Extract the file data and decode the dirents. - regFile, err := newRegularFile(inode) - if err != nil { - return nil, err - } - - // buf is used as scratch space for reading in dirents from disk and - // unmarshalling them into dirent structs. - buf := make([]byte, disklayout.DirentSize) - size := inode.diskInode.Size() - for off, inc := uint64(0), uint64(0); off < size; off += inc { - toRead := size - off - if toRead > disklayout.DirentSize { - toRead = disklayout.DirentSize - } - if n, err := regFile.impl.ReadAt(buf[:toRead], int64(off)); uint64(n) < toRead { - return nil, err - } - - var curDirent dirent - if newDirent { - curDirent.diskDirent = &disklayout.DirentNew{} - } else { - curDirent.diskDirent = &disklayout.DirentOld{} - } - binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent) - - if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 { - // Inode number and name length fields being set to 0 is used to indicate - // an unused dirent. - file.childList.PushBack(&curDirent) - file.childMap[curDirent.diskDirent.FileName()] = &curDirent - } - - // The next dirent is placed exactly after this dirent record on disk. - inc = uint64(curDirent.diskDirent.RecordSize()) - } - - return file, nil -} - -func (i *inode) isDir() bool { - _, ok := i.impl.(*directory) - return ok -} - -// dirent is the directory.childList node. -type dirent struct { - diskDirent disklayout.Dirent - - // direntEntry links dirents into their parent directory.childList. - direntEntry -} - -// directoryFD represents a directory file description. It implements -// vfs.FileDescriptionImpl. -type directoryFD struct { - fileDescription - vfs.DirectoryFileDescriptionDefaultImpl - - // Protected by directory.mu. - iter *dirent - off int64 -} - -// Compiles only if directoryFD implements vfs.FileDescriptionImpl. -var _ vfs.FileDescriptionImpl = (*directoryFD)(nil) - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { - if fd.iter == nil { - return - } - - dir := fd.inode().impl.(*directory) - dir.mu.Lock() - dir.childList.Remove(fd.iter) - dir.mu.Unlock() - fd.iter = nil -} - -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { - extfs := fd.filesystem() - dir := fd.inode().impl.(*directory) - - dir.mu.Lock() - defer dir.mu.Unlock() - - // Ensure that fd.iter exists and is not linked into dir.childList. - var child *dirent - if fd.iter == nil { - // Start iteration at the beginning of dir. - child = dir.childList.Front() - fd.iter = &dirent{} - } else { - // Continue iteration from where we left off. - child = fd.iter.Next() - dir.childList.Remove(fd.iter) - } - for ; child != nil; child = child.Next() { - // Skip other directoryFD iterators. - if child.diskDirent != nil { - childType, ok := child.diskDirent.FileType() - if !ok { - // We will need to read the inode off disk. Do not increment - // ref count here because this inode is not being added to the - // dentry tree. - extfs.mu.Lock() - childInode, err := extfs.getOrCreateInodeLocked(child.diskDirent.Inode()) - extfs.mu.Unlock() - if err != nil { - // Usage of the file description after the error is - // undefined. This implementation would continue reading - // from the next dirent. - fd.off++ - dir.childList.InsertAfter(child, fd.iter) - return err - } - childType = fs.ToInodeType(childInode.diskInode.Mode().FileType()) - } - - if err := cb.Handle(vfs.Dirent{ - Name: child.diskDirent.FileName(), - Type: fs.ToDirentType(childType), - Ino: uint64(child.diskDirent.Inode()), - NextOff: fd.off + 1, - }); err != nil { - dir.childList.InsertBefore(child, fd.iter) - return err - } - fd.off++ - } - } - dir.childList.PushBack(fd.iter) - return nil -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - if whence != linux.SEEK_SET && whence != linux.SEEK_CUR { - return 0, syserror.EINVAL - } - - dir := fd.inode().impl.(*directory) - - dir.mu.Lock() - defer dir.mu.Unlock() - - // Find resulting offset. - if whence == linux.SEEK_CUR { - offset += fd.off - } - - if offset < 0 { - // lseek(2) specifies that EINVAL should be returned if the resulting offset - // is negative. - return 0, syserror.EINVAL - } - - n := int64(len(dir.childMap)) - realWantOff := offset - if realWantOff > n { - realWantOff = n - } - realCurOff := fd.off - if realCurOff > n { - realCurOff = n - } - - // Ensure that fd.iter exists and is linked into dir.childList so we can - // intelligently seek from the optimal position. - if fd.iter == nil { - fd.iter = &dirent{} - dir.childList.PushFront(fd.iter) - } - - // Guess that iterating from the current position is optimal. - child := fd.iter - diff := realWantOff - realCurOff // Shows direction and magnitude of travel. - - // See if starting from the beginning or end is better. - abDiff := diff - if diff < 0 { - abDiff = -diff - } - if abDiff > realWantOff { - // Starting from the beginning is best. - child = dir.childList.Front() - diff = realWantOff - } else if abDiff > (n - realWantOff) { - // Starting from the end is best. - child = dir.childList.Back() - // (n - 1) because the last non-nil dirent represents the (n-1)th offset. - diff = realWantOff - (n - 1) - } - - for child != nil { - // Skip other directoryFD iterators. - if child.diskDirent != nil { - if diff == 0 { - if child != fd.iter { - dir.childList.Remove(fd.iter) - dir.childList.InsertBefore(child, fd.iter) - } - - fd.off = offset - return offset, nil - } - - if diff < 0 { - diff++ - child = child.Prev() - } else { - diff-- - child = child.Next() - } - continue - } - - if diff < 0 { - child = child.Prev() - } else { - child = child.Next() - } - } - - // Reaching here indicates that the offset is beyond the end of the childList. - dir.childList.Remove(fd.iter) - dir.childList.PushBack(fd.iter) - fd.off = offset - return offset, nil -} - -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - // mmap(2) specifies that EACCESS should be returned for non-regular file fds. - return syserror.EACCES -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD deleted file mode 100644 index 9bd9c76c0..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "disklayout", - srcs = [ - "block_group.go", - "block_group_32.go", - "block_group_64.go", - "dirent.go", - "dirent_new.go", - "dirent_old.go", - "disklayout.go", - "extent.go", - "inode.go", - "inode_new.go", - "inode_old.go", - "superblock.go", - "superblock_32.go", - "superblock_64.go", - "superblock_old.go", - "test_utils.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - ], -) - -go_test( - name = "disklayout_test", - size = "small", - srcs = [ - "block_group_test.go", - "dirent_test.go", - "extent_test.go", - "inode_test.go", - "superblock_test.go", - ], - library = ":disklayout", - deps = ["//pkg/sentry/kernel/time"], -) diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go deleted file mode 100644 index ad6f4fef8..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -// BlockGroup represents a Linux ext block group descriptor. An ext file system -// is split into a series of block groups. This provides an access layer to -// information needed to access and use a block group. -// -// Location: -// - The block group descriptor table is always placed in the blocks -// immediately after the block containing the superblock. -// - The 1st block group descriptor in the original table is in the -// (sb.FirstDataBlock() + 1)th block. -// - See SuperBlock docs to see where the block group descriptor table is -// replicated. -// - sb.BgDescSize() must be used as the block group descriptor entry size -// while reading the table from disk. -// -// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors. -type BlockGroup interface { - // InodeTable returns the absolute block number of the block containing the - // inode table. This points to an array of Inode structs. Inode tables are - // statically allocated at mkfs time. The superblock records the number of - // inodes per group (length of this table) and the size of each inode struct. - InodeTable() uint64 - - // BlockBitmap returns the absolute block number of the block containing the - // block bitmap. This bitmap tracks the usage of data blocks within this block - // group and has its own checksum. - BlockBitmap() uint64 - - // InodeBitmap returns the absolute block number of the block containing the - // inode bitmap. This bitmap tracks the usage of this group's inode table - // entries and has its own checksum. - InodeBitmap() uint64 - - // ExclusionBitmap returns the absolute block number of the snapshot exclusion - // bitmap. - ExclusionBitmap() uint64 - - // FreeBlocksCount returns the number of free blocks in the group. - FreeBlocksCount() uint32 - - // FreeInodesCount returns the number of free inodes in the group. - FreeInodesCount() uint32 - - // DirectoryCount returns the number of inodes that represent directories - // under this block group. - DirectoryCount() uint32 - - // UnusedInodeCount returns the number of unused inodes beyond the last used - // inode in this group's inode table. As a result, we needn’t scan past the - // (InodesPerGroup - UnusedInodeCount())th entry in the inode table. - UnusedInodeCount() uint32 - - // BlockBitmapChecksum returns the block bitmap checksum. This is calculated - // using crc32c(FS UUID + group number + entire bitmap). - BlockBitmapChecksum() uint32 - - // InodeBitmapChecksum returns the inode bitmap checksum. This is calculated - // using crc32c(FS UUID + group number + entire bitmap). - InodeBitmapChecksum() uint32 - - // Checksum returns this block group's checksum. - // - // If SbMetadataCsum feature is set: - // - checksum is crc32c(FS UUID + group number + group descriptor - // structure) & 0xFFFF. - // - // If SbGdtCsum feature is set: - // - checksum is crc16(FS UUID + group number + group descriptor - // structure). - // - // SbMetadataCsum and SbGdtCsum should not be both set. - // If they are, Linux warns and asks to run fsck. - Checksum() uint16 - - // Flags returns BGFlags which represents the block group flags. - Flags() BGFlags -} - -// These are the different block group flags. -const ( - // BgInodeUninit indicates that inode table and bitmap are not initialized. - BgInodeUninit uint16 = 0x1 - - // BgBlockUninit indicates that block bitmap is not initialized. - BgBlockUninit uint16 = 0x2 - - // BgInodeZeroed indicates that inode table is zeroed. - BgInodeZeroed uint16 = 0x4 -) - -// BGFlags represents all the different combinations of block group flags. -type BGFlags struct { - InodeUninit bool - BlockUninit bool - InodeZeroed bool -} - -// ToInt converts a BGFlags struct back to its 16-bit representation. -func (f BGFlags) ToInt() uint16 { - var res uint16 - - if f.InodeUninit { - res |= BgInodeUninit - } - if f.BlockUninit { - res |= BgBlockUninit - } - if f.InodeZeroed { - res |= BgInodeZeroed - } - - return res -} - -// BGFlagsFromInt converts the 16-bit flag representation to a BGFlags struct. -func BGFlagsFromInt(flags uint16) BGFlags { - return BGFlags{ - InodeUninit: flags&BgInodeUninit > 0, - BlockUninit: flags&BgBlockUninit > 0, - InodeZeroed: flags&BgInodeZeroed > 0, - } -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go deleted file mode 100644 index 3e16c76db..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -// BlockGroup32Bit emulates the first half of struct ext4_group_desc in -// fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and -// 32-bit ext4 filesystems. It implements BlockGroup interface. -type BlockGroup32Bit struct { - BlockBitmapLo uint32 - InodeBitmapLo uint32 - InodeTableLo uint32 - FreeBlocksCountLo uint16 - FreeInodesCountLo uint16 - UsedDirsCountLo uint16 - FlagsRaw uint16 - ExcludeBitmapLo uint32 - BlockBitmapChecksumLo uint16 - InodeBitmapChecksumLo uint16 - ItableUnusedLo uint16 - ChecksumRaw uint16 -} - -// Compiles only if BlockGroup32Bit implements BlockGroup. -var _ BlockGroup = (*BlockGroup32Bit)(nil) - -// InodeTable implements BlockGroup.InodeTable. -func (bg *BlockGroup32Bit) InodeTable() uint64 { return uint64(bg.InodeTableLo) } - -// BlockBitmap implements BlockGroup.BlockBitmap. -func (bg *BlockGroup32Bit) BlockBitmap() uint64 { return uint64(bg.BlockBitmapLo) } - -// InodeBitmap implements BlockGroup.InodeBitmap. -func (bg *BlockGroup32Bit) InodeBitmap() uint64 { return uint64(bg.InodeBitmapLo) } - -// ExclusionBitmap implements BlockGroup.ExclusionBitmap. -func (bg *BlockGroup32Bit) ExclusionBitmap() uint64 { return uint64(bg.ExcludeBitmapLo) } - -// FreeBlocksCount implements BlockGroup.FreeBlocksCount. -func (bg *BlockGroup32Bit) FreeBlocksCount() uint32 { return uint32(bg.FreeBlocksCountLo) } - -// FreeInodesCount implements BlockGroup.FreeInodesCount. -func (bg *BlockGroup32Bit) FreeInodesCount() uint32 { return uint32(bg.FreeInodesCountLo) } - -// DirectoryCount implements BlockGroup.DirectoryCount. -func (bg *BlockGroup32Bit) DirectoryCount() uint32 { return uint32(bg.UsedDirsCountLo) } - -// UnusedInodeCount implements BlockGroup.UnusedInodeCount. -func (bg *BlockGroup32Bit) UnusedInodeCount() uint32 { return uint32(bg.ItableUnusedLo) } - -// BlockBitmapChecksum implements BlockGroup.BlockBitmapChecksum. -func (bg *BlockGroup32Bit) BlockBitmapChecksum() uint32 { return uint32(bg.BlockBitmapChecksumLo) } - -// InodeBitmapChecksum implements BlockGroup.InodeBitmapChecksum. -func (bg *BlockGroup32Bit) InodeBitmapChecksum() uint32 { return uint32(bg.InodeBitmapChecksumLo) } - -// Checksum implements BlockGroup.Checksum. -func (bg *BlockGroup32Bit) Checksum() uint16 { return bg.ChecksumRaw } - -// Flags implements BlockGroup.Flags. -func (bg *BlockGroup32Bit) Flags() BGFlags { return BGFlagsFromInt(bg.FlagsRaw) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go deleted file mode 100644 index 9a809197a..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -// BlockGroup64Bit emulates struct ext4_group_desc in fs/ext4/ext4.h. -// It is the block group descriptor struct for 64-bit ext4 filesystems. -// It implements BlockGroup interface. It is an extension of the 32-bit -// version of BlockGroup. -type BlockGroup64Bit struct { - // We embed the 32-bit struct here because 64-bit version is just an extension - // of the 32-bit version. - BlockGroup32Bit - - // 64-bit specific fields. - BlockBitmapHi uint32 - InodeBitmapHi uint32 - InodeTableHi uint32 - FreeBlocksCountHi uint16 - FreeInodesCountHi uint16 - UsedDirsCountHi uint16 - ItableUnusedHi uint16 - ExcludeBitmapHi uint32 - BlockBitmapChecksumHi uint16 - InodeBitmapChecksumHi uint16 - _ uint32 // Padding to 64 bytes. -} - -// Compiles only if BlockGroup64Bit implements BlockGroup. -var _ BlockGroup = (*BlockGroup64Bit)(nil) - -// Methods to override. Checksum() and Flags() are not overridden. - -// InodeTable implements BlockGroup.InodeTable. -func (bg *BlockGroup64Bit) InodeTable() uint64 { - return (uint64(bg.InodeTableHi) << 32) | uint64(bg.InodeTableLo) -} - -// BlockBitmap implements BlockGroup.BlockBitmap. -func (bg *BlockGroup64Bit) BlockBitmap() uint64 { - return (uint64(bg.BlockBitmapHi) << 32) | uint64(bg.BlockBitmapLo) -} - -// InodeBitmap implements BlockGroup.InodeBitmap. -func (bg *BlockGroup64Bit) InodeBitmap() uint64 { - return (uint64(bg.InodeBitmapHi) << 32) | uint64(bg.InodeBitmapLo) -} - -// ExclusionBitmap implements BlockGroup.ExclusionBitmap. -func (bg *BlockGroup64Bit) ExclusionBitmap() uint64 { - return (uint64(bg.ExcludeBitmapHi) << 32) | uint64(bg.ExcludeBitmapLo) -} - -// FreeBlocksCount implements BlockGroup.FreeBlocksCount. -func (bg *BlockGroup64Bit) FreeBlocksCount() uint32 { - return (uint32(bg.FreeBlocksCountHi) << 16) | uint32(bg.FreeBlocksCountLo) -} - -// FreeInodesCount implements BlockGroup.FreeInodesCount. -func (bg *BlockGroup64Bit) FreeInodesCount() uint32 { - return (uint32(bg.FreeInodesCountHi) << 16) | uint32(bg.FreeInodesCountLo) -} - -// DirectoryCount implements BlockGroup.DirectoryCount. -func (bg *BlockGroup64Bit) DirectoryCount() uint32 { - return (uint32(bg.UsedDirsCountHi) << 16) | uint32(bg.UsedDirsCountLo) -} - -// UnusedInodeCount implements BlockGroup.UnusedInodeCount. -func (bg *BlockGroup64Bit) UnusedInodeCount() uint32 { - return (uint32(bg.ItableUnusedHi) << 16) | uint32(bg.ItableUnusedLo) -} - -// BlockBitmapChecksum implements BlockGroup.BlockBitmapChecksum. -func (bg *BlockGroup64Bit) BlockBitmapChecksum() uint32 { - return (uint32(bg.BlockBitmapChecksumHi) << 16) | uint32(bg.BlockBitmapChecksumLo) -} - -// InodeBitmapChecksum implements BlockGroup.InodeBitmapChecksum. -func (bg *BlockGroup64Bit) InodeBitmapChecksum() uint32 { - return (uint32(bg.InodeBitmapChecksumHi) << 16) | uint32(bg.InodeBitmapChecksumLo) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go deleted file mode 100644 index 0ef4294c0..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "testing" -) - -// TestBlockGroupSize tests that the block group descriptor structs are of the -// correct size. -func TestBlockGroupSize(t *testing.T) { - assertSize(t, BlockGroup32Bit{}, 32) - assertSize(t, BlockGroup64Bit{}, 64) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go deleted file mode 100644 index 417b6cf65..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -const ( - // MaxFileName is the maximum length of an ext fs file's name. - MaxFileName = 255 - - // DirentSize is the size of ext dirent structures. - DirentSize = 263 -) - -var ( - // inodeTypeByFileType maps ext4 file types to vfs inode types. - // - // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#ftype. - inodeTypeByFileType = map[uint8]fs.InodeType{ - 0: fs.Anonymous, - 1: fs.RegularFile, - 2: fs.Directory, - 3: fs.CharacterDevice, - 4: fs.BlockDevice, - 5: fs.Pipe, - 6: fs.Socket, - 7: fs.Symlink, - } -) - -// The Dirent interface should be implemented by structs representing ext -// directory entries. These are for the linear classical directories which -// just store a list of dirent structs. A directory is a series of data blocks -// where is each data block contains a linear array of dirents. The last entry -// of the block has a record size that takes it to the end of the block. The -// end of the directory is when you read dirInode.Size() bytes from the blocks. -// -// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories. -type Dirent interface { - // Inode returns the absolute inode number of the underlying inode. - // Inode number 0 signifies an unused dirent. - Inode() uint32 - - // RecordSize returns the record length of this dirent on disk. The next - // dirent in the dirent list should be read after these many bytes from - // the current dirent. Must be a multiple of 4. - RecordSize() uint16 - - // FileName returns the name of the file. Can be at most 255 is length. - FileName() string - - // FileType returns the inode type of the underlying inode. This is a - // performance hack so that we do not have to read the underlying inode struct - // to know the type of inode. This will only work when the SbDirentFileType - // feature is set. If not, the second returned value will be false indicating - // that user code has to use the inode mode to extract the file type. - FileType() (fs.InodeType, bool) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go deleted file mode 100644 index 29ae4a5c2..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// DirentNew represents the ext4 directory entry struct. This emulates Linux's -// ext4_dir_entry_2 struct. The FileName can not be more than 255 bytes so we -// only need 8 bits to store the NameLength. As a result, NameLength has been -// shortened and the other 8 bits are used to encode the file type. Use the -// FileTypeRaw field only if the SbDirentFileType feature is set. -// -// Note: This struct can be of variable size on disk. The one described below -// is of maximum size and the FileName beyond NameLength bytes might contain -// garbage. -type DirentNew struct { - InodeNumber uint32 - RecordLength uint16 - NameLength uint8 - FileTypeRaw uint8 - FileNameRaw [MaxFileName]byte -} - -// Compiles only if DirentNew implements Dirent. -var _ Dirent = (*DirentNew)(nil) - -// Inode implements Dirent.Inode. -func (d *DirentNew) Inode() uint32 { return d.InodeNumber } - -// RecordSize implements Dirent.RecordSize. -func (d *DirentNew) RecordSize() uint16 { return d.RecordLength } - -// FileName implements Dirent.FileName. -func (d *DirentNew) FileName() string { - return string(d.FileNameRaw[:d.NameLength]) -} - -// FileType implements Dirent.FileType. -func (d *DirentNew) FileType() (fs.InodeType, bool) { - if inodeType, ok := inodeTypeByFileType[d.FileTypeRaw]; ok { - return inodeType, true - } - - panic(fmt.Sprintf("unknown file type %v", d.FileTypeRaw)) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go deleted file mode 100644 index 6fff12a6e..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import "gvisor.dev/gvisor/pkg/sentry/fs" - -// DirentOld represents the old directory entry struct which does not contain -// the file type. This emulates Linux's ext4_dir_entry struct. -// -// Note: This struct can be of variable size on disk. The one described below -// is of maximum size and the FileName beyond NameLength bytes might contain -// garbage. -type DirentOld struct { - InodeNumber uint32 - RecordLength uint16 - NameLength uint16 - FileNameRaw [MaxFileName]byte -} - -// Compiles only if DirentOld implements Dirent. -var _ Dirent = (*DirentOld)(nil) - -// Inode implements Dirent.Inode. -func (d *DirentOld) Inode() uint32 { return d.InodeNumber } - -// RecordSize implements Dirent.RecordSize. -func (d *DirentOld) RecordSize() uint16 { return d.RecordLength } - -// FileName implements Dirent.FileName. -func (d *DirentOld) FileName() string { - return string(d.FileNameRaw[:d.NameLength]) -} - -// FileType implements Dirent.FileType. -func (d *DirentOld) FileType() (fs.InodeType, bool) { - return fs.Anonymous, false -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go deleted file mode 100644 index 934919f8a..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "testing" -) - -// TestDirentSize tests that the dirent structs are of the correct -// size. -func TestDirentSize(t *testing.T) { - assertSize(t, DirentOld{}, uintptr(DirentSize)) - assertSize(t, DirentNew{}, uintptr(DirentSize)) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go deleted file mode 100644 index bdf4e2132..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package disklayout provides Linux ext file system's disk level structures -// which can be directly read into from the underlying device. Structs aim to -// emulate structures `exactly` how they are layed out on disk. -// -// This library aims to be compatible with all ext(2/3/4) systems so it -// provides a generic interface for all major structures and various -// implementations (for different versions). The user code is responsible for -// using appropriate implementations based on the underlying device. -// -// Interfacing all major structures here serves a few purposes: -// - Abstracts away the complexity of the underlying structure from client -// code. The client only has to figure out versioning on set up and then -// can use these as black boxes and pass it higher up the stack. -// - Having pointer receivers forces the user to use pointers to these -// heavy structs. Hence, prevents the client code from unintentionally -// copying these by value while passing the interface around. -// - Version-based implementation selection is resolved on set up hence -// avoiding per call overhead of choosing implementation. -// - All interface methods are pretty light weight (do not take in any -// parameters by design). Passing pointer arguments to interface methods -// can lead to heap allocation as the compiler won't be able to perform -// escape analysis on an unknown implementation at compile time. -// -// Notes: -// - All fields in these structs are exported because binary.Read would -// panic otherwise. -// - All structures on disk are in little-endian order. Only jbd2 (journal) -// structures are in big-endian order. -// - All OS dependent fields in these structures will be interpretted using -// the Linux version of that field. -// - The suffix `Lo` in field names stands for lower bits of that field. -// - The suffix `Hi` in field names stands for upper bits of that field. -// - The suffix `Raw` has been added to indicate that the field is not split -// into Lo and Hi fields and also to resolve name collision with the -// respective interface. -package disklayout diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go deleted file mode 100644 index 4110649ab..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/extent.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -// Extents were introduced in ext4 and provide huge performance gains in terms -// data locality and reduced metadata block usage. Extents are organized in -// extent trees. The root node is contained in inode.BlocksRaw. -// -// Terminology: -// - Physical Block: -// Filesystem data block which is addressed normally wrt the entire -// filesystem (addressed with 48 bits). -// -// - File Block: -// Data block containing *only* file data and addressed wrt to the file -// with only 32 bits. The (i)th file block contains file data from -// byte (i * sb.BlockSize()) to ((i+1) * sb.BlockSize()). - -const ( - // ExtentHeaderSize is the size of the header of an extent tree node. - ExtentHeaderSize = 12 - - // ExtentEntrySize is the size of an entry in an extent tree node. - // This size is the same for both leaf and internal nodes. - ExtentEntrySize = 12 - - // ExtentMagic is the magic number which must be present in the header. - ExtentMagic = 0xf30a -) - -// ExtentEntryPair couples an in-memory ExtendNode with the ExtentEntry that -// points to it. We want to cache these structs in memory to avoid repeated -// disk reads. -// -// Note: This struct itself does not represent an on-disk struct. -type ExtentEntryPair struct { - // Entry points to the child node on disk. - Entry ExtentEntry - // Node points to child node in memory. Is nil if the current node is a leaf. - Node *ExtentNode -} - -// ExtentNode represents an extent tree node. For internal nodes, all Entries -// will be ExtendIdxs. For leaf nodes, they will all be Extents. -// -// Note: This struct itself does not represent an on-disk struct. -type ExtentNode struct { - Header ExtentHeader - Entries []ExtentEntryPair -} - -// ExtentEntry represents an extent tree node entry. The entry can either be -// an ExtentIdx or Extent itself. This exists to simplify navigation logic. -type ExtentEntry interface { - // FileBlock returns the first file block number covered by this entry. - FileBlock() uint32 - - // PhysicalBlock returns the child physical block that this entry points to. - PhysicalBlock() uint64 -} - -// ExtentHeader emulates the ext4_extent_header struct in ext4. Each extent -// tree node begins with this and is followed by `NumEntries` number of: -// - Extent if `Depth` == 0 -// - ExtentIdx otherwise -type ExtentHeader struct { - // Magic in the extent magic number, must be 0xf30a. - Magic uint16 - - // NumEntries indicates the number of valid entries following the header. - NumEntries uint16 - - // MaxEntries that could follow the header. Used while adding entries. - MaxEntries uint16 - - // Height represents the distance of this node from the farthest leaf. Please - // note that Linux incorrectly calls this `Depth` (which means the distance - // of the node from the root). - Height uint16 - _ uint32 -} - -// ExtentIdx emulates the ext4_extent_idx struct in ext4. Only present in -// internal nodes. Sorted in ascending order based on FirstFileBlock since -// Linux does a binary search on this. This points to a block containing the -// child node. -type ExtentIdx struct { - FirstFileBlock uint32 - ChildBlockLo uint32 - ChildBlockHi uint16 - _ uint16 -} - -// Compiles only if ExtentIdx implements ExtentEntry. -var _ ExtentEntry = (*ExtentIdx)(nil) - -// FileBlock implements ExtentEntry.FileBlock. -func (ei *ExtentIdx) FileBlock() uint32 { - return ei.FirstFileBlock -} - -// PhysicalBlock implements ExtentEntry.PhysicalBlock. It returns the -// physical block number of the child block. -func (ei *ExtentIdx) PhysicalBlock() uint64 { - return (uint64(ei.ChildBlockHi) << 32) | uint64(ei.ChildBlockLo) -} - -// Extent represents the ext4_extent struct in ext4. Only present in leaf -// nodes. Sorted in ascending order based on FirstFileBlock since Linux does a -// binary search on this. This points to an array of data blocks containing the -// file data. It covers `Length` data blocks starting from `StartBlock`. -type Extent struct { - FirstFileBlock uint32 - Length uint16 - StartBlockHi uint16 - StartBlockLo uint32 -} - -// Compiles only if Extent implements ExtentEntry. -var _ ExtentEntry = (*Extent)(nil) - -// FileBlock implements ExtentEntry.FileBlock. -func (e *Extent) FileBlock() uint32 { - return e.FirstFileBlock -} - -// PhysicalBlock implements ExtentEntry.PhysicalBlock. It returns the -// physical block number of the first data block this extent covers. -func (e *Extent) PhysicalBlock() uint64 { - return (uint64(e.StartBlockHi) << 32) | uint64(e.StartBlockLo) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go deleted file mode 100644 index 8762b90db..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "testing" -) - -// TestExtentSize tests that the extent structs are of the correct -// size. -func TestExtentSize(t *testing.T) { - assertSize(t, ExtentHeader{}, ExtentHeaderSize) - assertSize(t, ExtentIdx{}, ExtentEntrySize) - assertSize(t, Extent{}, ExtentEntrySize) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go deleted file mode 100644 index 88ae913f5..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/inode.go +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/time" -) - -// Special inodes. See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#special-inodes. -const ( - // RootDirInode is the inode number of the root directory inode. - RootDirInode = 2 -) - -// The Inode interface must be implemented by structs representing ext inodes. -// The inode stores all the metadata pertaining to the file (except for the -// file name which is held by the directory entry). It does NOT expose all -// fields and should be extended if need be. -// -// Some file systems (e.g. FAT) use the directory entry to store all this -// information. Ext file systems do not so that they can support hard links. -// However, ext4 cheats a little bit and duplicates the file type in the -// directory entry for performance gains. -// -// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes. -type Inode interface { - // Mode returns the linux file mode which is majorly used to extract - // information like: - // - File permissions (read/write/execute by user/group/others). - // - Sticky, set UID and GID bits. - // - File type. - // - // Masks to extract this information are provided in pkg/abi/linux/file.go. - Mode() linux.FileMode - - // UID returns the owner UID. - UID() auth.KUID - - // GID returns the owner GID. - GID() auth.KGID - - // Size returns the size of the file in bytes. - Size() uint64 - - // InodeSize returns the size of this inode struct in bytes. - // In ext2 and ext3, the inode struct and inode disk record size was fixed at - // 128 bytes. Ext4 makes it possible for the inode struct to be bigger. - // However, accessing any field beyond the 128 bytes marker must be verified - // using this method. - InodeSize() uint16 - - // AccessTime returns the last access time. Shows when the file was last read. - // - // If InExtendedAttr is set, then this should NOT be used because the - // underlying field is used to store the extended attribute value checksum. - AccessTime() time.Time - - // ChangeTime returns the last change time. Shows when the file meta data - // (like permissions) was last changed. - // - // If InExtendedAttr is set, then this should NOT be used because the - // underlying field is used to store the lower 32 bits of the attribute - // value’s reference count. - ChangeTime() time.Time - - // ModificationTime returns the last modification time. Shows when the file - // content was last modified. - // - // If InExtendedAttr is set, then this should NOT be used because - // the underlying field contains the number of the inode that owns the - // extended attribute. - ModificationTime() time.Time - - // DeletionTime returns the deletion time. Inodes are marked as deleted by - // writing to the underlying field. FS tools can restore files until they are - // actually overwritten. - DeletionTime() time.Time - - // LinksCount returns the number of hard links to this inode. - // - // Normally there is an upper limit on the number of hard links: - // - ext2/ext3 = 32,000 - // - ext4 = 65,000 - // - // This implies that an ext4 directory cannot have more than 64,998 - // subdirectories because each subdirectory will have a hard link to the - // directory via the `..` entry. The directory has hard link via the `.` entry - // of its own. And finally the inode is initiated with 1 hard link (itself). - // - // The underlying value is reset to 1 if all the following hold: - // - Inode is a directory. - // - SbDirNlink is enabled. - // - Number of hard links is incremented past 64,999. - // Hard link value of 1 for a directory would indicate that the number of hard - // links is unknown because a directory can have minimum 2 hard links (itself - // and `.` entry). - LinksCount() uint16 - - // Flags returns InodeFlags which represents the inode flags. - Flags() InodeFlags - - // Data returns the underlying inode.i_block array as a slice so it's - // modifiable. This field is special and is used to store various kinds of - // things depending on the filesystem version and inode type. The underlying - // field name in Linux is a little misleading. - // - In ext2/ext3, it contains the block map. - // - In ext4, it contains the extent tree root node. - // - For inline files, it contains the file contents. - // - For symlinks, it contains the link path (if it fits here). - // - // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#the-contents-of-inode-i-block. - Data() []byte -} - -// Inode flags. This is not comprehensive and flags which were not used in -// the Linux kernel have been excluded. -const ( - // InSync indicates that all writes to the file must be synchronous. - InSync = 0x8 - - // InImmutable indicates that this file is immutable. - InImmutable = 0x10 - - // InAppend indicates that this file can only be appended to. - InAppend = 0x20 - - // InNoDump indicates that teh dump(1) utility should not dump this file. - InNoDump = 0x40 - - // InNoAccessTime indicates that the access time of this inode must not be - // updated. - InNoAccessTime = 0x80 - - // InIndex indicates that this directory has hashed indexes. - InIndex = 0x1000 - - // InJournalData indicates that file data must always be written through a - // journal device. - InJournalData = 0x4000 - - // InDirSync indicates that all the directory entiry data must be written - // synchronously. - InDirSync = 0x10000 - - // InTopDir indicates that this inode is at the top of the directory hierarchy. - InTopDir = 0x20000 - - // InHugeFile indicates that this is a huge file. - InHugeFile = 0x40000 - - // InExtents indicates that this inode uses extents. - InExtents = 0x80000 - - // InExtendedAttr indicates that this inode stores a large extended attribute - // value in its data blocks. - InExtendedAttr = 0x200000 - - // InInline indicates that this inode has inline data. - InInline = 0x10000000 - - // InReserved indicates that this inode is reserved for the ext4 library. - InReserved = 0x80000000 -) - -// InodeFlags represents all possible combinations of inode flags. It aims to -// cover the bit masks and provide a more user-friendly interface. -type InodeFlags struct { - Sync bool - Immutable bool - Append bool - NoDump bool - NoAccessTime bool - Index bool - JournalData bool - DirSync bool - TopDir bool - HugeFile bool - Extents bool - ExtendedAttr bool - Inline bool - Reserved bool -} - -// ToInt converts inode flags back to its 32-bit rep. -func (f InodeFlags) ToInt() uint32 { - var res uint32 - - if f.Sync { - res |= InSync - } - if f.Immutable { - res |= InImmutable - } - if f.Append { - res |= InAppend - } - if f.NoDump { - res |= InNoDump - } - if f.NoAccessTime { - res |= InNoAccessTime - } - if f.Index { - res |= InIndex - } - if f.JournalData { - res |= InJournalData - } - if f.DirSync { - res |= InDirSync - } - if f.TopDir { - res |= InTopDir - } - if f.HugeFile { - res |= InHugeFile - } - if f.Extents { - res |= InExtents - } - if f.ExtendedAttr { - res |= InExtendedAttr - } - if f.Inline { - res |= InInline - } - if f.Reserved { - res |= InReserved - } - - return res -} - -// InodeFlagsFromInt converts the integer representation of inode flags to -// a InodeFlags struct. -func InodeFlagsFromInt(f uint32) InodeFlags { - return InodeFlags{ - Sync: f&InSync > 0, - Immutable: f&InImmutable > 0, - Append: f&InAppend > 0, - NoDump: f&InNoDump > 0, - NoAccessTime: f&InNoAccessTime > 0, - Index: f&InIndex > 0, - JournalData: f&InJournalData > 0, - DirSync: f&InDirSync > 0, - TopDir: f&InTopDir > 0, - HugeFile: f&InHugeFile > 0, - Extents: f&InExtents > 0, - ExtendedAttr: f&InExtendedAttr > 0, - Inline: f&InInline > 0, - Reserved: f&InReserved > 0, - } -} - -// These masks define how users can view/modify inode flags. The rest of the -// flags are for internal kernel usage only. -const ( - InUserReadFlagMask = 0x4BDFFF - InUserWriteFlagMask = 0x4B80FF -) diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go deleted file mode 100644 index 8f9f574ce..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import "gvisor.dev/gvisor/pkg/sentry/kernel/time" - -// InodeNew represents ext4 inode structure which can be bigger than -// OldInodeSize. The actual size of this struct should be determined using -// inode.ExtraInodeSize. Accessing any field here should be verified with the -// actual size. The extra space between the end of the inode struct and end of -// the inode record can be used to store extended attr. -// -// If the TimeExtra fields are in scope, the lower 2 bits of those are used -// to extend their counter part to be 34 bits wide; the rest (upper) 30 bits -// are used to provide nanoscond precision. Hence, these timestamps will now -// overflow in May 2446. -// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps. -type InodeNew struct { - InodeOld - - ExtraInodeSize uint16 - ChecksumHi uint16 - ChangeTimeExtra uint32 - ModificationTimeExtra uint32 - AccessTimeExtra uint32 - CreationTime uint32 - CreationTimeExtra uint32 - VersionHi uint32 - ProjectID uint32 -} - -// Compiles only if InodeNew implements Inode. -var _ Inode = (*InodeNew)(nil) - -// fromExtraTime decodes the extra time and constructs the kernel time struct -// with nanosecond precision. -func fromExtraTime(lo int32, extra uint32) time.Time { - // See description above InodeNew for format. - seconds := (int64(extra&0x3) << 32) + int64(lo) - nanoseconds := int64(extra >> 2) - return time.FromUnix(seconds, nanoseconds) -} - -// Only override methods which change due to ext4 specific fields. - -// Size implements Inode.Size. -func (in *InodeNew) Size() uint64 { - return (uint64(in.SizeHi) << 32) | uint64(in.SizeLo) -} - -// InodeSize implements Inode.InodeSize. -func (in *InodeNew) InodeSize() uint16 { - return OldInodeSize + in.ExtraInodeSize -} - -// ChangeTime implements Inode.ChangeTime. -func (in *InodeNew) ChangeTime() time.Time { - // Apply new timestamp logic if inode.ChangeTimeExtra is in scope. - if in.ExtraInodeSize >= 8 { - return fromExtraTime(in.ChangeTimeRaw, in.ChangeTimeExtra) - } - - return in.InodeOld.ChangeTime() -} - -// ModificationTime implements Inode.ModificationTime. -func (in *InodeNew) ModificationTime() time.Time { - // Apply new timestamp logic if inode.ModificationTimeExtra is in scope. - if in.ExtraInodeSize >= 12 { - return fromExtraTime(in.ModificationTimeRaw, in.ModificationTimeExtra) - } - - return in.InodeOld.ModificationTime() -} - -// AccessTime implements Inode.AccessTime. -func (in *InodeNew) AccessTime() time.Time { - // Apply new timestamp logic if inode.AccessTimeExtra is in scope. - if in.ExtraInodeSize >= 16 { - return fromExtraTime(in.AccessTimeRaw, in.AccessTimeExtra) - } - - return in.InodeOld.AccessTime() -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go deleted file mode 100644 index db25b11b6..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/time" -) - -const ( - // OldInodeSize is the inode size in ext2/ext3. - OldInodeSize = 128 -) - -// InodeOld implements Inode interface. It emulates ext2/ext3 inode struct. -// Inode struct size and record size are both 128 bytes for this. -// -// All fields representing time are in seconds since the epoch. Which means that -// they will overflow in January 2038. -type InodeOld struct { - ModeRaw uint16 - UIDLo uint16 - SizeLo uint32 - - // The time fields are signed integers because they could be negative to - // represent time before the epoch. - AccessTimeRaw int32 - ChangeTimeRaw int32 - ModificationTimeRaw int32 - DeletionTimeRaw int32 - - GIDLo uint16 - LinksCountRaw uint16 - BlocksCountLo uint32 - FlagsRaw uint32 - VersionLo uint32 // This is OS dependent. - DataRaw [60]byte - Generation uint32 - FileACLLo uint32 - SizeHi uint32 - ObsoFaddr uint32 - - // OS dependent fields have been inlined here. - BlocksCountHi uint16 - FileACLHi uint16 - UIDHi uint16 - GIDHi uint16 - ChecksumLo uint16 - _ uint16 -} - -// Compiles only if InodeOld implements Inode. -var _ Inode = (*InodeOld)(nil) - -// Mode implements Inode.Mode. -func (in *InodeOld) Mode() linux.FileMode { return linux.FileMode(in.ModeRaw) } - -// UID implements Inode.UID. -func (in *InodeOld) UID() auth.KUID { - return auth.KUID((uint32(in.UIDHi) << 16) | uint32(in.UIDLo)) -} - -// GID implements Inode.GID. -func (in *InodeOld) GID() auth.KGID { - return auth.KGID((uint32(in.GIDHi) << 16) | uint32(in.GIDLo)) -} - -// Size implements Inode.Size. -func (in *InodeOld) Size() uint64 { - // In ext2/ext3, in.SizeHi did not exist, it was instead named in.DirACL. - return uint64(in.SizeLo) -} - -// InodeSize implements Inode.InodeSize. -func (in *InodeOld) InodeSize() uint16 { return OldInodeSize } - -// AccessTime implements Inode.AccessTime. -func (in *InodeOld) AccessTime() time.Time { - return time.FromUnix(int64(in.AccessTimeRaw), 0) -} - -// ChangeTime implements Inode.ChangeTime. -func (in *InodeOld) ChangeTime() time.Time { - return time.FromUnix(int64(in.ChangeTimeRaw), 0) -} - -// ModificationTime implements Inode.ModificationTime. -func (in *InodeOld) ModificationTime() time.Time { - return time.FromUnix(int64(in.ModificationTimeRaw), 0) -} - -// DeletionTime implements Inode.DeletionTime. -func (in *InodeOld) DeletionTime() time.Time { - return time.FromUnix(int64(in.DeletionTimeRaw), 0) -} - -// LinksCount implements Inode.LinksCount. -func (in *InodeOld) LinksCount() uint16 { return in.LinksCountRaw } - -// Flags implements Inode.Flags. -func (in *InodeOld) Flags() InodeFlags { return InodeFlagsFromInt(in.FlagsRaw) } - -// Data implements Inode.Data. -func (in *InodeOld) Data() []byte { return in.DataRaw[:] } diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go deleted file mode 100644 index dd03ee50e..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "fmt" - "strconv" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/kernel/time" -) - -// TestInodeSize tests that the inode structs are of the correct size. -func TestInodeSize(t *testing.T) { - assertSize(t, InodeOld{}, OldInodeSize) - - // This was updated from 156 bytes to 160 bytes in Oct 2015. - assertSize(t, InodeNew{}, 160) -} - -// TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in -// ext4 inode structs are decoded correctly. -// -// These tests are derived from the table under https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps. -func TestTimestampSeconds(t *testing.T) { - type timestampTest struct { - // msbSet tells if the most significant bit of InodeOld.[X]TimeRaw is set. - // If this is set then the 32-bit time is negative. - msbSet bool - - // lowerBound tells if we should take the lowest possible value of - // InodeOld.[X]TimeRaw while satisfying test.msbSet condition. If set to - // false it tells to take the highest possible value. - lowerBound bool - - // extraBits is InodeNew.[X]TimeExtra. - extraBits uint32 - - // want is the kernel time struct that is expected. - want time.Time - } - - tests := []timestampTest{ - // 1901-12-13 - { - msbSet: true, - lowerBound: true, - extraBits: 0, - want: time.FromUnix(int64(-0x80000000), 0), - }, - - // 1969-12-31 - { - msbSet: true, - lowerBound: false, - extraBits: 0, - want: time.FromUnix(int64(-1), 0), - }, - - // 1970-01-01 - { - msbSet: false, - lowerBound: true, - extraBits: 0, - want: time.FromUnix(int64(0), 0), - }, - - // 2038-01-19 - { - msbSet: false, - lowerBound: false, - extraBits: 0, - want: time.FromUnix(int64(0x7fffffff), 0), - }, - - // 2038-01-19 - { - msbSet: true, - lowerBound: true, - extraBits: 1, - want: time.FromUnix(int64(0x80000000), 0), - }, - - // 2106-02-07 - { - msbSet: true, - lowerBound: false, - extraBits: 1, - want: time.FromUnix(int64(0xffffffff), 0), - }, - - // 2106-02-07 - { - msbSet: false, - lowerBound: true, - extraBits: 1, - want: time.FromUnix(int64(0x100000000), 0), - }, - - // 2174-02-25 - { - msbSet: false, - lowerBound: false, - extraBits: 1, - want: time.FromUnix(int64(0x17fffffff), 0), - }, - - // 2174-02-25 - { - msbSet: true, - lowerBound: true, - extraBits: 2, - want: time.FromUnix(int64(0x180000000), 0), - }, - - // 2242-03-16 - { - msbSet: true, - lowerBound: false, - extraBits: 2, - want: time.FromUnix(int64(0x1ffffffff), 0), - }, - - // 2242-03-16 - { - msbSet: false, - lowerBound: true, - extraBits: 2, - want: time.FromUnix(int64(0x200000000), 0), - }, - - // 2310-04-04 - { - msbSet: false, - lowerBound: false, - extraBits: 2, - want: time.FromUnix(int64(0x27fffffff), 0), - }, - - // 2310-04-04 - { - msbSet: true, - lowerBound: true, - extraBits: 3, - want: time.FromUnix(int64(0x280000000), 0), - }, - - // 2378-04-22 - { - msbSet: true, - lowerBound: false, - extraBits: 3, - want: time.FromUnix(int64(0x2ffffffff), 0), - }, - - // 2378-04-22 - { - msbSet: false, - lowerBound: true, - extraBits: 3, - want: time.FromUnix(int64(0x300000000), 0), - }, - - // 2446-05-10 - { - msbSet: false, - lowerBound: false, - extraBits: 3, - want: time.FromUnix(int64(0x37fffffff), 0), - }, - } - - lowerMSB0 := int32(0) // binary: 00000000 00000000 00000000 00000000 - upperMSB0 := int32(0x7fffffff) // binary: 01111111 11111111 11111111 11111111 - lowerMSB1 := int32(-0x80000000) // binary: 10000000 00000000 00000000 00000000 - upperMSB1 := int32(-1) // binary: 11111111 11111111 11111111 11111111 - - get32BitTime := func(test timestampTest) int32 { - if test.msbSet { - if test.lowerBound { - return lowerMSB1 - } - - return upperMSB1 - } - - if test.lowerBound { - return lowerMSB0 - } - - return upperMSB0 - } - - getTestName := func(test timestampTest) string { - return fmt.Sprintf( - "Tests time decoding with epoch bits 0b%s and 32-bit raw time: MSB set=%t, lower bound=%t", - strconv.FormatInt(int64(test.extraBits), 2), - test.msbSet, - test.lowerBound, - ) - } - - for _, test := range tests { - t.Run(getTestName(test), func(t *testing.T) { - if got := fromExtraTime(get32BitTime(test), test.extraBits); got != test.want { - t.Errorf("Expected: %v, Got: %v", test.want, got) - } - }) - } -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go deleted file mode 100644 index 8bb327006..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go +++ /dev/null @@ -1,471 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -const ( - // SbOffset is the absolute offset at which the superblock is placed. - SbOffset = 1024 -) - -// SuperBlock should be implemented by structs representing the ext superblock. -// The superblock holds a lot of information about the enclosing filesystem. -// This interface aims to provide access methods to important information held -// by the superblock. It does NOT expose all fields of the superblock, only the -// ones necessary. This can be expanded when need be. -// -// Location and replication: -// - The superblock is located at offset 1024 in block group 0. -// - Redundant copies of the superblock and group descriptors are kept in -// all groups if SbSparse feature flag is NOT set. If it is set, the -// replicas only exist in groups whose group number is either 0 or a -// power of 3, 5, or 7. -// - There is also a sparse superblock feature v2 in which there are just -// two replicas saved in the block groups pointed by sb.s_backup_bgs. -// -// Replicas should eventually be updated if the superblock is updated. -// -// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block. -type SuperBlock interface { - // InodesCount returns the total number of inodes in this filesystem. - InodesCount() uint32 - - // BlocksCount returns the total number of data blocks in this filesystem. - BlocksCount() uint64 - - // FreeBlocksCount returns the number of free blocks in this filesystem. - FreeBlocksCount() uint64 - - // FreeInodesCount returns the number of free inodes in this filesystem. - FreeInodesCount() uint32 - - // MountCount returns the number of mounts since the last fsck. - MountCount() uint16 - - // MaxMountCount returns the number of mounts allowed beyond which a fsck is - // needed. - MaxMountCount() uint16 - - // FirstDataBlock returns the absolute block number of the first data block, - // which contains the super block itself. - // - // If the filesystem has 1kb data blocks then this should return 1. For all - // other configurations, this typically returns 0. - FirstDataBlock() uint32 - - // BlockSize returns the size of one data block in this filesystem. - // This can be calculated by 2^(10 + sb.s_log_block_size). This ensures that - // the smallest block size is 1kb. - BlockSize() uint64 - - // BlocksPerGroup returns the number of data blocks in a block group. - BlocksPerGroup() uint32 - - // ClusterSize returns block cluster size (set during mkfs time by admin). - // This can be calculated by 2^(10 + sb.s_log_cluster_size). This ensures that - // the smallest cluster size is 1kb. - // - // sb.s_log_cluster_size must equal sb.s_log_block_size if bigalloc feature - // is NOT set and consequently BlockSize() = ClusterSize() in that case. - ClusterSize() uint64 - - // ClustersPerGroup returns: - // - number of clusters per group if bigalloc is enabled. - // - BlocksPerGroup() otherwise. - ClustersPerGroup() uint32 - - // InodeSize returns the size of the inode disk record size in bytes. Use this - // to iterate over inode arrays on disk. - // - // In ext2 and ext3: - // - Each inode had a disk record of 128 bytes. - // - The inode struct size was fixed at 128 bytes. - // - // In ext4 its possible to allocate larger on-disk inodes: - // - Inode disk record size = sb.s_inode_size (function return value). - // = 256 (default) - // - Inode struct size = 128 + inode.i_extra_isize. - // = 128 + 32 = 160 (default) - InodeSize() uint16 - - // InodesPerGroup returns the number of inodes in a block group. - InodesPerGroup() uint32 - - // BgDescSize returns the size of the block group descriptor struct. - // - // In ext2, ext3, ext4 (without 64-bit feature), the block group descriptor - // is only 32 bytes long. - // In ext4 with 64-bit feature, the block group descriptor expands to AT LEAST - // 64 bytes. It might be bigger than that. - BgDescSize() uint16 - - // CompatibleFeatures returns the CompatFeatures struct which holds all the - // compatible features this fs supports. - CompatibleFeatures() CompatFeatures - - // IncompatibleFeatures returns the CompatFeatures struct which holds all the - // incompatible features this fs supports. - IncompatibleFeatures() IncompatFeatures - - // ReadOnlyCompatibleFeatures returns the CompatFeatures struct which holds all the - // readonly compatible features this fs supports. - ReadOnlyCompatibleFeatures() RoCompatFeatures - - // Magic() returns the magic signature which must be 0xef53. - Magic() uint16 - - // Revision returns the superblock revision. Superblock struct fields from - // offset 0x54 till 0x150 should only be used if superblock has DynamicRev. - Revision() SbRevision -} - -// SbRevision is the type for superblock revisions. -type SbRevision uint32 - -// Super block revisions. -const ( - // OldRev is the good old (original) format. - OldRev SbRevision = 0 - - // DynamicRev is v2 format w/ dynamic inode sizes. - DynamicRev SbRevision = 1 -) - -// Superblock compatible features. -// This is not exhaustive, unused features are not listed. -const ( - // SbDirPrealloc indicates directory preallocation. - SbDirPrealloc = 0x1 - - // SbHasJournal indicates the presence of a journal. jbd2 should only work - // with this being set. - SbHasJournal = 0x4 - - // SbExtAttr indicates extended attributes support. - SbExtAttr = 0x8 - - // SbResizeInode indicates that the fs has reserved GDT blocks (right after - // group descriptors) for fs expansion. - SbResizeInode = 0x10 - - // SbDirIndex indicates that the fs has directory indices. - SbDirIndex = 0x20 - - // SbSparseV2 stands for Sparse superblock version 2. - SbSparseV2 = 0x200 -) - -// CompatFeatures represents a superblock's compatible feature set. If the -// kernel does not understand any of these feature, it can still read/write -// to this fs. -type CompatFeatures struct { - DirPrealloc bool - HasJournal bool - ExtAttr bool - ResizeInode bool - DirIndex bool - SparseV2 bool -} - -// ToInt converts superblock compatible features back to its 32-bit rep. -func (f CompatFeatures) ToInt() uint32 { - var res uint32 - - if f.DirPrealloc { - res |= SbDirPrealloc - } - if f.HasJournal { - res |= SbHasJournal - } - if f.ExtAttr { - res |= SbExtAttr - } - if f.ResizeInode { - res |= SbResizeInode - } - if f.DirIndex { - res |= SbDirIndex - } - if f.SparseV2 { - res |= SbSparseV2 - } - - return res -} - -// CompatFeaturesFromInt converts the integer representation of superblock -// compatible features to CompatFeatures struct. -func CompatFeaturesFromInt(f uint32) CompatFeatures { - return CompatFeatures{ - DirPrealloc: f&SbDirPrealloc > 0, - HasJournal: f&SbHasJournal > 0, - ExtAttr: f&SbExtAttr > 0, - ResizeInode: f&SbResizeInode > 0, - DirIndex: f&SbDirIndex > 0, - SparseV2: f&SbSparseV2 > 0, - } -} - -// Superblock incompatible features. -// This is not exhaustive, unused features are not listed. -const ( - // SbDirentFileType indicates that directory entries record the file type. - // We should use struct DirentNew for dirents then. - SbDirentFileType = 0x2 - - // SbRecovery indicates that the filesystem needs recovery. - SbRecovery = 0x4 - - // SbJournalDev indicates that the filesystem has a separate journal device. - SbJournalDev = 0x8 - - // SbMetaBG indicates that the filesystem is using Meta block groups. Moves - // the group descriptors from the congested first block group into the first - // group of each metablock group to increase the maximum block groups limit - // and hence support much larger filesystems. - // - // See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#meta-block-groups. - SbMetaBG = 0x10 - - // SbExtents indicates that the filesystem uses extents. Must be set in ext4 - // filesystems. - SbExtents = 0x40 - - // SbIs64Bit indicates that this filesystem addresses blocks with 64-bits. - // Hence can support 2^64 data blocks. - SbIs64Bit = 0x80 - - // SbMMP indicates that this filesystem has multiple mount protection. - // - // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#multiple-mount-protection. - SbMMP = 0x100 - - // SbFlexBg indicates that this filesystem has flexible block groups. Several - // block groups are tied into one logical block group so that all the metadata - // for the block groups (bitmaps and inode tables) are close together for - // faster loading. Consequently, large files will be continuous on disk. - // However, this does not affect the placement of redundant superblocks and - // group descriptors. - // - // See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#flexible-block-groups. - SbFlexBg = 0x200 - - // SbLargeDir shows that large directory enabled. Directory htree can be 3 - // levels deep. Directory htrees are allowed to be 2 levels deep otherwise. - SbLargeDir = 0x4000 - - // SbInlineData allows inline data in inodes for really small files. - SbInlineData = 0x8000 - - // SbEncrypted indicates that this fs contains encrypted inodes. - SbEncrypted = 0x10000 -) - -// IncompatFeatures represents a superblock's incompatible feature set. If the -// kernel does not understand any of these feature, it should refuse to mount. -type IncompatFeatures struct { - DirentFileType bool - Recovery bool - JournalDev bool - MetaBG bool - Extents bool - Is64Bit bool - MMP bool - FlexBg bool - LargeDir bool - InlineData bool - Encrypted bool -} - -// ToInt converts superblock incompatible features back to its 32-bit rep. -func (f IncompatFeatures) ToInt() uint32 { - var res uint32 - - if f.DirentFileType { - res |= SbDirentFileType - } - if f.Recovery { - res |= SbRecovery - } - if f.JournalDev { - res |= SbJournalDev - } - if f.MetaBG { - res |= SbMetaBG - } - if f.Extents { - res |= SbExtents - } - if f.Is64Bit { - res |= SbIs64Bit - } - if f.MMP { - res |= SbMMP - } - if f.FlexBg { - res |= SbFlexBg - } - if f.LargeDir { - res |= SbLargeDir - } - if f.InlineData { - res |= SbInlineData - } - if f.Encrypted { - res |= SbEncrypted - } - - return res -} - -// IncompatFeaturesFromInt converts the integer representation of superblock -// incompatible features to IncompatFeatures struct. -func IncompatFeaturesFromInt(f uint32) IncompatFeatures { - return IncompatFeatures{ - DirentFileType: f&SbDirentFileType > 0, - Recovery: f&SbRecovery > 0, - JournalDev: f&SbJournalDev > 0, - MetaBG: f&SbMetaBG > 0, - Extents: f&SbExtents > 0, - Is64Bit: f&SbIs64Bit > 0, - MMP: f&SbMMP > 0, - FlexBg: f&SbFlexBg > 0, - LargeDir: f&SbLargeDir > 0, - InlineData: f&SbInlineData > 0, - Encrypted: f&SbEncrypted > 0, - } -} - -// Superblock readonly compatible features. -// This is not exhaustive, unused features are not listed. -const ( - // SbSparse indicates sparse superblocks. Only groups with number either 0 or - // a power of 3, 5, or 7 will have redundant copies of the superblock and - // block descriptors. - SbSparse = 0x1 - - // SbLargeFile indicates that this fs has been used to store a file >= 2GiB. - SbLargeFile = 0x2 - - // SbHugeFile indicates that this fs contains files whose sizes are - // represented in units of logicals blocks, not 512-byte sectors. - SbHugeFile = 0x8 - - // SbGdtCsum indicates that group descriptors have checksums. - SbGdtCsum = 0x10 - - // SbDirNlink indicates that the new subdirectory limit is 64,999. Ext3 has a - // 32,000 subdirectory limit. - SbDirNlink = 0x20 - - // SbExtraIsize indicates that large inodes exist on this filesystem. - SbExtraIsize = 0x40 - - // SbHasSnapshot indicates the existence of a snapshot. - SbHasSnapshot = 0x80 - - // SbQuota enables usage tracking for all quota types. - SbQuota = 0x100 - - // SbBigalloc maps to the bigalloc feature. When set, the minimum allocation - // unit becomes a cluster rather than a data block. Then block bitmaps track - // clusters, not data blocks. - // - // See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#bigalloc. - SbBigalloc = 0x200 - - // SbMetadataCsum indicates that the fs supports metadata checksumming. - SbMetadataCsum = 0x400 - - // SbReadOnly marks this filesystem as readonly. Should refuse to mount in - // read/write mode. - SbReadOnly = 0x1000 -) - -// RoCompatFeatures represents a superblock's readonly compatible feature set. -// If the kernel does not understand any of these feature, it can still mount -// readonly. But if the user wants to mount read/write, the kernel should -// refuse to mount. -type RoCompatFeatures struct { - Sparse bool - LargeFile bool - HugeFile bool - GdtCsum bool - DirNlink bool - ExtraIsize bool - HasSnapshot bool - Quota bool - Bigalloc bool - MetadataCsum bool - ReadOnly bool -} - -// ToInt converts superblock readonly compatible features to its 32-bit rep. -func (f RoCompatFeatures) ToInt() uint32 { - var res uint32 - - if f.Sparse { - res |= SbSparse - } - if f.LargeFile { - res |= SbLargeFile - } - if f.HugeFile { - res |= SbHugeFile - } - if f.GdtCsum { - res |= SbGdtCsum - } - if f.DirNlink { - res |= SbDirNlink - } - if f.ExtraIsize { - res |= SbExtraIsize - } - if f.HasSnapshot { - res |= SbHasSnapshot - } - if f.Quota { - res |= SbQuota - } - if f.Bigalloc { - res |= SbBigalloc - } - if f.MetadataCsum { - res |= SbMetadataCsum - } - if f.ReadOnly { - res |= SbReadOnly - } - - return res -} - -// RoCompatFeaturesFromInt converts the integer representation of superblock -// readonly compatible features to RoCompatFeatures struct. -func RoCompatFeaturesFromInt(f uint32) RoCompatFeatures { - return RoCompatFeatures{ - Sparse: f&SbSparse > 0, - LargeFile: f&SbLargeFile > 0, - HugeFile: f&SbHugeFile > 0, - GdtCsum: f&SbGdtCsum > 0, - DirNlink: f&SbDirNlink > 0, - ExtraIsize: f&SbExtraIsize > 0, - HasSnapshot: f&SbHasSnapshot > 0, - Quota: f&SbQuota > 0, - Bigalloc: f&SbBigalloc > 0, - MetadataCsum: f&SbMetadataCsum > 0, - ReadOnly: f&SbReadOnly > 0, - } -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go deleted file mode 100644 index 53e515fd3..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -// SuperBlock32Bit implements SuperBlock and represents the 32-bit version of -// the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if -// RevLevel = DynamicRev and 64-bit feature is disabled. -type SuperBlock32Bit struct { - // We embed the old superblock struct here because the 32-bit version is just - // an extension of the old version. - SuperBlockOld - - FirstInode uint32 - InodeSizeRaw uint16 - BlockGroupNumber uint16 - FeatureCompat uint32 - FeatureIncompat uint32 - FeatureRoCompat uint32 - UUID [16]byte - VolumeName [16]byte - LastMounted [64]byte - AlgoUsageBitmap uint32 - PreallocBlocks uint8 - PreallocDirBlocks uint8 - ReservedGdtBlocks uint16 - JournalUUID [16]byte - JournalInum uint32 - JournalDev uint32 - LastOrphan uint32 - HashSeed [4]uint32 - DefaultHashVersion uint8 - JnlBackupType uint8 - BgDescSizeRaw uint16 - DefaultMountOpts uint32 - FirstMetaBg uint32 - MkfsTime uint32 - JnlBlocks [17]uint32 -} - -// Compiles only if SuperBlock32Bit implements SuperBlock. -var _ SuperBlock = (*SuperBlock32Bit)(nil) - -// Only override methods which change based on the additional fields above. -// Not overriding SuperBlock.BgDescSize because it would still return 32 here. - -// InodeSize implements SuperBlock.InodeSize. -func (sb *SuperBlock32Bit) InodeSize() uint16 { - return sb.InodeSizeRaw -} - -// CompatibleFeatures implements SuperBlock.CompatibleFeatures. -func (sb *SuperBlock32Bit) CompatibleFeatures() CompatFeatures { - return CompatFeaturesFromInt(sb.FeatureCompat) -} - -// IncompatibleFeatures implements SuperBlock.IncompatibleFeatures. -func (sb *SuperBlock32Bit) IncompatibleFeatures() IncompatFeatures { - return IncompatFeaturesFromInt(sb.FeatureIncompat) -} - -// ReadOnlyCompatibleFeatures implements SuperBlock.ReadOnlyCompatibleFeatures. -func (sb *SuperBlock32Bit) ReadOnlyCompatibleFeatures() RoCompatFeatures { - return RoCompatFeaturesFromInt(sb.FeatureRoCompat) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go deleted file mode 100644 index 7c1053fb4..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -// SuperBlock64Bit implements SuperBlock and represents the 64-bit version of -// the ext4_super_block struct in fs/ext4/ext4.h. This sums up to be exactly -// 1024 bytes (smallest possible block size) and hence the superblock always -// fits in no more than one data block. Should only be used when the 64-bit -// feature is set. -type SuperBlock64Bit struct { - // We embed the 32-bit struct here because 64-bit version is just an extension - // of the 32-bit version. - SuperBlock32Bit - - BlocksCountHi uint32 - ReservedBlocksCountHi uint32 - FreeBlocksCountHi uint32 - MinInodeSize uint16 - WantInodeSize uint16 - Flags uint32 - RaidStride uint16 - MmpInterval uint16 - MmpBlock uint64 - RaidStripeWidth uint32 - LogGroupsPerFlex uint8 - ChecksumType uint8 - _ uint16 - KbytesWritten uint64 - SnapshotInum uint32 - SnapshotID uint32 - SnapshotRsrvBlocksCount uint64 - SnapshotList uint32 - ErrorCount uint32 - FirstErrorTime uint32 - FirstErrorInode uint32 - FirstErrorBlock uint64 - FirstErrorFunction [32]byte - FirstErrorLine uint32 - LastErrorTime uint32 - LastErrorInode uint32 - LastErrorLine uint32 - LastErrorBlock uint64 - LastErrorFunction [32]byte - MountOpts [64]byte - UserQuotaInum uint32 - GroupQuotaInum uint32 - OverheadBlocks uint32 - BackupBgs [2]uint32 - EncryptAlgos [4]uint8 - EncryptPwSalt [16]uint8 - LostFoundInode uint32 - ProjectQuotaInode uint32 - ChecksumSeed uint32 - WtimeHi uint8 - MtimeHi uint8 - MkfsTimeHi uint8 - LastCheckHi uint8 - FirstErrorTimeHi uint8 - LastErrorTimeHi uint8 - _ [2]uint8 - Encoding uint16 - EncodingFlags uint16 - _ [95]uint32 - Checksum uint32 -} - -// Compiles only if SuperBlock64Bit implements SuperBlock. -var _ SuperBlock = (*SuperBlock64Bit)(nil) - -// Only override methods which change based on the 64-bit feature. - -// BlocksCount implements SuperBlock.BlocksCount. -func (sb *SuperBlock64Bit) BlocksCount() uint64 { - return (uint64(sb.BlocksCountHi) << 32) | uint64(sb.BlocksCountLo) -} - -// FreeBlocksCount implements SuperBlock.FreeBlocksCount. -func (sb *SuperBlock64Bit) FreeBlocksCount() uint64 { - return (uint64(sb.FreeBlocksCountHi) << 32) | uint64(sb.FreeBlocksCountLo) -} - -// BgDescSize implements SuperBlock.BgDescSize. -func (sb *SuperBlock64Bit) BgDescSize() uint16 { return sb.BgDescSizeRaw } diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go deleted file mode 100644 index 9221e0251..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -// SuperBlockOld implements SuperBlock and represents the old version of the -// superblock struct. Should be used only if RevLevel = OldRev. -type SuperBlockOld struct { - InodesCountRaw uint32 - BlocksCountLo uint32 - ReservedBlocksCount uint32 - FreeBlocksCountLo uint32 - FreeInodesCountRaw uint32 - FirstDataBlockRaw uint32 - LogBlockSize uint32 - LogClusterSize uint32 - BlocksPerGroupRaw uint32 - ClustersPerGroupRaw uint32 - InodesPerGroupRaw uint32 - Mtime uint32 - Wtime uint32 - MountCountRaw uint16 - MaxMountCountRaw uint16 - MagicRaw uint16 - State uint16 - Errors uint16 - MinorRevLevel uint16 - LastCheck uint32 - CheckInterval uint32 - CreatorOS uint32 - RevLevel uint32 - DefResUID uint16 - DefResGID uint16 -} - -// Compiles only if SuperBlockOld implements SuperBlock. -var _ SuperBlock = (*SuperBlockOld)(nil) - -// InodesCount implements SuperBlock.InodesCount. -func (sb *SuperBlockOld) InodesCount() uint32 { return sb.InodesCountRaw } - -// BlocksCount implements SuperBlock.BlocksCount. -func (sb *SuperBlockOld) BlocksCount() uint64 { return uint64(sb.BlocksCountLo) } - -// FreeBlocksCount implements SuperBlock.FreeBlocksCount. -func (sb *SuperBlockOld) FreeBlocksCount() uint64 { return uint64(sb.FreeBlocksCountLo) } - -// FreeInodesCount implements SuperBlock.FreeInodesCount. -func (sb *SuperBlockOld) FreeInodesCount() uint32 { return sb.FreeInodesCountRaw } - -// MountCount implements SuperBlock.MountCount. -func (sb *SuperBlockOld) MountCount() uint16 { return sb.MountCountRaw } - -// MaxMountCount implements SuperBlock.MaxMountCount. -func (sb *SuperBlockOld) MaxMountCount() uint16 { return sb.MaxMountCountRaw } - -// FirstDataBlock implements SuperBlock.FirstDataBlock. -func (sb *SuperBlockOld) FirstDataBlock() uint32 { return sb.FirstDataBlockRaw } - -// BlockSize implements SuperBlock.BlockSize. -func (sb *SuperBlockOld) BlockSize() uint64 { return 1 << (10 + sb.LogBlockSize) } - -// BlocksPerGroup implements SuperBlock.BlocksPerGroup. -func (sb *SuperBlockOld) BlocksPerGroup() uint32 { return sb.BlocksPerGroupRaw } - -// ClusterSize implements SuperBlock.ClusterSize. -func (sb *SuperBlockOld) ClusterSize() uint64 { return 1 << (10 + sb.LogClusterSize) } - -// ClustersPerGroup implements SuperBlock.ClustersPerGroup. -func (sb *SuperBlockOld) ClustersPerGroup() uint32 { return sb.ClustersPerGroupRaw } - -// InodeSize implements SuperBlock.InodeSize. -func (sb *SuperBlockOld) InodeSize() uint16 { return OldInodeSize } - -// InodesPerGroup implements SuperBlock.InodesPerGroup. -func (sb *SuperBlockOld) InodesPerGroup() uint32 { return sb.InodesPerGroupRaw } - -// BgDescSize implements SuperBlock.BgDescSize. -func (sb *SuperBlockOld) BgDescSize() uint16 { return 32 } - -// CompatibleFeatures implements SuperBlock.CompatibleFeatures. -func (sb *SuperBlockOld) CompatibleFeatures() CompatFeatures { return CompatFeatures{} } - -// IncompatibleFeatures implements SuperBlock.IncompatibleFeatures. -func (sb *SuperBlockOld) IncompatibleFeatures() IncompatFeatures { return IncompatFeatures{} } - -// ReadOnlyCompatibleFeatures implements SuperBlock.ReadOnlyCompatibleFeatures. -func (sb *SuperBlockOld) ReadOnlyCompatibleFeatures() RoCompatFeatures { return RoCompatFeatures{} } - -// Magic implements SuperBlock.Magic. -func (sb *SuperBlockOld) Magic() uint16 { return sb.MagicRaw } - -// Revision implements SuperBlock.Revision. -func (sb *SuperBlockOld) Revision() SbRevision { return SbRevision(sb.RevLevel) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go deleted file mode 100644 index 463b5ba21..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "testing" -) - -// TestSuperBlockSize tests that the superblock structs are of the correct -// size. -func TestSuperBlockSize(t *testing.T) { - assertSize(t, SuperBlockOld{}, 84) - assertSize(t, SuperBlock32Bit{}, 336) - assertSize(t, SuperBlock64Bit{}, 1024) -} diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go deleted file mode 100644 index 9c63f04c0..000000000 --- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package disklayout - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/binary" -) - -func assertSize(t *testing.T, v interface{}, want uintptr) { - t.Helper() - - if got := binary.Size(v); got != want { - t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got) - } -} diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go deleted file mode 100644 index 373d23b74..000000000 --- a/pkg/sentry/fsimpl/ext/ext.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package ext implements readonly ext(2/3/4) filesystems. -package ext - -import ( - "errors" - "fmt" - "io" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// FilesystemType implements vfs.FilesystemType. -type FilesystemType struct{} - -// Compiles only if FilesystemType implements vfs.FilesystemType. -var _ vfs.FilesystemType = (*FilesystemType)(nil) - -// getDeviceFd returns an io.ReaderAt to the underlying device. -// Currently there are two ways of mounting an ext(2/3/4) fs: -// 1. Specify a mount with our internal special MountType in the OCI spec. -// 2. Expose the device to the container and mount it from application layer. -func getDeviceFd(source string, opts vfs.GetFilesystemOptions) (io.ReaderAt, error) { - if opts.InternalData == nil { - // User mount call. - // TODO(b/134676337): Open the device specified by `source` and return that. - panic("unimplemented") - } - - // GetFilesystem call originated from within the sentry. - devFd, ok := opts.InternalData.(int) - if !ok { - return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device") - } - - if devFd < 0 { - return nil, fmt.Errorf("ext device file descriptor is not valid: %d", devFd) - } - - // The fd.ReadWriter returned from fd.NewReadWriter() does not take ownership - // of the file descriptor and hence will not close it when it is garbage - // collected. - return fd.NewReadWriter(devFd), nil -} - -// isCompatible checks if the superblock has feature sets which are compatible. -// We only need to check the superblock incompatible feature set since we are -// mounting readonly. We will also need to check readonly compatible feature -// set when mounting for read/write. -func isCompatible(sb disklayout.SuperBlock) bool { - // Please note that what is being checked is limited based on the fact that we - // are mounting readonly and that we are not journaling. When mounting - // read/write or with a journal, this must be reevaluated. - incompatFeatures := sb.IncompatibleFeatures() - if incompatFeatures.MetaBG { - log.Warningf("ext fs: meta block groups are not supported") - return false - } - if incompatFeatures.MMP { - log.Warningf("ext fs: multiple mount protection is not supported") - return false - } - if incompatFeatures.Encrypted { - log.Warningf("ext fs: encrypted inodes not supported") - return false - } - if incompatFeatures.InlineData { - log.Warningf("ext fs: inline files not supported") - return false - } - return true -} - -// GetFilesystem implements vfs.FilesystemType.GetFilesystem. -func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - // TODO(b/134676337): Ensure that the user is mounting readonly. If not, - // EACCESS should be returned according to mount(2). Filesystem independent - // flags (like readonly) are currently not available in pkg/sentry/vfs. - - dev, err := getDeviceFd(source, opts) - if err != nil { - return nil, nil, err - } - - fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} - fs.vfsfs.Init(vfsObj, &fs) - fs.sb, err = readSuperBlock(dev) - if err != nil { - return nil, nil, err - } - - if fs.sb.Magic() != linux.EXT_SUPER_MAGIC { - // mount(2) specifies that EINVAL should be returned if the superblock is - // invalid. - return nil, nil, syserror.EINVAL - } - - // Refuse to mount if the filesystem is incompatible. - if !isCompatible(fs.sb) { - return nil, nil, syserror.EINVAL - } - - fs.bgs, err = readBlockGroups(dev, fs.sb) - if err != nil { - return nil, nil, err - } - - rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode) - if err != nil { - return nil, nil, err - } - rootInode.incRef() - - return &fs.vfsfs, &newDentry(rootInode).vfsd, nil -} diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go deleted file mode 100644 index 29bb73765..000000000 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ /dev/null @@ -1,922 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "fmt" - "io" - "os" - "path" - "sort" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/runsc/testutil" -) - -const ( - assetsDir = "pkg/sentry/fsimpl/ext/assets" -) - -var ( - ext2ImagePath = path.Join(assetsDir, "tiny.ext2") - ext3ImagePath = path.Join(assetsDir, "tiny.ext3") - ext4ImagePath = path.Join(assetsDir, "tiny.ext4") -) - -// setUp opens imagePath as an ext Filesystem and returns all necessary -// elements required to run tests. If error is non-nil, it also returns a tear -// down function which must be called after the test is run for clean up. -func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesystem, *vfs.VirtualDentry, func(), error) { - localImagePath, err := testutil.FindFile(imagePath) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("failed to open local image at path %s: %v", imagePath, err) - } - - f, err := os.Open(localImagePath) - if err != nil { - return nil, nil, nil, nil, err - } - - ctx := contexttest.Context(t) - creds := auth.CredentialsFromContext(ctx) - - // Create VFS. - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - t.Fatalf("VFS init: %v", err) - } - vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) - if err != nil { - f.Close() - return nil, nil, nil, nil, err - } - - root := mntns.Root() - - tearDown := func() { - root.DecRef() - - if err := f.Close(); err != nil { - t.Fatalf("tearDown failed: %v", err) - } - } - return ctx, vfsObj, &root, tearDown, nil -} - -// TODO(b/134676337): Test vfs.FilesystemImpl.ReadlinkAt and -// vfs.FilesystemImpl.StatFSAt which are not implemented in -// vfs.VirtualFilesystem yet. - -// TestSeek tests vfs.FileDescriptionImpl.Seek functionality. -func TestSeek(t *testing.T) { - type seekTest struct { - name string - image string - path string - } - - tests := []seekTest{ - { - name: "ext4 root dir seek", - image: ext4ImagePath, - path: "/", - }, - { - name: "ext3 root dir seek", - image: ext3ImagePath, - path: "/", - }, - { - name: "ext2 root dir seek", - image: ext2ImagePath, - path: "/", - }, - { - name: "ext4 reg file seek", - image: ext4ImagePath, - path: "/file.txt", - }, - { - name: "ext3 reg file seek", - image: ext3ImagePath, - path: "/file.txt", - }, - { - name: "ext2 reg file seek", - image: ext2ImagePath, - path: "/file.txt", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ctx, vfsfs, root, tearDown, err := setUp(t, test.image) - if err != nil { - t.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - fd, err := vfsfs.OpenAt( - ctx, - auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)}, - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt failed: %v", err) - } - - if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { - t.Errorf("expected seek position 0, got %d and error %v", n, err) - } - - stat, err := fd.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err) - } - - // We should be able to seek beyond the end of file. - size := int64(stat.Size) - if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { - t.Errorf("expected seek position %d, got %d and error %v", size, n, err) - } - - // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { - t.Errorf("expected error EINVAL but got %v", err) - } - - if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { - t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err) - } - - // Make sure negative offsets work with SEEK_CUR. - if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { - t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) - } - - // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { - t.Errorf("expected error EINVAL but got %v", err) - } - - // Make sure SEEK_END works with regular files. - if _, ok := fd.Impl().(*regularFileFD); ok { - // Seek back to 0. - if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { - t.Errorf("expected seek position %d, got %d and error %v", 0, n, err) - } - - // Seek forward beyond EOF. - if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { - t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) - } - - // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { - t.Errorf("expected error EINVAL but got %v", err) - } - } - }) - } -} - -// TestStatAt tests filesystem.StatAt functionality. -func TestStatAt(t *testing.T) { - type statAtTest struct { - name string - image string - path string - want linux.Statx - } - - tests := []statAtTest{ - { - name: "ext4 statx small file", - image: ext4ImagePath, - path: "/file.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0644 | linux.ModeRegular, - Size: 13, - }, - }, - { - name: "ext3 statx small file", - image: ext3ImagePath, - path: "/file.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0644 | linux.ModeRegular, - Size: 13, - }, - }, - { - name: "ext2 statx small file", - image: ext2ImagePath, - path: "/file.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0644 | linux.ModeRegular, - Size: 13, - }, - }, - { - name: "ext4 statx big file", - image: ext4ImagePath, - path: "/bigfile.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0644 | linux.ModeRegular, - Size: 13042, - }, - }, - { - name: "ext3 statx big file", - image: ext3ImagePath, - path: "/bigfile.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0644 | linux.ModeRegular, - Size: 13042, - }, - }, - { - name: "ext2 statx big file", - image: ext2ImagePath, - path: "/bigfile.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0644 | linux.ModeRegular, - Size: 13042, - }, - }, - { - name: "ext4 statx symlink file", - image: ext4ImagePath, - path: "/symlink.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0777 | linux.ModeSymlink, - Size: 8, - }, - }, - { - name: "ext3 statx symlink file", - image: ext3ImagePath, - path: "/symlink.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0777 | linux.ModeSymlink, - Size: 8, - }, - }, - { - name: "ext2 statx symlink file", - image: ext2ImagePath, - path: "/symlink.txt", - want: linux.Statx{ - Blksize: 0x400, - Nlink: 1, - UID: 0, - GID: 0, - Mode: 0777 | linux.ModeSymlink, - Size: 8, - }, - }, - } - - // Ignore the fields that are not supported by filesystem.StatAt yet and - // those which are likely to change as the image does. - ignoredFields := map[string]bool{ - "Attributes": true, - "AttributesMask": true, - "Atime": true, - "Blocks": true, - "Btime": true, - "Ctime": true, - "DevMajor": true, - "DevMinor": true, - "Ino": true, - "Mask": true, - "Mtime": true, - "RdevMajor": true, - "RdevMinor": true, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ctx, vfsfs, root, tearDown, err := setUp(t, test.image) - if err != nil { - t.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - got, err := vfsfs.StatAt(ctx, - auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)}, - &vfs.StatOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.StatAt failed for file %s in image %s: %v", test.path, test.image, err) - } - - cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool { - _, ok := ignoredFields[p.String()] - return ok - }, cmp.Ignore()) - if diff := cmp.Diff(got, test.want, cmpIgnoreFields, cmpopts.IgnoreUnexported(linux.Statx{})); diff != "" { - t.Errorf("stat mismatch (-want +got):\n%s", diff) - } - }) - } -} - -// TestRead tests the read functionality for vfs file descriptions. -func TestRead(t *testing.T) { - type readTest struct { - name string - image string - absPath string - } - - tests := []readTest{ - { - name: "ext4 read small file", - image: ext4ImagePath, - absPath: "/file.txt", - }, - { - name: "ext3 read small file", - image: ext3ImagePath, - absPath: "/file.txt", - }, - { - name: "ext2 read small file", - image: ext2ImagePath, - absPath: "/file.txt", - }, - { - name: "ext4 read big file", - image: ext4ImagePath, - absPath: "/bigfile.txt", - }, - { - name: "ext3 read big file", - image: ext3ImagePath, - absPath: "/bigfile.txt", - }, - { - name: "ext2 read big file", - image: ext2ImagePath, - absPath: "/bigfile.txt", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ctx, vfsfs, root, tearDown, err := setUp(t, test.image) - if err != nil { - t.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - fd, err := vfsfs.OpenAt( - ctx, - auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.absPath)}, - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt failed: %v", err) - } - - // Get a local file descriptor and compare its functionality with a vfs file - // description for the same file. - localFile, err := testutil.FindFile(path.Join(assetsDir, test.absPath)) - if err != nil { - t.Fatalf("testutil.FindFile failed for %s: %v", test.absPath, err) - } - - f, err := os.Open(localFile) - if err != nil { - t.Fatalf("os.Open failed for %s: %v", localFile, err) - } - defer f.Close() - - // Read the entire file by reading one byte repeatedly. Doing this stress - // tests the underlying file reader implementation. - got := make([]byte, 1) - want := make([]byte, 1) - for { - n, err := f.Read(want) - fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) - - if diff := cmp.Diff(got, want); diff != "" { - t.Errorf("file data mismatch (-want +got):\n%s", diff) - } - - // Make sure there is no more file data left after getting EOF. - if n == 0 || err == io.EOF { - if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { - t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image) - } - - break - } - - if err != nil { - t.Fatalf("read failed: %v", err) - } - } - }) - } -} - -// iterDirentsCb is a simple callback which just keeps adding the dirents to an -// internal list. Implements vfs.IterDirentsCallback. -type iterDirentsCb struct { - dirents []vfs.Dirent -} - -// Compiles only if iterDirentCb implements vfs.IterDirentsCallback. -var _ vfs.IterDirentsCallback = (*iterDirentsCb)(nil) - -// newIterDirentsCb is the iterDirent -func newIterDirentCb() *iterDirentsCb { - return &iterDirentsCb{dirents: make([]vfs.Dirent, 0)} -} - -// Handle implements vfs.IterDirentsCallback.Handle. -func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) error { - cb.dirents = append(cb.dirents, dirent) - return nil -} - -// TestIterDirents tests the FileDescriptionImpl.IterDirents functionality. -func TestIterDirents(t *testing.T) { - type iterDirentTest struct { - name string - image string - path string - want []vfs.Dirent - } - - wantDirents := []vfs.Dirent{ - { - Name: ".", - Type: linux.DT_DIR, - }, - { - Name: "..", - Type: linux.DT_DIR, - }, - { - Name: "lost+found", - Type: linux.DT_DIR, - }, - { - Name: "file.txt", - Type: linux.DT_REG, - }, - { - Name: "bigfile.txt", - Type: linux.DT_REG, - }, - { - Name: "symlink.txt", - Type: linux.DT_LNK, - }, - } - tests := []iterDirentTest{ - { - name: "ext4 root dir iteration", - image: ext4ImagePath, - path: "/", - want: wantDirents, - }, - { - name: "ext3 root dir iteration", - image: ext3ImagePath, - path: "/", - want: wantDirents, - }, - { - name: "ext2 root dir iteration", - image: ext2ImagePath, - path: "/", - want: wantDirents, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ctx, vfsfs, root, tearDown, err := setUp(t, test.image) - if err != nil { - t.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - fd, err := vfsfs.OpenAt( - ctx, - auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)}, - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt failed: %v", err) - } - - cb := &iterDirentsCb{} - if err = fd.IterDirents(ctx, cb); err != nil { - t.Fatalf("dir fd.IterDirents() failed: %v", err) - } - - sort.Slice(cb.dirents, func(i int, j int) bool { return cb.dirents[i].Name < cb.dirents[j].Name }) - sort.Slice(test.want, func(i int, j int) bool { return test.want[i].Name < test.want[j].Name }) - - // Ignore the inode number and offset of dirents because those are likely to - // change as the underlying image changes. - cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool { - return p.String() == "Ino" || p.String() == "NextOff" - }, cmp.Ignore()) - if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" { - t.Errorf("dirents mismatch (-want +got):\n%s", diff) - } - }) - } -} - -// TestRootDir tests that the root directory inode is correctly initialized and -// returned from setUp. -func TestRootDir(t *testing.T) { - type inodeProps struct { - Mode linux.FileMode - UID auth.KUID - GID auth.KGID - Size uint64 - InodeSize uint16 - Links uint16 - Flags disklayout.InodeFlags - } - - type rootDirTest struct { - name string - image string - wantInode inodeProps - } - - tests := []rootDirTest{ - { - name: "ext4 root dir", - image: ext4ImagePath, - wantInode: inodeProps{ - Mode: linux.ModeDirectory | 0755, - Size: 0x400, - InodeSize: 0x80, - Links: 3, - Flags: disklayout.InodeFlags{Extents: true}, - }, - }, - { - name: "ext3 root dir", - image: ext3ImagePath, - wantInode: inodeProps{ - Mode: linux.ModeDirectory | 0755, - Size: 0x400, - InodeSize: 0x80, - Links: 3, - }, - }, - { - name: "ext2 root dir", - image: ext2ImagePath, - wantInode: inodeProps{ - Mode: linux.ModeDirectory | 0755, - Size: 0x400, - InodeSize: 0x80, - Links: 3, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - _, _, vd, tearDown, err := setUp(t, test.image) - if err != nil { - t.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - d, ok := vd.Dentry().Impl().(*dentry) - if !ok { - t.Fatalf("ext dentry of incorrect type: %T", vd.Dentry().Impl()) - } - - // Offload inode contents into local structs for comparison. - gotInode := inodeProps{ - Mode: d.inode.diskInode.Mode(), - UID: d.inode.diskInode.UID(), - GID: d.inode.diskInode.GID(), - Size: d.inode.diskInode.Size(), - InodeSize: d.inode.diskInode.InodeSize(), - Links: d.inode.diskInode.LinksCount(), - Flags: d.inode.diskInode.Flags(), - } - - if diff := cmp.Diff(gotInode, test.wantInode); diff != "" { - t.Errorf("inode mismatch (-want +got):\n%s", diff) - } - }) - } -} - -// TestFilesystemInit tests that the filesystem superblock and block group -// descriptors are correctly read in and initialized. -func TestFilesystemInit(t *testing.T) { - // sb only contains the immutable properties of the superblock. - type sb struct { - InodesCount uint32 - BlocksCount uint64 - MaxMountCount uint16 - FirstDataBlock uint32 - BlockSize uint64 - BlocksPerGroup uint32 - ClusterSize uint64 - ClustersPerGroup uint32 - InodeSize uint16 - InodesPerGroup uint32 - BgDescSize uint16 - Magic uint16 - Revision disklayout.SbRevision - CompatFeatures disklayout.CompatFeatures - IncompatFeatures disklayout.IncompatFeatures - RoCompatFeatures disklayout.RoCompatFeatures - } - - // bg only contains the immutable properties of the block group descriptor. - type bg struct { - InodeTable uint64 - BlockBitmap uint64 - InodeBitmap uint64 - ExclusionBitmap uint64 - Flags disklayout.BGFlags - } - - type fsInitTest struct { - name string - image string - wantSb sb - wantBgs []bg - } - - tests := []fsInitTest{ - { - name: "ext4 filesystem init", - image: ext4ImagePath, - wantSb: sb{ - InodesCount: 0x10, - BlocksCount: 0x40, - MaxMountCount: 0xffff, - FirstDataBlock: 0x1, - BlockSize: 0x400, - BlocksPerGroup: 0x2000, - ClusterSize: 0x400, - ClustersPerGroup: 0x2000, - InodeSize: 0x80, - InodesPerGroup: 0x10, - BgDescSize: 0x40, - Magic: linux.EXT_SUPER_MAGIC, - Revision: disklayout.DynamicRev, - CompatFeatures: disklayout.CompatFeatures{ - ExtAttr: true, - ResizeInode: true, - DirIndex: true, - }, - IncompatFeatures: disklayout.IncompatFeatures{ - DirentFileType: true, - Extents: true, - Is64Bit: true, - FlexBg: true, - }, - RoCompatFeatures: disklayout.RoCompatFeatures{ - Sparse: true, - LargeFile: true, - HugeFile: true, - DirNlink: true, - ExtraIsize: true, - MetadataCsum: true, - }, - }, - wantBgs: []bg{ - { - InodeTable: 0x23, - BlockBitmap: 0x3, - InodeBitmap: 0x13, - Flags: disklayout.BGFlags{ - InodeZeroed: true, - }, - }, - }, - }, - { - name: "ext3 filesystem init", - image: ext3ImagePath, - wantSb: sb{ - InodesCount: 0x10, - BlocksCount: 0x40, - MaxMountCount: 0xffff, - FirstDataBlock: 0x1, - BlockSize: 0x400, - BlocksPerGroup: 0x2000, - ClusterSize: 0x400, - ClustersPerGroup: 0x2000, - InodeSize: 0x80, - InodesPerGroup: 0x10, - BgDescSize: 0x20, - Magic: linux.EXT_SUPER_MAGIC, - Revision: disklayout.DynamicRev, - CompatFeatures: disklayout.CompatFeatures{ - ExtAttr: true, - ResizeInode: true, - DirIndex: true, - }, - IncompatFeatures: disklayout.IncompatFeatures{ - DirentFileType: true, - }, - RoCompatFeatures: disklayout.RoCompatFeatures{ - Sparse: true, - LargeFile: true, - }, - }, - wantBgs: []bg{ - { - InodeTable: 0x5, - BlockBitmap: 0x3, - InodeBitmap: 0x4, - Flags: disklayout.BGFlags{ - InodeZeroed: true, - }, - }, - }, - }, - { - name: "ext2 filesystem init", - image: ext2ImagePath, - wantSb: sb{ - InodesCount: 0x10, - BlocksCount: 0x40, - MaxMountCount: 0xffff, - FirstDataBlock: 0x1, - BlockSize: 0x400, - BlocksPerGroup: 0x2000, - ClusterSize: 0x400, - ClustersPerGroup: 0x2000, - InodeSize: 0x80, - InodesPerGroup: 0x10, - BgDescSize: 0x20, - Magic: linux.EXT_SUPER_MAGIC, - Revision: disklayout.DynamicRev, - CompatFeatures: disklayout.CompatFeatures{ - ExtAttr: true, - ResizeInode: true, - DirIndex: true, - }, - IncompatFeatures: disklayout.IncompatFeatures{ - DirentFileType: true, - }, - RoCompatFeatures: disklayout.RoCompatFeatures{ - Sparse: true, - LargeFile: true, - }, - }, - wantBgs: []bg{ - { - InodeTable: 0x5, - BlockBitmap: 0x3, - InodeBitmap: 0x4, - Flags: disklayout.BGFlags{ - InodeZeroed: true, - }, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - _, _, vd, tearDown, err := setUp(t, test.image) - if err != nil { - t.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - fs, ok := vd.Mount().Filesystem().Impl().(*filesystem) - if !ok { - t.Fatalf("ext filesystem of incorrect type: %T", vd.Mount().Filesystem().Impl()) - } - - // Offload superblock and block group descriptors contents into - // local structs for comparison. - totalFreeInodes := uint32(0) - totalFreeBlocks := uint64(0) - gotSb := sb{ - InodesCount: fs.sb.InodesCount(), - BlocksCount: fs.sb.BlocksCount(), - MaxMountCount: fs.sb.MaxMountCount(), - FirstDataBlock: fs.sb.FirstDataBlock(), - BlockSize: fs.sb.BlockSize(), - BlocksPerGroup: fs.sb.BlocksPerGroup(), - ClusterSize: fs.sb.ClusterSize(), - ClustersPerGroup: fs.sb.ClustersPerGroup(), - InodeSize: fs.sb.InodeSize(), - InodesPerGroup: fs.sb.InodesPerGroup(), - BgDescSize: fs.sb.BgDescSize(), - Magic: fs.sb.Magic(), - Revision: fs.sb.Revision(), - CompatFeatures: fs.sb.CompatibleFeatures(), - IncompatFeatures: fs.sb.IncompatibleFeatures(), - RoCompatFeatures: fs.sb.ReadOnlyCompatibleFeatures(), - } - gotNumBgs := len(fs.bgs) - gotBgs := make([]bg, gotNumBgs) - for i := 0; i < gotNumBgs; i++ { - gotBgs[i].InodeTable = fs.bgs[i].InodeTable() - gotBgs[i].BlockBitmap = fs.bgs[i].BlockBitmap() - gotBgs[i].InodeBitmap = fs.bgs[i].InodeBitmap() - gotBgs[i].ExclusionBitmap = fs.bgs[i].ExclusionBitmap() - gotBgs[i].Flags = fs.bgs[i].Flags() - - totalFreeInodes += fs.bgs[i].FreeInodesCount() - totalFreeBlocks += uint64(fs.bgs[i].FreeBlocksCount()) - } - - if diff := cmp.Diff(gotSb, test.wantSb); diff != "" { - t.Errorf("superblock mismatch (-want +got):\n%s", diff) - } - - if diff := cmp.Diff(gotBgs, test.wantBgs); diff != "" { - t.Errorf("block group descriptors mismatch (-want +got):\n%s", diff) - } - - if diff := cmp.Diff(totalFreeInodes, fs.sb.FreeInodesCount()); diff != "" { - t.Errorf("total free inodes mismatch (-want +got):\n%s", diff) - } - - if diff := cmp.Diff(totalFreeBlocks, fs.sb.FreeBlocksCount()); diff != "" { - t.Errorf("total free blocks mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go deleted file mode 100644 index 11dcc0346..000000000 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "io" - "sort" - - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" - "gvisor.dev/gvisor/pkg/syserror" -) - -// extentFile is a type of regular file which uses extents to store file data. -type extentFile struct { - regFile regularFile - - // root is the root extent node. This lives in the 60 byte diskInode.Data(). - // Immutable. - root disklayout.ExtentNode -} - -// Compiles only if extentFile implements io.ReaderAt. -var _ io.ReaderAt = (*extentFile)(nil) - -// newExtentFile is the extent file constructor. It reads the entire extent -// tree into memory. -// TODO(b/134676337): Build extent tree on demand to reduce memory usage. -func newExtentFile(regFile regularFile) (*extentFile, error) { - file := &extentFile{regFile: regFile} - file.regFile.impl = file - err := file.buildExtTree() - if err != nil { - return nil, err - } - return file, nil -} - -// buildExtTree builds the extent tree by reading it from disk by doing -// running a simple DFS. It first reads the root node from the inode struct in -// memory. Then it recursively builds the rest of the tree by reading it off -// disk. -// -// Precondition: inode flag InExtents must be set. -func (f *extentFile) buildExtTree() error { - rootNodeData := f.regFile.inode.diskInode.Data() - - binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header) - - // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries. - if f.root.Header.NumEntries > 4 { - // read(2) specifies that EINVAL should be returned if the file is unsuitable - // for reading. - return syserror.EINVAL - } - - f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries) - for i, off := uint16(0), disklayout.ExtentEntrySize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize { - var curEntry disklayout.ExtentEntry - if f.root.Header.Height == 0 { - // Leaf node. - curEntry = &disklayout.Extent{} - } else { - // Internal node. - curEntry = &disklayout.ExtentIdx{} - } - binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry) - f.root.Entries[i].Entry = curEntry - } - - // If this node is internal, perform DFS. - if f.root.Header.Height > 0 { - for i := uint16(0); i < f.root.Header.NumEntries; i++ { - var err error - if f.root.Entries[i].Node, err = f.buildExtTreeFromDisk(f.root.Entries[i].Entry); err != nil { - return err - } - } - } - - return nil -} - -// buildExtTreeFromDisk reads the extent tree nodes from disk and recursively -// builds the tree. Performs a simple DFS. It returns the ExtentNode pointed to -// by the ExtentEntry. -func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) { - var header disklayout.ExtentHeader - off := entry.PhysicalBlock() * f.regFile.inode.blkSize - err := readFromDisk(f.regFile.inode.fs.dev, int64(off), &header) - if err != nil { - return nil, err - } - - entries := make([]disklayout.ExtentEntryPair, header.NumEntries) - for i, off := uint16(0), off+disklayout.ExtentEntrySize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize { - var curEntry disklayout.ExtentEntry - if header.Height == 0 { - // Leaf node. - curEntry = &disklayout.Extent{} - } else { - // Internal node. - curEntry = &disklayout.ExtentIdx{} - } - - err := readFromDisk(f.regFile.inode.fs.dev, int64(off), curEntry) - if err != nil { - return nil, err - } - entries[i].Entry = curEntry - } - - // If this node is internal, perform DFS. - if header.Height > 0 { - for i := uint16(0); i < header.NumEntries; i++ { - var err error - entries[i].Node, err = f.buildExtTreeFromDisk(entries[i].Entry) - if err != nil { - return nil, err - } - } - } - - return &disklayout.ExtentNode{header, entries}, nil -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (f *extentFile) ReadAt(dst []byte, off int64) (int, error) { - if len(dst) == 0 { - return 0, nil - } - - if off < 0 { - return 0, syserror.EINVAL - } - - if uint64(off) >= f.regFile.inode.diskInode.Size() { - return 0, io.EOF - } - - n, err := f.read(&f.root, uint64(off), dst) - if n < len(dst) && err == nil { - err = io.EOF - } - return n, err -} - -// read is the recursive step of extentFile.ReadAt which traverses the extent -// tree from the node passed and reads file data. -func (f *extentFile) read(node *disklayout.ExtentNode, off uint64, dst []byte) (int, error) { - // Perform a binary search for the node covering bytes starting at r.fileOff. - // A highly fragmented filesystem can have upto 340 entries and so linear - // search should be avoided. Finds the first entry which does not cover the - // file block we want and subtracts 1 to get the desired index. - fileBlk := uint32(off / f.regFile.inode.blkSize) - n := len(node.Entries) - found := sort.Search(n, func(i int) bool { - return node.Entries[i].Entry.FileBlock() > fileBlk - }) - 1 - - // We should be in this recursive step only if the data we want exists under - // the current node. - if found < 0 { - panic("searching for a file block in an extent entry which does not cover it") - } - - read := 0 - toRead := len(dst) - var curR int - var err error - for i := found; i < n && read < toRead; i++ { - if node.Header.Height == 0 { - curR, err = f.readFromExtent(node.Entries[i].Entry.(*disklayout.Extent), off, dst[read:]) - } else { - curR, err = f.read(node.Entries[i].Node, off, dst[read:]) - } - - read += curR - off += uint64(curR) - if err != nil { - return read, err - } - } - - return read, nil -} - -// readFromExtent reads file data from the extent. It takes advantage of the -// sequential nature of extents and reads file data from multiple blocks in one -// call. -// -// A non-nil error indicates that this is a partial read and there is probably -// more to read from this extent. The caller should propagate the error upward -// and not move to the next extent in the tree. -// -// A subsequent call to extentReader.Read should continue reading from where we -// left off as expected. -func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byte) (int, error) { - curFileBlk := uint32(off / f.regFile.inode.blkSize) - exFirstFileBlk := ex.FileBlock() - exLastFileBlk := exFirstFileBlk + uint32(ex.Length) // This is exclusive. - - // We should be in this recursive step only if the data we want exists under - // the current extent. - if curFileBlk < exFirstFileBlk || exLastFileBlk <= curFileBlk { - panic("searching for a file block in an extent which does not cover it") - } - - curPhyBlk := uint64(curFileBlk-exFirstFileBlk) + ex.PhysicalBlock() - readStart := curPhyBlk*f.regFile.inode.blkSize + (off % f.regFile.inode.blkSize) - - endPhyBlk := ex.PhysicalBlock() + uint64(ex.Length) - extentEnd := endPhyBlk * f.regFile.inode.blkSize // This is exclusive. - - toRead := int(extentEnd - readStart) - if len(dst) < toRead { - toRead = len(dst) - } - - n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], int64(readStart)) - if n < toRead { - return n, syserror.EIO - } - return n, nil -} diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go deleted file mode 100644 index a2382daa3..000000000 --- a/pkg/sentry/fsimpl/ext/extent_test.go +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "bytes" - "math/rand" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" -) - -const ( - // mockExtentBlkSize is the mock block size used for testing. - // No block has more than 1 header + 4 entries. - mockExtentBlkSize = uint64(64) -) - -// The tree described below looks like: -// -// 0.{Head}[Idx][Idx] -// / \ -// / \ -// 1.{Head}[Ext][Ext] 2.{Head}[Idx] -// / | \ -// [Phy] [Phy, Phy] 3.{Head}[Ext] -// | -// [Phy, Phy, Phy] -// -// Legend: -// - Head = ExtentHeader -// - Idx = ExtentIdx -// - Ext = Extent -// - Phy = Physical Block -// -// Please note that ext4 might not construct extent trees looking like this. -// This is purely for testing the tree traversal logic. -var ( - node3 = &disklayout.ExtentNode{ - Header: disklayout.ExtentHeader{ - Magic: disklayout.ExtentMagic, - NumEntries: 1, - MaxEntries: 4, - Height: 0, - }, - Entries: []disklayout.ExtentEntryPair{ - { - Entry: &disklayout.Extent{ - FirstFileBlock: 3, - Length: 3, - StartBlockLo: 6, - }, - Node: nil, - }, - }, - } - - node2 = &disklayout.ExtentNode{ - Header: disklayout.ExtentHeader{ - Magic: disklayout.ExtentMagic, - NumEntries: 1, - MaxEntries: 4, - Height: 1, - }, - Entries: []disklayout.ExtentEntryPair{ - { - Entry: &disklayout.ExtentIdx{ - FirstFileBlock: 3, - ChildBlockLo: 2, - }, - Node: node3, - }, - }, - } - - node1 = &disklayout.ExtentNode{ - Header: disklayout.ExtentHeader{ - Magic: disklayout.ExtentMagic, - NumEntries: 2, - MaxEntries: 4, - Height: 0, - }, - Entries: []disklayout.ExtentEntryPair{ - { - Entry: &disklayout.Extent{ - FirstFileBlock: 0, - Length: 1, - StartBlockLo: 3, - }, - Node: nil, - }, - { - Entry: &disklayout.Extent{ - FirstFileBlock: 1, - Length: 2, - StartBlockLo: 4, - }, - Node: nil, - }, - }, - } - - node0 = &disklayout.ExtentNode{ - Header: disklayout.ExtentHeader{ - Magic: disklayout.ExtentMagic, - NumEntries: 2, - MaxEntries: 4, - Height: 2, - }, - Entries: []disklayout.ExtentEntryPair{ - { - Entry: &disklayout.ExtentIdx{ - FirstFileBlock: 0, - ChildBlockLo: 0, - }, - Node: node1, - }, - { - Entry: &disklayout.ExtentIdx{ - FirstFileBlock: 3, - ChildBlockLo: 1, - }, - Node: node2, - }, - }, - } -) - -// TestExtentReader stress tests extentReader functionality. It performs random -// length reads from all possible positions in the extent tree. -func TestExtentReader(t *testing.T) { - mockExtentFile, want := extentTreeSetUp(t, node0) - n := len(want) - - for from := 0; from < n; from++ { - got := make([]byte, n-from) - - if read, err := mockExtentFile.ReadAt(got, int64(from)); err != nil { - t.Fatalf("file read operation from offset %d to %d only read %d bytes: %v", from, n, read, err) - } - - if diff := cmp.Diff(got, want[from:]); diff != "" { - t.Fatalf("file data from offset %d to %d mismatched (-want +got):\n%s", from, n, diff) - } - } -} - -// TestBuildExtentTree tests the extent tree building logic. -func TestBuildExtentTree(t *testing.T) { - mockExtentFile, _ := extentTreeSetUp(t, node0) - - opt := cmpopts.IgnoreUnexported(disklayout.ExtentIdx{}, disklayout.ExtentHeader{}) - if diff := cmp.Diff(&mockExtentFile.root, node0, opt); diff != "" { - t.Errorf("extent tree mismatch (-want +got):\n%s", diff) - } -} - -// extentTreeSetUp writes the passed extent tree to a mock disk as an extent -// tree. It also constucts a mock extent file with the same tree built in it. -// It also writes random data file data and returns it. -func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []byte) { - t.Helper() - - mockDisk := make([]byte, mockExtentBlkSize*10) - mockExtentFile := &extentFile{ - regFile: regularFile{ - inode: inode{ - fs: &filesystem{ - dev: bytes.NewReader(mockDisk), - }, - diskInode: &disklayout.InodeNew{ - InodeOld: disklayout.InodeOld{ - SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root), - }, - }, - blkSize: mockExtentBlkSize, - }, - }, - } - - fileData := writeTree(&mockExtentFile.regFile.inode, mockDisk, node0, mockExtentBlkSize) - - if err := mockExtentFile.buildExtTree(); err != nil { - t.Fatalf("inode.buildExtTree failed: %v", err) - } - return mockExtentFile, fileData -} - -// writeTree writes the tree represented by `root` to the inode and disk. It -// also writes random file data on disk. -func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte { - rootData := binary.Marshal(nil, binary.LittleEndian, root.Header) - for _, ep := range root.Entries { - rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry) - } - - copy(in.diskInode.Data(), rootData) - - var fileData []byte - for _, ep := range root.Entries { - if root.Header.Height == 0 { - fileData = append(fileData, writeFileDataToExtent(disk, ep.Entry.(*disklayout.Extent))...) - } else { - fileData = append(fileData, writeTreeToDisk(disk, ep)...) - } - } - return fileData -} - -// writeTreeToDisk is the recursive step for writeTree which writes the tree -// on the disk only. Also writes random file data on disk. -func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte { - nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header) - for _, ep := range curNode.Node.Entries { - nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry) - } - - copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData) - - var fileData []byte - for _, ep := range curNode.Node.Entries { - if curNode.Node.Header.Height == 0 { - fileData = append(fileData, writeFileDataToExtent(disk, ep.Entry.(*disklayout.Extent))...) - } else { - fileData = append(fileData, writeTreeToDisk(disk, ep)...) - } - } - return fileData -} - -// writeFileDataToExtent writes random bytes to the blocks on disk that the -// passed extent points to. -func writeFileDataToExtent(disk []byte, ex *disklayout.Extent) []byte { - phyExStartBlk := ex.PhysicalBlock() - phyExStartOff := phyExStartBlk * mockExtentBlkSize - phyExEndOff := phyExStartOff + uint64(ex.Length)*mockExtentBlkSize - rand.Read(disk[phyExStartOff:phyExEndOff]) - return disk[phyExStartOff:phyExEndOff] -} - -// getNumPhyBlks returns the number of physical blocks covered under the node. -func getNumPhyBlks(node *disklayout.ExtentNode) uint32 { - var res uint32 - for _, ep := range node.Entries { - if node.Header.Height == 0 { - res += uint32(ep.Entry.(*disklayout.Extent).Length) - } else { - res += getNumPhyBlks(ep.Node) - } - } - return res -} diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go deleted file mode 100644 index 92f7da40d..000000000 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// fileDescription is embedded by ext implementations of -// vfs.FileDescriptionImpl. -type fileDescription struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl -} - -func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) -} - -func (fd *fileDescription) inode() *inode { - return fd.vfsfd.Dentry().Impl().(*dentry).inode -} - -// Stat implements vfs.FileDescriptionImpl.Stat. -func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - var stat linux.Statx - fd.inode().statTo(&stat) - return stat, nil -} - -// SetStat implements vfs.FileDescriptionImpl.SetStat. -func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - if opts.Stat.Mask == 0 { - return nil - } - return syserror.EPERM -} - -// SetStat implements vfs.FileDescriptionImpl.StatFS. -func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { - var stat linux.Statfs - fd.filesystem().statTo(&stat) - return stat, nil -} - -// Sync implements vfs.FileDescriptionImpl.Sync. -func (fd *fileDescription) Sync(ctx context.Context) error { - return nil -} diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go deleted file mode 100644 index 8497be615..000000000 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ /dev/null @@ -1,507 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "errors" - "io" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" -) - -var ( - // errResolveDirent indicates that the vfs.ResolvingPath.Component() does - // not exist on the dentry tree but does exist on disk. So it has to be read in - // using the in-memory dirent and added to the dentry tree. Usually indicates - // the need to lock filesystem.mu for writing. - errResolveDirent = errors.New("resolve path component using dirent") -) - -// filesystem implements vfs.FilesystemImpl. -type filesystem struct { - vfsfs vfs.Filesystem - - // mu serializes changes to the Dentry tree. - mu sync.RWMutex - - // dev represents the underlying fs device. It does not require protection - // because io.ReaderAt permits concurrent read calls to it. It translates to - // the pread syscall which passes on the read request directly to the device - // driver. Device drivers are intelligent in serving multiple concurrent read - // requests in the optimal order (taking locality into consideration). - dev io.ReaderAt - - // inodeCache maps absolute inode numbers to the corresponding Inode struct. - // Inodes should be removed from this once their reference count hits 0. - // - // Protected by mu because most additions (see IterDirents) and all removals - // from this corresponds to a change in the dentry tree. - inodeCache map[uint32]*inode - - // sb represents the filesystem superblock. Immutable after initialization. - sb disklayout.SuperBlock - - // bgs represents all the block group descriptors for the filesystem. - // Immutable after initialization. - bgs []disklayout.BlockGroup -} - -// Compiles only if filesystem implements vfs.FilesystemImpl. -var _ vfs.FilesystemImpl = (*filesystem)(nil) - -// stepLocked resolves rp.Component() in parent directory vfsd. The write -// parameter passed tells if the caller has acquired filesystem.mu for writing -// or not. If set to true, an existing inode on disk can be added to the dentry -// tree if not present already. -// -// stepLocked is loosely analogous to fs/namei.c:walk_component(). -// -// Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). -// - !rp.Done(). -// - inode == vfsd.Impl().(*Dentry).inode. -func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) { - if !inode.isDir() { - return nil, nil, syserror.ENOTDIR - } - if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, nil, err - } - - for { - nextVFSD, err := rp.ResolveComponent(vfsd) - if err != nil { - return nil, nil, err - } - if nextVFSD == nil { - // Since the Dentry tree is not the sole source of truth for extfs, if it's - // not in the Dentry tree, it might need to be pulled from disk. - childDirent, ok := inode.impl.(*directory).childMap[rp.Component()] - if !ok { - // The underlying inode does not exist on disk. - return nil, nil, syserror.ENOENT - } - - if !write { - // filesystem.mu must be held for writing to add to the dentry tree. - return nil, nil, errResolveDirent - } - - // Create and add the component's dirent to the dentry tree. - fs := rp.Mount().Filesystem().Impl().(*filesystem) - childInode, err := fs.getOrCreateInodeLocked(childDirent.diskDirent.Inode()) - if err != nil { - return nil, nil, err - } - // incRef because this is being added to the dentry tree. - childInode.incRef() - child := newDentry(childInode) - vfsd.InsertChild(&child.vfsd, rp.Component()) - - // Continue as usual now that nextVFSD is not nil. - nextVFSD = &child.vfsd - } - nextInode := nextVFSD.Impl().(*dentry).inode - if nextInode.isSymlink() && rp.ShouldFollowSymlink() { - if err := rp.HandleSymlink(inode.impl.(*symlink).target); err != nil { - return nil, nil, err - } - continue - } - rp.Advance() - return nextVFSD, nextInode, nil - } -} - -// walkLocked resolves rp to an existing file. The write parameter -// passed tells if the caller has acquired filesystem.mu for writing or not. -// If set to true, additions can be made to the dentry tree while walking. -// If errResolveDirent is returned, the walk needs to be continued with an -// upgraded filesystem.mu. -// -// walkLocked is loosely analogous to Linux's fs/namei.c:path_lookupat(). -// -// Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). -func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { - vfsd := rp.Start() - inode := vfsd.Impl().(*dentry).inode - for !rp.Done() { - var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode, write) - if err != nil { - return nil, nil, err - } - } - if rp.MustBeDir() && !inode.isDir() { - return nil, nil, syserror.ENOTDIR - } - return vfsd, inode, nil -} - -// walkParentLocked resolves all but the last path component of rp to an -// existing directory. It does not check that the returned directory is -// searchable by the provider of rp. The write parameter passed tells if the -// caller has acquired filesystem.mu for writing or not. If set to true, -// additions can be made to the dentry tree while walking. -// If errResolveDirent is returned, the walk needs to be continued with an -// upgraded filesystem.mu. -// -// walkParentLocked is loosely analogous to Linux's fs/namei.c:path_parentat(). -// -// Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). -// - !rp.Done(). -func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { - vfsd := rp.Start() - inode := vfsd.Impl().(*dentry).inode - for !rp.Final() { - var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode, write) - if err != nil { - return nil, nil, err - } - } - if !inode.isDir() { - return nil, nil, syserror.ENOTDIR - } - return vfsd, inode, nil -} - -// walk resolves rp to an existing file. If parent is set to true, it resolves -// the rp till the parent of the last component which should be an existing -// directory. If parent is false then resolves rp entirely. Attemps to resolve -// the path as far as it can with a read lock and upgrades the lock if needed. -func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) { - var ( - vfsd *vfs.Dentry - inode *inode - err error - ) - - // Try walking with the hopes that all dentries have already been pulled out - // of disk. This reduces congestion (allows concurrent walks). - fs.mu.RLock() - if parent { - vfsd, inode, err = walkParentLocked(rp, false) - } else { - vfsd, inode, err = walkLocked(rp, false) - } - fs.mu.RUnlock() - - if err == errResolveDirent { - // Upgrade lock and continue walking. Lock upgrading in the middle of the - // walk is fine as this is a read only filesystem. - fs.mu.Lock() - if parent { - vfsd, inode, err = walkParentLocked(rp, true) - } else { - vfsd, inode, err = walkLocked(rp, true) - } - fs.mu.Unlock() - } - - return vfsd, inode, err -} - -// getOrCreateInodeLocked gets the inode corresponding to the inode number passed in. -// It creates a new one with the given inode number if one does not exist. -// The caller must increment the ref count if adding this to the dentry tree. -// -// Precondition: must be holding fs.mu for writing. -func (fs *filesystem) getOrCreateInodeLocked(inodeNum uint32) (*inode, error) { - if in, ok := fs.inodeCache[inodeNum]; ok { - return in, nil - } - - in, err := newInode(fs, inodeNum) - if err != nil { - return nil, err - } - - fs.inodeCache[inodeNum] = in - return in, nil -} - -// statTo writes the statfs fields to the output parameter. -func (fs *filesystem) statTo(stat *linux.Statfs) { - stat.Type = uint64(fs.sb.Magic()) - stat.BlockSize = int64(fs.sb.BlockSize()) - stat.Blocks = fs.sb.BlocksCount() - stat.BlocksFree = fs.sb.FreeBlocksCount() - stat.BlocksAvailable = fs.sb.FreeBlocksCount() - stat.Files = uint64(fs.sb.InodesCount()) - stat.FilesFree = uint64(fs.sb.FreeInodesCount()) - stat.NameLength = disklayout.MaxFileName - stat.FragmentSize = int64(fs.sb.BlockSize()) - // TODO(b/134676337): Set Statfs.Flags and Statfs.FSID. -} - -// AccessAt implements vfs.Filesystem.Impl.AccessAt. -func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { - _, inode, err := fs.walk(rp, false) - if err != nil { - return err - } - return inode.checkPermissions(rp.Credentials(), ats) -} - -// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. -func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { - vfsd, inode, err := fs.walk(rp, false) - if err != nil { - return nil, err - } - - if opts.CheckSearchable { - if !inode.isDir() { - return nil, syserror.ENOTDIR - } - if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err - } - } - - inode.incRef() - return vfsd, nil -} - -// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. -func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { - vfsd, inode, err := fs.walk(rp, true) - if err != nil { - return nil, err - } - inode.incRef() - return vfsd, nil -} - -// OpenAt implements vfs.FilesystemImpl.OpenAt. -func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - vfsd, inode, err := fs.walk(rp, false) - if err != nil { - return nil, err - } - - // EROFS is returned if write access is needed. - if vfs.MayWriteFileWithOpenFlags(opts.Flags) || opts.Flags&(linux.O_CREAT|linux.O_EXCL|linux.O_TMPFILE) != 0 { - return nil, syserror.EROFS - } - return inode.open(rp, vfsd, &opts) -} - -// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. -func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - _, inode, err := fs.walk(rp, false) - if err != nil { - return "", err - } - symlink, ok := inode.impl.(*symlink) - if !ok { - return "", syserror.EINVAL - } - return symlink.target, nil -} - -// StatAt implements vfs.FilesystemImpl.StatAt. -func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { - _, inode, err := fs.walk(rp, false) - if err != nil { - return linux.Statx{}, err - } - var stat linux.Statx - inode.statTo(&stat) - return stat, nil -} - -// StatFSAt implements vfs.FilesystemImpl.StatFSAt. -func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { - if _, _, err := fs.walk(rp, false); err != nil { - return linux.Statfs{}, err - } - - var stat linux.Statfs - fs.statTo(&stat) - return stat, nil -} - -// Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() {} - -// Sync implements vfs.FilesystemImpl.Sync. -func (fs *filesystem) Sync(ctx context.Context) error { - // This is a readonly filesystem for now. - return nil -} - -// The vfs.FilesystemImpl functions below return EROFS because their respective -// man pages say that EROFS must be returned if the path resolves to a file on -// this read-only filesystem. - -// LinkAt implements vfs.FilesystemImpl.LinkAt. -func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - if rp.Done() { - return syserror.EEXIST - } - - if _, _, err := fs.walk(rp, true); err != nil { - return err - } - - return syserror.EROFS -} - -// MkdirAt implements vfs.FilesystemImpl.MkdirAt. -func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - if rp.Done() { - return syserror.EEXIST - } - - if _, _, err := fs.walk(rp, true); err != nil { - return err - } - - return syserror.EROFS -} - -// MknodAt implements vfs.FilesystemImpl.MknodAt. -func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - if rp.Done() { - return syserror.EEXIST - } - - _, _, err := fs.walk(rp, true) - if err != nil { - return err - } - - return syserror.EROFS -} - -// RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if rp.Done() { - return syserror.ENOENT - } - - _, _, err := fs.walk(rp, false) - if err != nil { - return err - } - - return syserror.EROFS -} - -// RmdirAt implements vfs.FilesystemImpl.RmdirAt. -func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { - _, inode, err := fs.walk(rp, false) - if err != nil { - return err - } - - if !inode.isDir() { - return syserror.ENOTDIR - } - - return syserror.EROFS -} - -// SetStatAt implements vfs.FilesystemImpl.SetStatAt. -func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { - _, _, err := fs.walk(rp, false) - if err != nil { - return err - } - - return syserror.EROFS -} - -// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. -func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - if rp.Done() { - return syserror.EEXIST - } - - _, _, err := fs.walk(rp, true) - if err != nil { - return err - } - - return syserror.EROFS -} - -// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. -func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { - _, inode, err := fs.walk(rp, false) - if err != nil { - return err - } - - if inode.isDir() { - return syserror.EISDIR - } - - return syserror.EROFS -} - -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { - _, _, err := fs.walk(rp, false) - if err != nil { - return nil, err - } - return nil, syserror.ENOTSUP -} - -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { - _, _, err := fs.walk(rp, false) - if err != nil { - return "", err - } - return "", syserror.ENOTSUP -} - -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { - _, _, err := fs.walk(rp, false) - if err != nil { - return err - } - return syserror.ENOTSUP -} - -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { - _, _, err := fs.walk(rp, false) - if err != nil { - return err - } - return syserror.ENOTSUP -} - -// PrependPath implements vfs.FilesystemImpl.PrependPath. -func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - return vfs.GenericPrependPath(vfsroot, vd, b) -} diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go deleted file mode 100644 index 6962083f5..000000000 --- a/pkg/sentry/fsimpl/ext/inode.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "fmt" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// inode represents an ext inode. -// -// inode uses the same inheritance pattern that pkg/sentry/vfs structures use. -// This has been done to increase memory locality. -// -// Implementations: -// inode -- -// |-- dir -// |-- symlink -// |-- regular-- -// |-- extent file -// |-- block map file -type inode struct { - // refs is a reference count. refs is accessed using atomic memory operations. - refs int64 - - // fs is the containing filesystem. - fs *filesystem - - // inodeNum is the inode number of this inode on disk. This is used to - // identify inodes within the ext filesystem. - inodeNum uint32 - - // blkSize is the fs data block size. Same as filesystem.sb.BlockSize(). - blkSize uint64 - - // diskInode gives us access to the inode struct on disk. Immutable. - diskInode disklayout.Inode - - // This is immutable. The first field of the implementations must have inode - // as the first field to ensure temporality. - impl interface{} -} - -// incRef increments the inode ref count. -func (in *inode) incRef() { - atomic.AddInt64(&in.refs, 1) -} - -// tryIncRef tries to increment the ref count. Returns true if successful. -func (in *inode) tryIncRef() bool { - for { - refs := atomic.LoadInt64(&in.refs) - if refs == 0 { - return false - } - if atomic.CompareAndSwapInt64(&in.refs, refs, refs+1) { - return true - } - } -} - -// decRef decrements the inode ref count and releases the inode resources if -// the ref count hits 0. -// -// Precondition: Must have locked filesystem.mu. -func (in *inode) decRef() { - if refs := atomic.AddInt64(&in.refs, -1); refs == 0 { - delete(in.fs.inodeCache, in.inodeNum) - } else if refs < 0 { - panic("ext.inode.decRef() called without holding a reference") - } -} - -// newInode is the inode constructor. Reads the inode off disk. Identifies -// inodes based on the absolute inode number on disk. -func newInode(fs *filesystem, inodeNum uint32) (*inode, error) { - if inodeNum == 0 { - panic("inode number 0 on ext filesystems is not possible") - } - - inodeRecordSize := fs.sb.InodeSize() - var diskInode disklayout.Inode - if inodeRecordSize == disklayout.OldInodeSize { - diskInode = &disklayout.InodeOld{} - } else { - diskInode = &disklayout.InodeNew{} - } - - // Calculate where the inode is actually placed. - inodesPerGrp := fs.sb.InodesPerGroup() - blkSize := fs.sb.BlockSize() - inodeTableOff := fs.bgs[getBGNum(inodeNum, inodesPerGrp)].InodeTable() * blkSize - inodeOff := inodeTableOff + uint64(uint32(inodeRecordSize)*getBGOff(inodeNum, inodesPerGrp)) - - if err := readFromDisk(fs.dev, int64(inodeOff), diskInode); err != nil { - return nil, err - } - - // Build the inode based on its type. - inode := inode{ - fs: fs, - inodeNum: inodeNum, - blkSize: blkSize, - diskInode: diskInode, - } - - switch diskInode.Mode().FileType() { - case linux.ModeSymlink: - f, err := newSymlink(inode) - if err != nil { - return nil, err - } - return &f.inode, nil - case linux.ModeRegular: - f, err := newRegularFile(inode) - if err != nil { - return nil, err - } - return &f.inode, nil - case linux.ModeDirectory: - f, err := newDirectroy(inode, fs.sb.IncompatibleFeatures().DirentFileType) - if err != nil { - return nil, err - } - return &f.inode, nil - default: - // TODO(b/134676337): Return appropriate errors for sockets, pipes and devices. - return nil, syserror.EINVAL - } -} - -// open creates and returns a file description for the dentry passed in. -func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { - ats := vfs.AccessTypesForOpenFlags(opts) - if err := in.checkPermissions(rp.Credentials(), ats); err != nil { - return nil, err - } - mnt := rp.Mount() - switch in.impl.(type) { - case *regularFile: - var fd regularFileFD - if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - return &fd.vfsfd, nil - case *directory: - // Can't open directories writably. This check is not necessary for a read - // only filesystem but will be required when write is implemented. - if ats&vfs.MayWrite != 0 { - return nil, syserror.EISDIR - } - var fd directoryFD - if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - return &fd.vfsfd, nil - case *symlink: - if opts.Flags&linux.O_PATH == 0 { - // Can't open symlinks without O_PATH. - return nil, syserror.ELOOP - } - var fd symlinkFD - fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) - return &fd.vfsfd, nil - default: - panic(fmt.Sprintf("unknown inode type: %T", in.impl)) - } -} - -func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { - return vfs.GenericCheckPermissions(creds, ats, in.isDir(), uint16(in.diskInode.Mode()), in.diskInode.UID(), in.diskInode.GID()) -} - -// statTo writes the statx fields to the output parameter. -func (in *inode) statTo(stat *linux.Statx) { - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | - linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE | - linux.STATX_ATIME | linux.STATX_CTIME | linux.STATX_MTIME - stat.Blksize = uint32(in.blkSize) - stat.Mode = uint16(in.diskInode.Mode()) - stat.Nlink = uint32(in.diskInode.LinksCount()) - stat.UID = uint32(in.diskInode.UID()) - stat.GID = uint32(in.diskInode.GID()) - stat.Ino = uint64(in.inodeNum) - stat.Size = in.diskInode.Size() - stat.Atime = in.diskInode.AccessTime().StatxTimestamp() - stat.Ctime = in.diskInode.ChangeTime().StatxTimestamp() - stat.Mtime = in.diskInode.ModificationTime().StatxTimestamp() - // TODO(b/134676337): Set stat.Blocks which is the number of 512 byte blocks - // (including metadata blocks) required to represent this file. -} - -// getBGNum returns the block group number that a given inode belongs to. -func getBGNum(inodeNum uint32, inodesPerGrp uint32) uint32 { - return (inodeNum - 1) / inodesPerGrp -} - -// getBGOff returns the offset at which the given inode lives in the block -// group's inode table, i.e. the index of the inode in the inode table. -func getBGOff(inodeNum uint32, inodesPerGrp uint32) uint32 { - return (inodeNum - 1) % inodesPerGrp -} diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go deleted file mode 100644 index 30135ddb0..000000000 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "io" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// regularFile represents a regular file's inode. This too follows the -// inheritance pattern prevelant in the vfs layer described in -// pkg/sentry/vfs/README.md. -type regularFile struct { - inode inode - - // This is immutable. The first field of fileReader implementations must be - // regularFile to ensure temporality. - // io.ReaderAt is more strict than io.Reader in the sense that a partial read - // is always accompanied by an error. If a read spans past the end of file, a - // partial read (within file range) is done and io.EOF is returned. - impl io.ReaderAt -} - -// newRegularFile is the regularFile constructor. It figures out what kind of -// file this is and initializes the fileReader. -func newRegularFile(inode inode) (*regularFile, error) { - regFile := regularFile{ - inode: inode, - } - - inodeFlags := inode.diskInode.Flags() - - if inodeFlags.Extents { - file, err := newExtentFile(regFile) - if err != nil { - return nil, err - } - - file.regFile.inode.impl = &file.regFile - return &file.regFile, nil - } - - file, err := newBlockMapFile(regFile) - if err != nil { - return nil, err - } - file.regFile.inode.impl = &file.regFile - return &file.regFile, nil -} - -func (in *inode) isRegular() bool { - _, ok := in.impl.(*regularFile) - return ok -} - -// directoryFD represents a directory file description. It implements -// vfs.FileDescriptionImpl. -type regularFileFD struct { - fileDescription - - // off is the file offset. off is accessed using atomic memory operations. - off int64 - - // offMu serializes operations that may mutate off. - offMu sync.Mutex -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() {} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - safeReader := safemem.FromIOReaderAt{ - ReaderAt: fd.inode().impl.(*regularFile).impl, - Offset: offset, - } - - // Copies data from disk directly into usermem without any intermediate - // allocations (if dst is converted into BlockSeq such that it does not need - // safe copying). - return dst.CopyOutFrom(ctx, safeReader) -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - n, err := fd.PRead(ctx, dst, fd.off, opts) - fd.offMu.Lock() - fd.off += n - fd.offMu.Unlock() - return n, err -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - // write(2) specifies that EBADF must be returned if the fd is not open for - // writing. - return 0, syserror.EBADF -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.offMu.Lock() - fd.off += n - fd.offMu.Unlock() - return n, err -} - -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *regularFileFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { - return syserror.ENOTDIR -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - fd.offMu.Lock() - defer fd.offMu.Unlock() - switch whence { - case linux.SEEK_SET: - // Use offset as specified. - case linux.SEEK_CUR: - offset += fd.off - case linux.SEEK_END: - offset += int64(fd.inode().diskInode.Size()) - default: - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL - } - fd.off = offset - return offset, nil -} - -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - // TODO(b/134676337): Implement mmap(2). - return syserror.ENODEV -} diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go deleted file mode 100644 index 1447a4dc1..000000000 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// symlink represents a symlink inode. -type symlink struct { - inode inode - target string // immutable -} - -// newSymlink is the symlink constructor. It reads out the symlink target from -// the inode (however it might have been stored). -func newSymlink(inode inode) (*symlink, error) { - var file *symlink - var link []byte - - // If the symlink target is lesser than 60 bytes, its stores in inode.Data(). - // Otherwise either extents or block maps will be used to store the link. - size := inode.diskInode.Size() - if size < 60 { - link = inode.diskInode.Data()[:size] - } else { - // Create a regular file out of this inode and read out the target. - regFile, err := newRegularFile(inode) - if err != nil { - return nil, err - } - - link = make([]byte, size) - if n, err := regFile.impl.ReadAt(link, 0); uint64(n) < size { - return nil, err - } - } - - file = &symlink{inode: inode, target: string(link)} - file.inode.impl = file - return file, nil -} - -func (in *inode) isSymlink() bool { - _, ok := in.impl.(*symlink) - return ok -} - -// symlinkFD represents a symlink file description and implements implements -// vfs.FileDescriptionImpl. which may only be used if open options contains -// O_PATH. For this reason most of the functions return EBADF. -type symlinkFD struct { - fileDescription -} - -// Compiles only if symlinkFD implements vfs.FileDescriptionImpl. -var _ vfs.FileDescriptionImpl = (*symlinkFD)(nil) - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *symlinkFD) Release() {} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *symlinkFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return 0, syserror.EBADF -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *symlinkFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return 0, syserror.EBADF -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *symlinkFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.EBADF -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *symlinkFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.EBADF -} - -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *symlinkFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { - return syserror.ENOTDIR -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - return 0, syserror.EBADF -} - -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - return syserror.EBADF -} diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go deleted file mode 100644 index d8b728f8c..000000000 --- a/pkg/sentry/fsimpl/ext/utils.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "io" - - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" - "gvisor.dev/gvisor/pkg/syserror" -) - -// readFromDisk performs a binary read from disk into the given struct from -// the absolute offset provided. -func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error { - n := binary.Size(v) - buf := make([]byte, n) - if read, _ := dev.ReadAt(buf, abOff); read < int(n) { - return syserror.EIO - } - - binary.Unmarshal(buf, binary.LittleEndian, v) - return nil -} - -// readSuperBlock reads the SuperBlock from block group 0 in the underlying -// device. There are three versions of the superblock. This function identifies -// and returns the correct version. -func readSuperBlock(dev io.ReaderAt) (disklayout.SuperBlock, error) { - var sb disklayout.SuperBlock = &disklayout.SuperBlockOld{} - if err := readFromDisk(dev, disklayout.SbOffset, sb); err != nil { - return nil, err - } - if sb.Revision() == disklayout.OldRev { - return sb, nil - } - - sb = &disklayout.SuperBlock32Bit{} - if err := readFromDisk(dev, disklayout.SbOffset, sb); err != nil { - return nil, err - } - if !sb.IncompatibleFeatures().Is64Bit { - return sb, nil - } - - sb = &disklayout.SuperBlock64Bit{} - if err := readFromDisk(dev, disklayout.SbOffset, sb); err != nil { - return nil, err - } - return sb, nil -} - -// blockGroupsCount returns the number of block groups in the ext fs. -func blockGroupsCount(sb disklayout.SuperBlock) uint64 { - blocksCount := sb.BlocksCount() - blocksPerGroup := uint64(sb.BlocksPerGroup()) - - // Round up the result. float64 can compromise precision so do it manually. - return (blocksCount + blocksPerGroup - 1) / blocksPerGroup -} - -// readBlockGroups reads the block group descriptor table from block group 0 in -// the underlying device. -func readBlockGroups(dev io.ReaderAt, sb disklayout.SuperBlock) ([]disklayout.BlockGroup, error) { - bgCount := blockGroupsCount(sb) - bgdSize := uint64(sb.BgDescSize()) - is64Bit := sb.IncompatibleFeatures().Is64Bit - bgds := make([]disklayout.BlockGroup, bgCount) - - for i, off := uint64(0), uint64(sb.FirstDataBlock()+1)*sb.BlockSize(); i < bgCount; i, off = i+1, off+bgdSize { - if is64Bit { - bgds[i] = &disklayout.BlockGroup64Bit{} - } else { - bgds[i] = &disklayout.BlockGroup32Bit{} - } - - if err := readFromDisk(dev, int64(off), bgds[i]); err != nil { - return nil, err - } - } - return bgds, nil -} diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD deleted file mode 100644 index 4ba76a1e8..000000000 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "dentry_list", - out = "dentry_list.go", - package = "gofer", - prefix = "dentry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*dentry", - "Linker": "*dentry", - }, -) - -go_library( - name = "gofer", - srcs = [ - "dentry_list.go", - "directory.go", - "filesystem.go", - "gofer.go", - "handle.go", - "handle_unsafe.go", - "p9file.go", - "pagemath.go", - "regular_file.go", - "special_file.go", - "symlink.go", - "time.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fd", - "//pkg/fspath", - "//pkg/log", - "//pkg/p9", - "//pkg/safemem", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/unet", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go deleted file mode 100644 index 5dbfc6250..000000000 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "sync" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -func (d *dentry) isDir() bool { - return d.fileType() == linux.S_IFDIR -} - -// Preconditions: d.dirMu must be locked. d.isDir(). fs.opts.interop != -// InteropModeShared. -func (d *dentry) cacheNegativeChildLocked(name string) { - if d.negativeChildren == nil { - d.negativeChildren = make(map[string]struct{}) - } - d.negativeChildren[name] = struct{}{} -} - -type directoryFD struct { - fileDescription - vfs.DirectoryFileDescriptionDefaultImpl - - mu sync.Mutex - off int64 - dirents []vfs.Dirent -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { -} - -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { - fd.mu.Lock() - defer fd.mu.Unlock() - - if fd.dirents == nil { - ds, err := fd.dentry().getDirents(ctx) - if err != nil { - return err - } - fd.dirents = ds - } - - for fd.off < int64(len(fd.dirents)) { - if err := cb.Handle(fd.dirents[fd.off]); err != nil { - return err - } - fd.off++ - } - return nil -} - -// Preconditions: d.isDir(). There exists at least one directoryFD representing d. -func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { - // 9P2000.L's readdir does not specify behavior in the presence of - // concurrent mutation of an iterated directory, so implementations may - // duplicate or omit entries in this case, which violates POSIX semantics. - // Thus we read all directory entries while holding d.dirMu to exclude - // directory mutations. (Note that it is impossible for the client to - // exclude concurrent mutation from other remote filesystem users. Since - // there is no way to detect if the server has incorrectly omitted - // directory entries, we simply assume that the server is well-behaved - // under InteropModeShared.) This is inconsistent with Linux (which appears - // to assume that directory fids have the correct semantics, and translates - // struct file_operations::readdir calls directly to readdir RPCs), but is - // consistent with VFS1. - // - // NOTE(b/135560623): In particular, some gofer implementations may not - // retain state between calls to Readdir, so may not provide a coherent - // directory stream across in the presence of mutation. - - d.fs.renameMu.RLock() - defer d.fs.renameMu.RUnlock() - d.dirMu.Lock() - defer d.dirMu.Unlock() - if d.dirents != nil { - return d.dirents, nil - } - - // It's not clear if 9P2000.L's readdir is expected to return "." and "..", - // so we generate them here. - parent := d.vfsd.ParentOrSelf().Impl().(*dentry) - dirents := []vfs.Dirent{ - { - Name: ".", - Type: linux.DT_DIR, - Ino: d.ino, - NextOff: 1, - }, - { - Name: "..", - Type: uint8(atomic.LoadUint32(&parent.mode) >> 12), - Ino: parent.ino, - NextOff: 2, - }, - } - off := uint64(0) - const count = 64 * 1024 // for consistency with the vfs1 client - d.handleMu.RLock() - defer d.handleMu.RUnlock() - if !d.handleReadable { - // This should not be possible because a readable handle should have - // been opened when the calling directoryFD was opened. - panic("gofer.dentry.getDirents called without a readable handle") - } - for { - p9ds, err := d.handle.file.readdir(ctx, off, count) - if err != nil { - return nil, err - } - if len(p9ds) == 0 { - // Cache dirents for future directoryFDs if permitted. - if d.fs.opts.interop != InteropModeShared { - d.dirents = dirents - } - return dirents, nil - } - for _, p9d := range p9ds { - if p9d.Name == "." || p9d.Name == ".." { - continue - } - dirent := vfs.Dirent{ - Name: p9d.Name, - Ino: p9d.QID.Path, - NextOff: int64(len(dirents) + 1), - } - // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or - // DMSOCKET. - switch p9d.Type { - case p9.TypeSymlink: - dirent.Type = linux.DT_LNK - case p9.TypeDir: - dirent.Type = linux.DT_DIR - default: - dirent.Type = linux.DT_REG - } - dirents = append(dirents, dirent) - } - off = p9ds[len(p9ds)-1].Offset - } -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - fd.mu.Lock() - defer fd.mu.Unlock() - - switch whence { - case linux.SEEK_SET: - if offset < 0 { - return 0, syserror.EINVAL - } - if offset == 0 { - // Ensure that the next call to fd.IterDirents() calls - // fd.dentry().getDirents(). - fd.dirents = nil - } - fd.off = offset - return fd.off, nil - case linux.SEEK_CUR: - offset += fd.off - if offset < 0 { - return 0, syserror.EINVAL - } - // Don't clear fd.dirents in this case, even if offset == 0. - fd.off = offset - return fd.off, nil - default: - return 0, syserror.EINVAL - } -} diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go deleted file mode 100644 index 38e4cdbc5..000000000 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ /dev/null @@ -1,1103 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "sync" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// Sync implements vfs.FilesystemImpl.Sync. -func (fs *filesystem) Sync(ctx context.Context) error { - // Snapshot current dentries and special files. - fs.syncMu.Lock() - ds := make([]*dentry, 0, len(fs.dentries)) - for d := range fs.dentries { - ds = append(ds, d) - } - sffds := make([]*specialFileFD, 0, len(fs.specialFileFDs)) - for sffd := range fs.specialFileFDs { - sffds = append(sffds, sffd) - } - fs.syncMu.Unlock() - - // Return the first error we encounter, but sync everything we can - // regardless. - var retErr error - - // Sync regular files. - for _, d := range ds { - if !d.TryIncRef() { - continue - } - err := d.syncSharedHandle(ctx) - d.DecRef() - if err != nil && retErr == nil { - retErr = err - } - } - - // Sync special files, which may be writable but do not use dentry shared - // handles (so they won't be synced by the above). - for _, sffd := range sffds { - if !sffd.vfsfd.TryIncRef() { - continue - } - err := sffd.Sync(ctx) - sffd.vfsfd.DecRef() - if err != nil && retErr == nil { - retErr = err - } - } - - return retErr -} - -// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's -// encoding of strings, which uses 2 bytes for the length prefix. -const maxFilenameLen = (1 << 16) - 1 - -// dentrySlicePool is a pool of *[]*dentry used to store dentries for which -// dentry.checkCachingLocked() must be called. The pool holds pointers to -// slices because Go lacks generics, so sync.Pool operates on interface{}, so -// every call to (what should be) sync.Pool<[]*dentry>.Put() allocates a copy -// of the slice header on the heap. -var dentrySlicePool = sync.Pool{ - New: func() interface{} { - ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity - return &ds - }, -} - -func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry { - if ds == nil { - ds = dentrySlicePool.Get().(*[]*dentry) - } - *ds = append(*ds, d) - return ds -} - -// Preconditions: ds != nil. -func putDentrySlice(ds *[]*dentry) { - // Allow dentries to be GC'd. - for i := range *ds { - (*ds)[i] = nil - } - *ds = (*ds)[:0] - dentrySlicePool.Put(ds) -} - -// stepLocked resolves rp.Component() to an existing file, starting from the -// given directory. -// -// Dentries which may become cached as a result of the traversal are appended -// to *ds. -// -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. -// !rp.Done(). If fs.opts.interop == InteropModeShared, then d's cached -// metadata must be up to date. -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { - if !d.isDir() { - return nil, syserror.ENOTDIR - } - if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { - return nil, err - } -afterSymlink: - name := rp.Component() - if name == "." { - rp.Advance() - return d, nil - } - if name == ".." { - parentVFSD, err := rp.ResolveParent(&d.vfsd) - if err != nil { - return nil, err - } - parent := parentVFSD.Impl().(*dentry) - if fs.opts.interop == InteropModeShared { - // We must assume that parentVFSD is correct, because if d has been - // moved elsewhere in the remote filesystem so that its parent has - // changed, we have no way of determining its new parent's location - // in the filesystem. Get updated metadata for parentVFSD. - _, attrMask, attr, err := parent.file.getAttr(ctx, dentryAttrMask()) - if err != nil { - return nil, err - } - parent.updateFromP9Attrs(attrMask, &attr) - } - rp.Advance() - return parent, nil - } - childVFSD, err := rp.ResolveChild(&d.vfsd, name) - if err != nil { - return nil, err - } - // FIXME(jamieliu): Linux performs revalidation before mount lookup - // (fs/namei.c:lookup_fast() => __d_lookup_rcu(), d_revalidate(), - // __follow_mount_rcu()). - child, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), d, name, childVFSD, ds) - if err != nil { - return nil, err - } - if child == nil { - return nil, syserror.ENOENT - } - if child.isSymlink() && rp.ShouldFollowSymlink() { - target, err := child.readlink(ctx, rp.Mount()) - if err != nil { - return nil, err - } - if err := rp.HandleSymlink(target); err != nil { - return nil, err - } - goto afterSymlink // don't check the current directory again - } - rp.Advance() - return child, nil -} - -// revalidateChildLocked must be called after a call to parent.vfsd.Child(name) -// or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be -// nil) to verify that the returned child (or lack thereof) is correct. If no file -// exists at name, revalidateChildLocked returns (nil, nil). -// -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. -// parent.isDir(). name is not "." or "..". -// -// Postconditions: If revalidateChildLocked returns a non-nil dentry, its -// cached metadata is up to date. -func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, childVFSD *vfs.Dentry, ds **[]*dentry) (*dentry, error) { - if childVFSD != nil && fs.opts.interop != InteropModeShared { - // We have a cached dentry that is assumed to be correct. - return childVFSD.Impl().(*dentry), nil - } - // We either don't have a cached dentry or need to verify that it's still - // correct, either of which requires a remote lookup. Check if this name is - // valid before performing the lookup. - if len(name) > maxFilenameLen { - return nil, syserror.ENAMETOOLONG - } - // Check if we've already cached this lookup with a negative result. - if _, ok := parent.negativeChildren[name]; ok { - return nil, nil - } - // Perform the remote lookup. - qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) - if err != nil && err != syserror.ENOENT { - return nil, err - } - if childVFSD != nil { - child := childVFSD.Impl().(*dentry) - if !file.isNil() && qid.Path == child.ino { - // The file at this path hasn't changed. Just update cached - // metadata. - file.close(ctx) - child.updateFromP9Attrs(attrMask, &attr) - return child, nil - } - // The file at this path has changed or no longer exists. Remove - // the stale dentry from the tree, and re-evaluate its caching - // status (i.e. if it has 0 references, drop it). - vfsObj.ForceDeleteDentry(childVFSD) - *ds = appendDentry(*ds, child) - childVFSD = nil - } - if file.isNil() { - // No file exists at this path now. Cache the negative lookup if - // allowed. - if fs.opts.interop != InteropModeShared { - parent.cacheNegativeChildLocked(name) - } - return nil, nil - } - // Create a new dentry representing the file. - child, err := fs.newDentry(ctx, file, qid, attrMask, &attr) - if err != nil { - file.close(ctx) - return nil, err - } - parent.IncRef() // reference held by child on its parent - parent.vfsd.InsertChild(&child.vfsd, name) - // For now, child has 0 references, so our caller should call - // child.checkCachingLocked(). - *ds = appendDentry(*ds, child) - return child, nil -} - -// walkParentDirLocked resolves all but the last path component of rp to an -// existing directory, starting from the given directory (which is usually -// rp.Start().Impl().(*dentry)). It does not check that the returned directory -// is searchable by the provider of rp. -// -// Preconditions: fs.renameMu must be locked. !rp.Done(). If fs.opts.interop == -// InteropModeShared, then d's cached metadata must be up to date. -func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { - for !rp.Final() { - d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, ds) - d.dirMu.Unlock() - if err != nil { - return nil, err - } - d = next - } - if !d.isDir() { - return nil, syserror.ENOTDIR - } - return d, nil -} - -// resolveLocked resolves rp to an existing file. -// -// Preconditions: fs.renameMu must be locked. -func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { - d := rp.Start().Impl().(*dentry) - if fs.opts.interop == InteropModeShared { - // Get updated metadata for rp.Start() as required by fs.stepLocked(). - if err := d.updateFromGetattr(ctx); err != nil { - return nil, err - } - } - for !rp.Done() { - d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, ds) - d.dirMu.Unlock() - if err != nil { - return nil, err - } - d = next - } - if rp.MustBeDir() && !d.isDir() { - return nil, syserror.ENOTDIR - } - return d, nil -} - -// doCreateAt checks that creating a file at rp is permitted, then invokes -// create to do so. -// -// Preconditions: !rp.Done(). For the final path component in rp, -// !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string) error) error { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - start := rp.Start().Impl().(*dentry) - if fs.opts.interop == InteropModeShared { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } - parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) - if err != nil { - return err - } - if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { - return err - } - if parent.isDeleted() { - return syserror.ENOENT - } - name := rp.Component() - if name == "." || name == ".." { - return syserror.EEXIST - } - if len(name) > maxFilenameLen { - return syserror.ENAMETOOLONG - } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } - mnt := rp.Mount() - if err := mnt.CheckBeginWrite(); err != nil { - return err - } - defer mnt.EndWrite() - parent.dirMu.Lock() - defer parent.dirMu.Unlock() - if fs.opts.interop == InteropModeShared { - // The existence of a dentry at name would be inconclusive because the - // file it represents may have been deleted from the remote filesystem, - // so we would need to make an RPC to revalidate the dentry. Just - // attempt the file creation RPC instead. If a file does exist, the RPC - // will fail with EEXIST like we would have. If the RPC succeeds, and a - // stale dentry exists, the dentry will fail revalidation next time - // it's used. - return create(parent, name) - } - if parent.vfsd.Child(name) != nil { - return syserror.EEXIST - } - // No cached dentry exists; however, there might still be an existing file - // at name. As above, we attempt the file creation RPC anyway. - if err := create(parent, name); err != nil { - return err - } - parent.touchCMtime(ctx) - delete(parent.negativeChildren, name) - parent.dirents = nil - return nil -} - -// Preconditions: !rp.Done(). -func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool) error { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - start := rp.Start().Impl().(*dentry) - if fs.opts.interop == InteropModeShared { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } - parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) - if err != nil { - return err - } - if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - - name := rp.Component() - if dir { - if name == "." { - return syserror.EINVAL - } - if name == ".." { - return syserror.ENOTEMPTY - } - } else { - if name == "." || name == ".." { - return syserror.EISDIR - } - } - vfsObj := rp.VirtualFilesystem() - mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() - parent.dirMu.Lock() - defer parent.dirMu.Unlock() - childVFSD := parent.vfsd.Child(name) - var child *dentry - // We only need a dentry representing the file at name if it can be a mount - // point. If childVFSD is nil, then it can't be a mount point. If childVFSD - // is non-nil but stale, the actual file can't be a mount point either; we - // detect this case by just speculatively calling PrepareDeleteDentry and - // only revalidating the dentry if that fails (indicating that the existing - // dentry is a mount point). - if childVFSD != nil { - child = childVFSD.Impl().(*dentry) - if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil { - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, childVFSD, &ds) - if err != nil { - return err - } - if child != nil { - childVFSD = &child.vfsd - if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil { - return err - } - } else { - childVFSD = nil - } - } - } else if _, ok := parent.negativeChildren[name]; ok { - return syserror.ENOENT - } - flags := uint32(0) - if dir { - if child != nil && !child.isDir() { - return syserror.ENOTDIR - } - flags = linux.AT_REMOVEDIR - } else { - if child != nil && child.isDir() { - return syserror.EISDIR - } - if rp.MustBeDir() { - return syserror.ENOTDIR - } - } - err = parent.file.unlinkAt(ctx, name, flags) - if err != nil { - if childVFSD != nil { - vfsObj.AbortDeleteDentry(childVFSD) - } - return err - } - if fs.opts.interop != InteropModeShared { - parent.touchCMtime(ctx) - parent.cacheNegativeChildLocked(name) - parent.dirents = nil - } - if child != nil { - child.setDeleted() - vfsObj.CommitDeleteDentry(childVFSD) - ds = appendDentry(ds, child) - } - return nil -} - -// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls -// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for -// writing. -// -// ds is a pointer-to-pointer since defer evaluates its arguments immediately, -// but dentry slices are allocated lazily, and it's much easier to say "defer -// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { -// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) { - fs.renameMu.RUnlock() - if *ds == nil { - return - } - if len(**ds) != 0 { - fs.renameMu.Lock() - for _, d := range **ds { - d.checkCachingLocked() - } - fs.renameMu.Unlock() - } - putDentrySlice(*ds) -} - -func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) { - if *ds == nil { - fs.renameMu.Unlock() - return - } - for _, d := range **ds { - d.checkCachingLocked() - } - fs.renameMu.Unlock() - putDentrySlice(*ds) -} - -// AccessAt implements vfs.Filesystem.Impl.AccessAt. -func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return err - } - return d.checkPermissions(creds, ats, d.isDir()) -} - -// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. -func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return nil, err - } - if opts.CheckSearchable { - if !d.isDir() { - return nil, syserror.ENOTDIR - } - if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { - return nil, err - } - } - d.IncRef() - return &d.vfsd, nil -} - -// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. -func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - start := rp.Start().Impl().(*dentry) - if fs.opts.interop == InteropModeShared { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } - d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) - if err != nil { - return nil, err - } - d.IncRef() - return &d.vfsd, nil -} - -// LinkAt implements vfs.FilesystemImpl.LinkAt. -func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string) error { - if rp.Mount() != vd.Mount() { - return syserror.EXDEV - } - // 9P2000.L supports hard links, but we don't. - return syserror.EPERM - }) -} - -// MkdirAt implements vfs.FilesystemImpl.MkdirAt. -func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error { - creds := rp.Credentials() - _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - return err - }) -} - -// MknodAt implements vfs.FilesystemImpl.MknodAt. -func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error { - creds := rp.Credentials() - _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - return err - }) -} - -// OpenAt implements vfs.FilesystemImpl.OpenAt. -func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - // Reject O_TMPFILE, which is not supported; supporting it correctly in the - // presence of other remote filesystem users requires remote filesystem - // support, and it isn't clear that there's any way to implement this in - // 9P. - if opts.Flags&linux.O_TMPFILE != 0 { - return nil, syserror.EOPNOTSUPP - } - mayCreate := opts.Flags&linux.O_CREAT != 0 - mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL) - - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - - start := rp.Start().Impl().(*dentry) - if fs.opts.interop == InteropModeShared { - // Get updated metadata for start as required by fs.stepLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } - if rp.Done() { - return start.openLocked(ctx, rp, &opts) - } - -afterTrailingSymlink: - parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) - if err != nil { - return nil, err - } - // Check for search permission in the parent directory. - if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { - return nil, err - } - // Determine whether or not we need to create a file. - parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, &ds) - if err == syserror.ENOENT && mayCreate { - fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts) - parent.dirMu.Unlock() - return fd, err - } - if err != nil { - parent.dirMu.Unlock() - return nil, err - } - // Open existing child or follow symlink. - parent.dirMu.Unlock() - if mustCreate { - return nil, syserror.EEXIST - } - if child.isSymlink() && rp.ShouldFollowSymlink() { - target, err := child.readlink(ctx, rp.Mount()) - if err != nil { - return nil, err - } - if err := rp.HandleSymlink(target); err != nil { - return nil, err - } - start = parent - goto afterTrailingSymlink - } - return child.openLocked(ctx, rp, &opts) -} - -// Preconditions: fs.renameMu must be locked. -func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { - ats := vfs.AccessTypesForOpenFlags(opts) - if err := d.checkPermissions(rp.Credentials(), ats, d.isDir()); err != nil { - return nil, err - } - mnt := rp.Mount() - filetype := d.fileType() - switch { - case filetype == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD: - if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil { - return nil, err - } - fd := ®ularFileFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ - AllowDirectIO: true, - }); err != nil { - return nil, err - } - return &fd.vfsfd, nil - case filetype == linux.S_IFDIR: - // Can't open directories with O_CREAT. - if opts.Flags&linux.O_CREAT != 0 { - return nil, syserror.EISDIR - } - // Can't open directories writably. - if ats&vfs.MayWrite != 0 { - return nil, syserror.EISDIR - } - if opts.Flags&linux.O_DIRECT != 0 { - return nil, syserror.EINVAL - } - if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil { - return nil, err - } - fd := &directoryFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - return &fd.vfsfd, nil - case filetype == linux.S_IFLNK: - // Can't open symlinks without O_PATH (which is unimplemented). - return nil, syserror.ELOOP - default: - if opts.Flags&linux.O_DIRECT != 0 { - return nil, syserror.EINVAL - } - h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0) - if err != nil { - return nil, err - } - fd := &specialFileFD{ - handle: h, - } - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - h.close(ctx) - return nil, err - } - return &fd.vfsfd, nil - } -} - -// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked. -func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { - if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil { - return nil, err - } - if d.isDeleted() { - return nil, syserror.ENOENT - } - mnt := rp.Mount() - if err := mnt.CheckBeginWrite(); err != nil { - return nil, err - } - defer mnt.EndWrite() - - // 9P2000.L's lcreate takes a fid representing the parent directory, and - // converts it into an open fid representing the created file, so we need - // to duplicate the directory fid first. - _, dirfile, err := d.file.walk(ctx, nil) - if err != nil { - return nil, err - } - creds := rp.Credentials() - name := rp.Component() - fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, (p9.OpenFlags)(opts.Flags), (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - if err != nil { - dirfile.close(ctx) - return nil, err - } - // Then we need to walk to the file we just created to get a non-open fid - // representing it, and to get its metadata. This must use d.file since, as - // explained above, dirfile was invalidated by dirfile.Create(). - walkQID, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name) - if err != nil { - openFile.close(ctx) - if fdobj != nil { - fdobj.Close() - } - return nil, err - } - // Sanity-check that we walked to the file we created. - if createQID.Path != walkQID.Path { - // Probably due to concurrent remote filesystem mutation? - ctx.Warningf("gofer.dentry.createAndOpenChildLocked: created file has QID %v before walk, QID %v after (interop=%v)", createQID, walkQID, d.fs.opts.interop) - nonOpenFile.close(ctx) - openFile.close(ctx) - if fdobj != nil { - fdobj.Close() - } - return nil, syserror.EAGAIN - } - - // Construct the new dentry. - child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr) - if err != nil { - nonOpenFile.close(ctx) - openFile.close(ctx) - if fdobj != nil { - fdobj.Close() - } - return nil, err - } - // Incorporate the fid that was opened by lcreate. - useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD - if useRegularFileFD { - child.handleMu.Lock() - child.handle.file = openFile - if fdobj != nil { - child.handle.fd = int32(fdobj.Release()) - } - child.handleReadable = vfs.MayReadFileWithOpenFlags(opts.Flags) - child.handleWritable = vfs.MayWriteFileWithOpenFlags(opts.Flags) - child.handleMu.Unlock() - } - // Take a reference on the new dentry to be held by the new file - // description. (This reference also means that the new dentry is not - // eligible for caching yet, so we don't need to append to a dentry slice.) - child.refs = 1 - // Insert the dentry into the tree. - d.IncRef() // reference held by child on its parent d - d.vfsd.InsertChild(&child.vfsd, name) - if d.fs.opts.interop != InteropModeShared { - d.touchCMtime(ctx) - delete(d.negativeChildren, name) - d.dirents = nil - } - - // Finally, construct a file description representing the created file. - var childVFSFD *vfs.FileDescription - mnt.IncRef() - if useRegularFileFD { - fd := ®ularFileFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{ - AllowDirectIO: true, - }); err != nil { - return nil, err - } - childVFSFD = &fd.vfsfd - } else { - fd := &specialFileFD{ - handle: handle{ - file: openFile, - fd: -1, - }, - } - if fdobj != nil { - fd.handle.fd = int32(fdobj.Release()) - } - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - fd.handle.close(ctx) - return nil, err - } - childVFSFD = &fd.vfsfd - } - return childVFSFD, nil -} - -// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. -func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return "", err - } - if !d.isSymlink() { - return "", syserror.EINVAL - } - return d.readlink(ctx, rp.Mount()) -} - -// RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - // Requires 9P support. - return syserror.EINVAL - } - - var ds *[]*dentry - fs.renameMu.Lock() - defer fs.renameMuUnlockAndCheckCaching(&ds) - newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds) - if err != nil { - return err - } - newName := rp.Component() - if newName == "." || newName == ".." { - return syserror.EBUSY - } - mnt := rp.Mount() - if mnt != oldParentVD.Mount() { - return syserror.EXDEV - } - if err := mnt.CheckBeginWrite(); err != nil { - return err - } - defer mnt.EndWrite() - - oldParent := oldParentVD.Dentry().Impl().(*dentry) - if fs.opts.interop == InteropModeShared { - if err := oldParent.updateFromGetattr(ctx); err != nil { - return err - } - } - if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { - return err - } - vfsObj := rp.VirtualFilesystem() - // We need a dentry representing the renamed file since, if it's a - // directory, we need to check for write permission on it. - oldParent.dirMu.Lock() - defer oldParent.dirMu.Unlock() - renamed, err := fs.revalidateChildLocked(ctx, vfsObj, oldParent, oldName, oldParent.vfsd.Child(oldName), &ds) - if err != nil { - return err - } - if renamed == nil { - return syserror.ENOENT - } - if renamed.isDir() { - if renamed == newParent || renamed.vfsd.IsAncestorOf(&newParent.vfsd) { - return syserror.EINVAL - } - if oldParent != newParent { - if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil { - return err - } - } - } else { - if opts.MustBeDir || rp.MustBeDir() { - return syserror.ENOTDIR - } - } - - if oldParent != newParent { - if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { - return err - } - newParent.dirMu.Lock() - defer newParent.dirMu.Unlock() - } - if newParent.isDeleted() { - return syserror.ENOENT - } - replacedVFSD := newParent.vfsd.Child(newName) - var replaced *dentry - // This is similar to unlinkAt, except: - // - // - We revalidate the replaced dentry unconditionally for simplicity. - // - // - If rp.MustBeDir(), then we need a dentry representing the replaced - // file regardless to confirm that it's a directory. - if replacedVFSD != nil || rp.MustBeDir() { - replaced, err = fs.revalidateChildLocked(ctx, vfsObj, newParent, newName, replacedVFSD, &ds) - if err != nil { - return err - } - if replaced != nil { - if replaced.isDir() { - if !renamed.isDir() { - return syserror.EISDIR - } - } else { - if rp.MustBeDir() || renamed.isDir() { - return syserror.ENOTDIR - } - } - replacedVFSD = &replaced.vfsd - } else { - replacedVFSD = nil - } - } - - if oldParent == newParent && oldName == newName { - return nil - } - mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() - if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil { - return err - } - if err := renamed.file.rename(ctx, newParent.file, newName); err != nil { - vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) - return err - } - if fs.opts.interop != InteropModeShared { - oldParent.cacheNegativeChildLocked(oldName) - oldParent.dirents = nil - delete(newParent.negativeChildren, newName) - newParent.dirents = nil - } - vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, &newParent.vfsd, newName, replacedVFSD) - return nil -} - -// RmdirAt implements vfs.FilesystemImpl.RmdirAt. -func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { - return fs.unlinkAt(ctx, rp, true /* dir */) -} - -// SetStatAt implements vfs.FilesystemImpl.SetStatAt. -func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return err - } - return d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()) -} - -// StatAt implements vfs.FilesystemImpl.StatAt. -func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return linux.Statx{}, err - } - // Since walking updates metadata for all traversed dentries under - // InteropModeShared, including the returned one, we can return cached - // metadata here regardless of fs.opts.interop. - var stat linux.Statx - d.statTo(&stat) - return stat, nil -} - -// StatFSAt implements vfs.FilesystemImpl.StatFSAt. -func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return linux.Statfs{}, err - } - fsstat, err := d.file.statFS(ctx) - if err != nil { - return linux.Statfs{}, err - } - nameLen := uint64(fsstat.NameLength) - if nameLen > maxFilenameLen { - nameLen = maxFilenameLen - } - return linux.Statfs{ - // This is primarily for distinguishing a gofer file system in - // tests. Testing is important, so instead of defining - // something completely random, use a standard value. - Type: linux.V9FS_MAGIC, - BlockSize: int64(fsstat.BlockSize), - Blocks: fsstat.Blocks, - BlocksFree: fsstat.BlocksFree, - BlocksAvailable: fsstat.BlocksAvailable, - Files: fsstat.Files, - FilesFree: fsstat.FilesFree, - NameLength: nameLen, - }, nil -} - -// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. -func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error { - creds := rp.Credentials() - _, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - return err - }) -} - -// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. -func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { - return fs.unlinkAt(ctx, rp, false /* dir */) -} - -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return nil, err - } - return d.listxattr(ctx) -} - -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return "", err - } - return d.getxattr(ctx, name) -} - -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return err - } - return d.setxattr(ctx, &opts) -} - -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { - var ds *[]*dentry - fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) - d, err := fs.resolveLocked(ctx, rp, &ds) - if err != nil { - return err - } - return d.removexattr(ctx, name) -} - -// PrependPath implements vfs.FilesystemImpl.PrependPath. -func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { - fs.renameMu.RLock() - defer fs.renameMu.RUnlock() - return vfs.GenericPrependPath(vfsroot, vd, b) -} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go deleted file mode 100644 index c4a8f0b38..000000000 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ /dev/null @@ -1,1150 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gofer provides a filesystem implementation that is backed by a 9p -// server, interchangably referred to as "gofers" throughout this package. -// -// Lock order: -// regularFileFD/directoryFD.mu -// filesystem.renameMu -// dentry.dirMu -// filesystem.syncMu -// dentry.metadataMu -// *** "memmap.Mappable locks" below this point -// dentry.mapsMu -// *** "memmap.Mappable locks taken by Translate" below this point -// dentry.handleMu -// dentry.dataMu -// -// Locking dentry.dirMu in multiple dentries requires holding -// filesystem.renameMu for writing. -package gofer - -import ( - "fmt" - "strconv" - "sync" - "sync/atomic" - "syscall" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/unet" - "gvisor.dev/gvisor/pkg/usermem" -) - -// Name is the default filesystem name. -const Name = "9p" - -// FilesystemType implements vfs.FilesystemType. -type FilesystemType struct{} - -// filesystem implements vfs.FilesystemImpl. -type filesystem struct { - vfsfs vfs.Filesystem - - // mfp is used to allocate memory that caches regular file contents. mfp is - // immutable. - mfp pgalloc.MemoryFileProvider - - // Immutable options. - opts filesystemOptions - - // client is the client used by this filesystem. client is immutable. - client *p9.Client - - // uid and gid are the effective KUID and KGID of the filesystem's creator, - // and are used as the owner and group for files that don't specify one. - // uid and gid are immutable. - uid auth.KUID - gid auth.KGID - - // renameMu serves two purposes: - // - // - It synchronizes path resolution with renaming initiated by this - // client. - // - // - It is held by path resolution to ensure that reachable dentries remain - // valid. A dentry is reachable by path resolution if it has a non-zero - // reference count (such that it is usable as vfs.ResolvingPath.Start() or - // is reachable from its children), or if it is a child dentry (such that - // it is reachable from its parent). - renameMu sync.RWMutex - - // cachedDentries contains all dentries with 0 references. (Due to race - // conditions, it may also contain dentries with non-zero references.) - // cachedDentriesLen is the number of dentries in cachedDentries. These - // fields are protected by renameMu. - cachedDentries dentryList - cachedDentriesLen uint64 - - // dentries contains all dentries in this filesystem. specialFileFDs - // contains all open specialFileFDs. These fields are protected by syncMu. - syncMu sync.Mutex - dentries map[*dentry]struct{} - specialFileFDs map[*specialFileFD]struct{} -} - -type filesystemOptions struct { - // "Standard" 9P options. - fd int - aname string - interop InteropMode // derived from the "cache" mount option - msize uint32 - version string - - // maxCachedDentries is the maximum number of dentries with 0 references - // retained by the client. - maxCachedDentries uint64 - - // If forcePageCache is true, host FDs may not be used for application - // memory mappings even if available; instead, the client must perform its - // own caching of regular file pages. This is primarily useful for testing. - forcePageCache bool - - // If limitHostFDTranslation is true, apply maxFillRange() constraints to - // host FD mappings returned by dentry.(memmap.Mappable).Translate(). This - // makes memory accounting behavior more consistent between cases where - // host FDs are / are not available, but may increase the frequency of - // sentry-handled page faults on files for which a host FD is available. - limitHostFDTranslation bool - - // If overlayfsStaleRead is true, O_RDONLY host FDs provided by the remote - // filesystem may not be coherent with writable host FDs opened later, so - // mappings of the former must be replaced by mappings of the latter. This - // is usually only the case when the remote filesystem is an overlayfs - // mount on Linux < 4.19. - overlayfsStaleRead bool - - // If regularFilesUseSpecialFileFD is true, application FDs representing - // regular files will use distinct file handles for each FD, in the same - // way that application FDs representing "special files" such as sockets - // do. Note that this disables client caching and mmap for regular files. - regularFilesUseSpecialFileFD bool -} - -// InteropMode controls the client's interaction with other remote filesystem -// users. -type InteropMode uint32 - -const ( - // InteropModeExclusive is appropriate when the filesystem client is the - // only user of the remote filesystem. - // - // - The client may cache arbitrary filesystem state (file data, metadata, - // filesystem structure, etc.). - // - // - Client changes to filesystem state may be sent to the remote - // filesystem asynchronously, except when server permission checks are - // necessary. - // - // - File timestamps are based on client clocks. This ensures that users of - // the client observe timestamps that are coherent with their own clocks - // and consistent with Linux's semantics. However, since it is not always - // possible for clients to set arbitrary atimes and mtimes, and never - // possible for clients to set arbitrary ctimes, file timestamp changes are - // stored in the client only and never sent to the remote filesystem. - InteropModeExclusive InteropMode = iota - - // InteropModeWritethrough is appropriate when there are read-only users of - // the remote filesystem that expect to observe changes made by the - // filesystem client. - // - // - The client may cache arbitrary filesystem state. - // - // - Client changes to filesystem state must be sent to the remote - // filesystem synchronously. - // - // - File timestamps are based on client clocks. As a corollary, access - // timestamp changes from other remote filesystem users will not be visible - // to the client. - InteropModeWritethrough - - // InteropModeShared is appropriate when there are users of the remote - // filesystem that may mutate its state other than the client. - // - // - The client must verify cached filesystem state before using it. - // - // - Client changes to filesystem state must be sent to the remote - // filesystem synchronously. - // - // - File timestamps are based on server clocks. This is necessary to - // ensure that timestamp changes are synchronized between remote filesystem - // users. - // - // Note that the correctness of InteropModeShared depends on the server - // correctly implementing 9P fids (i.e. each fid immutably represents a - // single filesystem object), even in the presence of remote filesystem - // mutations from other users. If this is violated, the behavior of the - // client is undefined. - InteropModeShared -) - -// GetFilesystem implements vfs.FilesystemType.GetFilesystem. -func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - mfp := pgalloc.MemoryFileProviderFromContext(ctx) - if mfp == nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: context does not provide a pgalloc.MemoryFileProvider") - return nil, nil, syserror.EINVAL - } - - mopts := vfs.GenericParseMountOptions(opts.Data) - var fsopts filesystemOptions - - // Check that the transport is "fd". - trans, ok := mopts["trans"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'") - return nil, nil, syserror.EINVAL - } - delete(mopts, "trans") - if trans != "fd" { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans) - return nil, nil, syserror.EINVAL - } - - // Check that read and write FDs are provided and identical. - rfdstr, ok := mopts["rfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "rfdno") - rfd, err := strconv.Atoi(rfdstr) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr) - return nil, nil, syserror.EINVAL - } - wfdstr, ok := mopts["wfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "wfdno") - wfd, err := strconv.Atoi(wfdstr) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr) - return nil, nil, syserror.EINVAL - } - if rfd != wfd { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd) - return nil, nil, syserror.EINVAL - } - fsopts.fd = rfd - - // Get the attach name. - fsopts.aname = "/" - if aname, ok := mopts["aname"]; ok { - delete(mopts, "aname") - fsopts.aname = aname - } - - // Parse the cache policy. For historical reasons, this defaults to the - // least generally-applicable option, InteropModeExclusive. - fsopts.interop = InteropModeExclusive - if cache, ok := mopts["cache"]; ok { - delete(mopts, "cache") - switch cache { - case "fscache": - fsopts.interop = InteropModeExclusive - case "fscache_writethrough": - fsopts.interop = InteropModeWritethrough - case "none": - fsopts.regularFilesUseSpecialFileFD = true - fallthrough - case "remote_revalidating": - fsopts.interop = InteropModeShared - default: - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: cache=%s", cache) - return nil, nil, syserror.EINVAL - } - } - - // Parse the 9P message size. - fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M - if msizestr, ok := mopts["msize"]; ok { - delete(mopts, "msize") - msize, err := strconv.ParseUint(msizestr, 10, 32) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: msize=%s", msizestr) - return nil, nil, syserror.EINVAL - } - fsopts.msize = uint32(msize) - } - - // Parse the 9P protocol version. - fsopts.version = p9.HighestVersionString() - if version, ok := mopts["version"]; ok { - delete(mopts, "version") - fsopts.version = version - } - - // Parse the dentry cache limit. - fsopts.maxCachedDentries = 1000 - if str, ok := mopts["dentry_cache_limit"]; ok { - delete(mopts, "dentry_cache_limit") - maxCachedDentries, err := strconv.ParseUint(str, 10, 64) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) - return nil, nil, syserror.EINVAL - } - fsopts.maxCachedDentries = maxCachedDentries - } - - // Handle simple flags. - if _, ok := mopts["force_page_cache"]; ok { - delete(mopts, "force_page_cache") - fsopts.forcePageCache = true - } - if _, ok := mopts["limit_host_fd_translation"]; ok { - delete(mopts, "limit_host_fd_translation") - fsopts.limitHostFDTranslation = true - } - if _, ok := mopts["overlayfs_stale_read"]; ok { - delete(mopts, "overlayfs_stale_read") - fsopts.overlayfsStaleRead = true - } - // fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying - // "cache=none". - - // Check for unparsed options. - if len(mopts) != 0 { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: unknown options: %v", mopts) - return nil, nil, syserror.EINVAL - } - - // Establish a connection with the server. - conn, err := unet.NewSocket(fsopts.fd) - if err != nil { - return nil, nil, err - } - - // Perform version negotiation with the server. - ctx.UninterruptibleSleepStart(false) - client, err := p9.NewClient(conn, fsopts.msize, fsopts.version) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - conn.Close() - return nil, nil, err - } - // Ownership of conn has been transferred to client. - - // Perform attach to obtain the filesystem root. - ctx.UninterruptibleSleepStart(false) - attached, err := client.Attach(fsopts.aname) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - client.Close() - return nil, nil, err - } - attachFile := p9file{attached} - qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) - if err != nil { - attachFile.close(ctx) - client.Close() - return nil, nil, err - } - - // Construct the filesystem object. - fs := &filesystem{ - mfp: mfp, - opts: fsopts, - uid: creds.EffectiveKUID, - gid: creds.EffectiveKGID, - client: client, - dentries: make(map[*dentry]struct{}), - specialFileFDs: make(map[*specialFileFD]struct{}), - } - fs.vfsfs.Init(vfsObj, fs) - - // Construct the root dentry. - root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) - if err != nil { - attachFile.close(ctx) - fs.vfsfs.DecRef() - return nil, nil, err - } - // Set the root's reference count to 2. One reference is returned to the - // caller, and the other is deliberately leaked to prevent the root from - // being "cached" and subsequently evicted. Its resources will still be - // cleaned up by fs.Release(). - root.refs = 2 - - return &fs.vfsfs, &root.vfsd, nil -} - -// Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { - ctx := context.Background() - mf := fs.mfp.MemoryFile() - - fs.syncMu.Lock() - for d := range fs.dentries { - d.handleMu.Lock() - d.dataMu.Lock() - if d.handleWritable { - // Write dirty cached data to the remote file. - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt); err != nil { - log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err) - } - // TODO(jamieliu): Do we need to flushf/fsync d? - } - // Discard cached pages. - d.cache.DropAll(mf) - d.dirty.RemoveAll() - d.dataMu.Unlock() - // Close the host fd if one exists. - if d.handle.fd >= 0 { - syscall.Close(int(d.handle.fd)) - d.handle.fd = -1 - } - d.handleMu.Unlock() - } - // There can't be any specialFileFDs still using fs, since each such - // FileDescription would hold a reference on a Mount holding a reference on - // fs. - fs.syncMu.Unlock() - - // Close the connection to the server. This implicitly clunks all fids. - fs.client.Close() -} - -// dentry implements vfs.DentryImpl. -type dentry struct { - vfsd vfs.Dentry - - // refs is the reference count. Each dentry holds a reference on its - // parent, even if disowned. refs is accessed using atomic memory - // operations. - refs int64 - - // fs is the owning filesystem. fs is immutable. - fs *filesystem - - // We don't support hard links, so each dentry maps 1:1 to an inode. - - // file is the unopened p9.File that backs this dentry. file is immutable. - file p9file - - // If deleted is non-zero, the file represented by this dentry has been - // deleted. deleted is accessed using atomic memory operations. - deleted uint32 - - // If cached is true, dentryEntry links dentry into - // filesystem.cachedDentries. cached and dentryEntry are protected by - // filesystem.renameMu. - cached bool - dentryEntry - - dirMu sync.Mutex - - // If this dentry represents a directory, and InteropModeShared is not in - // effect, negativeChildren is a set of child names in this directory that - // are known not to exist. negativeChildren is protected by dirMu. - negativeChildren map[string]struct{} - - // If this dentry represents a directory, InteropModeShared is not in - // effect, and dirents is not nil, it is a cache of all entries in the - // directory, in the order they were returned by the server. dirents is - // protected by dirMu. - dirents []vfs.Dirent - - // Cached metadata; protected by metadataMu and accessed using atomic - // memory operations unless otherwise specified. - metadataMu sync.Mutex - ino uint64 // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown - // Timestamps, all nsecs from the Unix epoch. - atime int64 - mtime int64 - ctime int64 - btime int64 - // File size, protected by both metadataMu and dataMu (i.e. both must be - // locked to mutate it). - size uint64 - - mapsMu sync.Mutex - - // If this dentry represents a regular file, mappings tracks mappings of - // the file into memmap.MappingSpaces. mappings is protected by mapsMu. - mappings memmap.MappingSet - - // If this dentry represents a regular file or directory: - // - // - handle is the I/O handle used by all regularFileFDs/directoryFDs - // representing this dentry. - // - // - handleReadable is true if handle is readable. - // - // - handleWritable is true if handle is writable. - // - // Invariants: - // - // - If handleReadable == handleWritable == false, then handle.file == nil - // (i.e. there is no open handle). Conversely, if handleReadable || - // handleWritable == true, then handle.file != nil (i.e. there is an open - // handle). - // - // - handleReadable and handleWritable cannot transition from true to false - // (i.e. handles may not be downgraded). - // - // These fields are protected by handleMu. - handleMu sync.RWMutex - handle handle - handleReadable bool - handleWritable bool - - dataMu sync.RWMutex - - // If this dentry represents a regular file that is client-cached, cache - // maps offsets into the cached file to offsets into - // filesystem.mfp.MemoryFile() that store the file's data. cache is - // protected by dataMu. - cache fsutil.FileRangeSet - - // If this dentry represents a regular file that is client-cached, dirty - // tracks dirty segments in cache. dirty is protected by dataMu. - dirty fsutil.DirtySet - - // pf implements platform.File for mappings of handle.fd. - pf dentryPlatformFile - - // If this dentry represents a symbolic link, InteropModeShared is not in - // effect, and haveTarget is true, target is the symlink target. haveTarget - // and target are protected by dataMu. - haveTarget bool - target string -} - -// dentryAttrMask returns a p9.AttrMask enabling all attributes used by the -// gofer client. -func dentryAttrMask() p9.AttrMask { - return p9.AttrMask{ - Mode: true, - UID: true, - GID: true, - ATime: true, - MTime: true, - CTime: true, - Size: true, - BTime: true, - } -} - -// newDentry creates a new dentry representing the given file. The dentry -// initially has no references, but is not cached; it is the caller's -// responsibility to set the dentry's reference count and/or call -// dentry.checkCachingLocked() as appropriate. -func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, mask p9.AttrMask, attr *p9.Attr) (*dentry, error) { - if !mask.Mode { - ctx.Warningf("can't create gofer.dentry without file type") - return nil, syserror.EIO - } - if attr.Mode.FileType() == p9.ModeRegular && !mask.Size { - ctx.Warningf("can't create regular file gofer.dentry without file size") - return nil, syserror.EIO - } - - d := &dentry{ - fs: fs, - file: file, - ino: qid.Path, - mode: uint32(attr.Mode), - uid: uint32(fs.uid), - gid: uint32(fs.gid), - blockSize: usermem.PageSize, - handle: handle{ - fd: -1, - }, - } - d.pf.dentry = d - if mask.UID { - d.uid = uint32(attr.UID) - } - if mask.GID { - d.gid = uint32(attr.GID) - } - if mask.Size { - d.size = attr.Size - } - if attr.BlockSize != 0 { - d.blockSize = uint32(attr.BlockSize) - } - if mask.ATime { - d.atime = dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds) - } - if mask.MTime { - d.mtime = dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds) - } - if mask.CTime { - d.ctime = dentryTimestampFromP9(attr.CTimeSeconds, attr.CTimeNanoSeconds) - } - if mask.BTime { - d.btime = dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds) - } - d.vfsd.Init(d) - - fs.syncMu.Lock() - fs.dentries[d] = struct{}{} - fs.syncMu.Unlock() - return d, nil -} - -// updateFromP9Attrs is called to update d's metadata after an update from the -// remote filesystem. -func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { - d.metadataMu.Lock() - if mask.Mode { - if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want { - d.metadataMu.Unlock() - panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got)) - } - atomic.StoreUint32(&d.mode, uint32(attr.Mode)) - } - if mask.UID { - atomic.StoreUint32(&d.uid, uint32(attr.UID)) - } - if mask.GID { - atomic.StoreUint32(&d.gid, uint32(attr.GID)) - } - // There is no P9_GETATTR_* bit for I/O block size. - if attr.BlockSize != 0 { - atomic.StoreUint32(&d.blockSize, uint32(attr.BlockSize)) - } - if mask.ATime { - atomic.StoreInt64(&d.atime, dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds)) - } - if mask.MTime { - atomic.StoreInt64(&d.mtime, dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds)) - } - if mask.CTime { - atomic.StoreInt64(&d.ctime, dentryTimestampFromP9(attr.CTimeSeconds, attr.CTimeNanoSeconds)) - } - if mask.BTime { - atomic.StoreInt64(&d.btime, dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds)) - } - if mask.Size { - d.dataMu.Lock() - atomic.StoreUint64(&d.size, attr.Size) - d.dataMu.Unlock() - } - d.metadataMu.Unlock() -} - -func (d *dentry) updateFromGetattr(ctx context.Context) error { - // Use d.handle.file, which represents a 9P fid that has been opened, in - // preference to d.file, which represents a 9P fid that has not. This may - // be significantly more efficient in some implementations. - var ( - file p9file - handleMuRLocked bool - ) - d.handleMu.RLock() - if !d.handle.file.isNil() { - file = d.handle.file - handleMuRLocked = true - } else { - file = d.file - d.handleMu.RUnlock() - } - _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask()) - if handleMuRLocked { - d.handleMu.RUnlock() - } - if err != nil { - return err - } - d.updateFromP9Attrs(attrMask, &attr) - return nil -} - -func (d *dentry) fileType() uint32 { - return atomic.LoadUint32(&d.mode) & linux.S_IFMT -} - -func (d *dentry) statTo(stat *linux.Statx) { - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME - stat.Blksize = atomic.LoadUint32(&d.blockSize) - stat.Nlink = 1 - if d.isDir() { - stat.Nlink = 2 - } - stat.UID = atomic.LoadUint32(&d.uid) - stat.GID = atomic.LoadUint32(&d.gid) - stat.Mode = uint16(atomic.LoadUint32(&d.mode)) - stat.Ino = d.ino - stat.Size = atomic.LoadUint64(&d.size) - // This is consistent with regularFileFD.Seek(), which treats regular files - // as having no holes. - stat.Blocks = (stat.Size + 511) / 512 - stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime)) - stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime)) - stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime)) - stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime)) - // TODO(jamieliu): device number -} - -func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error { - if stat.Mask == 0 { - return nil - } - if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 { - return syserror.EPERM - } - if err := vfs.CheckSetStat(creds, stat, uint16(atomic.LoadUint32(&d.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { - return err - } - if err := mnt.CheckBeginWrite(); err != nil { - return err - } - defer mnt.EndWrite() - setLocalAtime := false - setLocalMtime := false - if d.fs.opts.interop != InteropModeShared { - // Timestamp updates will be handled locally. - setLocalAtime = stat.Mask&linux.STATX_ATIME != 0 - setLocalMtime = stat.Mask&linux.STATX_MTIME != 0 - stat.Mask &^= linux.STATX_ATIME | linux.STATX_MTIME - if !setLocalMtime && (stat.Mask&linux.STATX_SIZE != 0) { - // Truncate updates mtime. - setLocalMtime = true - stat.Mtime.Nsec = linux.UTIME_NOW - } - } - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - if stat.Mask != 0 { - if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0, - UID: stat.Mask&linux.STATX_UID != 0, - GID: stat.Mask&linux.STATX_GID != 0, - Size: stat.Mask&linux.STATX_SIZE != 0, - ATime: stat.Mask&linux.STATX_ATIME != 0, - MTime: stat.Mask&linux.STATX_MTIME != 0, - ATimeNotSystemTime: stat.Atime.Nsec != linux.UTIME_NOW, - MTimeNotSystemTime: stat.Mtime.Nsec != linux.UTIME_NOW, - }, p9.SetAttr{ - Permissions: p9.FileMode(stat.Mode), - UID: p9.UID(stat.UID), - GID: p9.GID(stat.GID), - Size: stat.Size, - ATimeSeconds: uint64(stat.Atime.Sec), - ATimeNanoSeconds: uint64(stat.Atime.Nsec), - MTimeSeconds: uint64(stat.Mtime.Sec), - MTimeNanoSeconds: uint64(stat.Mtime.Nsec), - }); err != nil { - return err - } - } - if d.fs.opts.interop == InteropModeShared { - // There's no point to updating d's metadata in this case since it'll - // be overwritten by revalidation before the next time it's used - // anyway. (InteropModeShared inhibits client caching of regular file - // data, so there's no cache to truncate either.) - return nil - } - now, haveNow := nowFromContext(ctx) - if !haveNow { - ctx.Warningf("gofer.dentry.setStat: current time not available") - } - if stat.Mask&linux.STATX_MODE != 0 { - atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) - } - if stat.Mask&linux.STATX_UID != 0 { - atomic.StoreUint32(&d.uid, stat.UID) - } - if stat.Mask&linux.STATX_GID != 0 { - atomic.StoreUint32(&d.gid, stat.GID) - } - if setLocalAtime { - if stat.Atime.Nsec == linux.UTIME_NOW { - if haveNow { - atomic.StoreInt64(&d.atime, now) - } - } else { - atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) - } - } - if setLocalMtime { - if stat.Mtime.Nsec == linux.UTIME_NOW { - if haveNow { - atomic.StoreInt64(&d.mtime, now) - } - } else { - atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) - } - } - if haveNow { - atomic.StoreInt64(&d.ctime, now) - } - if stat.Mask&linux.STATX_SIZE != 0 { - d.dataMu.Lock() - oldSize := d.size - d.size = stat.Size - // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings - // below. This allows concurrent calls to Read/Translate/etc. These - // functions synchronize with truncation by refusing to use cache - // contents beyond the new d.size. (We are still holding d.metadataMu, - // so we can't race with Write or another truncate.) - d.dataMu.Unlock() - if d.size < oldSize { - oldpgend := pageRoundUp(oldSize) - newpgend := pageRoundUp(d.size) - if oldpgend != newpgend { - d.mapsMu.Lock() - d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ - // Compare Linux's mm/truncate.c:truncate_setsize() => - // truncate_pagecache() => - // mm/memory.c:unmap_mapping_range(evencows=1). - InvalidatePrivate: true, - }) - d.mapsMu.Unlock() - } - // We are now guaranteed that there are no translations of - // truncated pages, and can remove them from the cache. Since - // truncated pages have been removed from the remote file, they - // should be dropped without being written back. - d.dataMu.Lock() - d.cache.Truncate(d.size, d.fs.mfp.MemoryFile()) - d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend}) - d.dataMu.Unlock() - } - } - return nil -} - -func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error { - return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&d.mode))&0777, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) -} - -// IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef() { - // d.refs may be 0 if d.fs.renameMu is locked, which serializes against - // d.checkCachingLocked(). - atomic.AddInt64(&d.refs, 1) -} - -// TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef() bool { - for { - refs := atomic.LoadInt64(&d.refs) - if refs == 0 { - return false - } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { - return true - } - } -} - -// DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { - d.fs.renameMu.Lock() - d.checkCachingLocked() - d.fs.renameMu.Unlock() - } else if refs < 0 { - panic("gofer.dentry.DecRef() called without holding a reference") - } -} - -// checkCachingLocked should be called after d's reference count becomes 0 or it -// becomes disowned. -// -// Preconditions: d.fs.renameMu must be locked for writing. -func (d *dentry) checkCachingLocked() { - // Dentries with a non-zero reference count must be retained. (The only way - // to obtain a reference on a dentry with zero references is via path - // resolution, which requires renameMu, so if d.refs is zero then it will - // remain zero while we hold renameMu for writing.) - if atomic.LoadInt64(&d.refs) != 0 { - if d.cached { - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentriesLen-- - d.cached = false - } - return - } - // Non-child dentries with zero references are no longer reachable by path - // resolution and should be dropped immediately. - if d.vfsd.Parent() == nil || d.vfsd.IsDisowned() { - if d.cached { - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentriesLen-- - d.cached = false - } - d.destroyLocked() - return - } - // If d is already cached, just move it to the front of the LRU. - if d.cached { - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentries.PushFront(d) - return - } - // Cache the dentry, then evict the least recently used cached dentry if - // the cache becomes over-full. - d.fs.cachedDentries.PushFront(d) - d.fs.cachedDentriesLen++ - d.cached = true - if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { - victim := d.fs.cachedDentries.Back() - d.fs.cachedDentries.Remove(victim) - d.fs.cachedDentriesLen-- - victim.cached = false - // victim.refs may have become non-zero from an earlier path - // resolution since it was inserted into fs.cachedDentries; see - // dentry.incRefLocked(). Either way, we brought - // fs.cachedDentriesLen back down to fs.opts.maxCachedDentries, so - // we don't loop. - if atomic.LoadInt64(&victim.refs) == 0 { - if victimParentVFSD := victim.vfsd.Parent(); victimParentVFSD != nil { - victimParent := victimParentVFSD.Impl().(*dentry) - victimParent.dirMu.Lock() - if !victim.vfsd.IsDisowned() { - // victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount - // points. - d.fs.vfsfs.VirtualFilesystem().ForceDeleteDentry(&victim.vfsd) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victimParent.dirMu.Unlock() - } - victim.destroyLocked() - } - } -} - -// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is -// not a child dentry. -func (d *dentry) destroyLocked() { - ctx := context.Background() - d.handleMu.Lock() - if !d.handle.file.isNil() { - mf := d.fs.mfp.MemoryFile() - d.dataMu.Lock() - // Write dirty pages back to the remote filesystem. - if d.handleWritable { - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { - log.Warningf("gofer.dentry.DecRef: failed to write dirty data back: %v", err) - } - } - // Discard cached data. - d.cache.DropAll(mf) - d.dirty.RemoveAll() - d.dataMu.Unlock() - // Clunk open fids and close open host FDs. - d.handle.close(ctx) - } - d.handleMu.Unlock() - d.file.close(ctx) - // Remove d from the set of all dentries. - d.fs.syncMu.Lock() - delete(d.fs.dentries, d) - d.fs.syncMu.Unlock() - // Drop the reference held by d on its parent. - if parentVFSD := d.vfsd.Parent(); parentVFSD != nil { - parent := parentVFSD.Impl().(*dentry) - // This is parent.DecRef() without recursive locking of d.fs.renameMu. - if refs := atomic.AddInt64(&parent.refs, -1); refs == 0 { - parent.checkCachingLocked() - } else if refs < 0 { - panic("gofer.dentry.DecRef() called without holding a reference") - } - } -} - -func (d *dentry) isDeleted() bool { - return atomic.LoadUint32(&d.deleted) != 0 -} - -func (d *dentry) setDeleted() { - atomic.StoreUint32(&d.deleted, 1) -} - -func (d *dentry) listxattr(ctx context.Context) ([]string, error) { - return nil, syserror.ENOTSUP -} - -func (d *dentry) getxattr(ctx context.Context, name string) (string, error) { - // TODO(jamieliu): add vfs.GetxattrOptions.Size - return d.file.getXattr(ctx, name, linux.XATTR_SIZE_MAX) -} - -func (d *dentry) setxattr(ctx context.Context, opts *vfs.SetxattrOptions) error { - return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) -} - -func (d *dentry) removexattr(ctx context.Context, name string) error { - return syserror.ENOTSUP -} - -// Preconditions: d.isRegularFile() || d.isDirectory(). -func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error { - // O_TRUNC unconditionally requires us to obtain a new handle (opened with - // O_TRUNC). - if !trunc { - d.handleMu.RLock() - if (!read || d.handleReadable) && (!write || d.handleWritable) { - // The current handle is sufficient. - d.handleMu.RUnlock() - return nil - } - d.handleMu.RUnlock() - } - - haveOldFD := false - d.handleMu.Lock() - if (read && !d.handleReadable) || (write && !d.handleWritable) || trunc { - // Get a new handle. - wantReadable := d.handleReadable || read - wantWritable := d.handleWritable || write - h, err := openHandle(ctx, d.file, wantReadable, wantWritable, trunc) - if err != nil { - d.handleMu.Unlock() - return err - } - if !d.handle.file.isNil() { - // Check that old and new handles are compatible: If the old handle - // includes a host file descriptor but the new one does not, or - // vice versa, old and new memory mappings may be incoherent. - haveOldFD = d.handle.fd >= 0 - haveNewFD := h.fd >= 0 - if haveOldFD != haveNewFD { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: can't change host FD availability from %v to %v across dentry handle upgrade", haveOldFD, haveNewFD) - h.close(ctx) - return syserror.EIO - } - if haveOldFD { - // We may have raced with callers of d.pf.FD() that are now - // using the old file descriptor, preventing us from safely - // closing it. We could handle this by invalidating existing - // memmap.Translations, but this is expensive. Instead, use - // dup3 to make the old file descriptor refer to the new file - // description, then close the new file descriptor (which is no - // longer needed). Racing callers may use the old or new file - // description, but this doesn't matter since they refer to the - // same file (unless d.fs.opts.overlayfsStaleRead is true, - // which we handle separately). - if err := syscall.Dup3(int(h.fd), int(d.handle.fd), 0); err != nil { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err) - h.close(ctx) - return err - } - syscall.Close(int(h.fd)) - h.fd = d.handle.fd - if d.fs.opts.overlayfsStaleRead { - // Replace sentry mappings of the old FD with mappings of - // the new FD, since the two are not necessarily coherent. - if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) - h.close(ctx) - return err - } - } - // Clunk the old fid before making the new handle visible (by - // unlocking d.handleMu). - d.handle.file.close(ctx) - } - } - // Switch to the new handle. - d.handle = h - d.handleReadable = wantReadable - d.handleWritable = wantWritable - } - d.handleMu.Unlock() - - if d.fs.opts.overlayfsStaleRead && haveOldFD { - // Invalidate application mappings that may be using the old FD; they - // will be replaced with mappings using the new FD after future calls - // to d.Translate(). This requires holding d.mapsMu, which precedes - // d.handleMu in the lock order. - d.mapsMu.Lock() - d.mappings.InvalidateAll(memmap.InvalidateOpts{}) - d.mapsMu.Unlock() - } - - return nil -} - -// fileDescription is embedded by gofer implementations of -// vfs.FileDescriptionImpl. -type fileDescription struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl -} - -func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) -} - -func (fd *fileDescription) dentry() *dentry { - return fd.vfsfd.Dentry().Impl().(*dentry) -} - -// Stat implements vfs.FileDescriptionImpl.Stat. -func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - d := fd.dentry() - if d.fs.opts.interop == InteropModeShared && opts.Mask&(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE|linux.STATX_BLOCKS|linux.STATX_BTIME) != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC { - // TODO(jamieliu): Use specialFileFD.handle.file for the getattr if - // available? - if err := d.updateFromGetattr(ctx); err != nil { - return linux.Statx{}, err - } - } - var stat linux.Statx - d.statTo(&stat) - return stat, nil -} - -// SetStat implements vfs.FileDescriptionImpl.SetStat. -func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - return fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()) -} - -// Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context) ([]string, error) { - return fd.dentry().listxattr(ctx) -} - -// Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, name string) (string, error) { - return fd.dentry().getxattr(ctx, name) -} - -// Setxattr implements vfs.FileDescriptionImpl.Setxattr. -func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { - return fd.dentry().setxattr(ctx, &opts) -} - -// Removexattr implements vfs.FileDescriptionImpl.Removexattr. -func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { - return fd.dentry().removexattr(ctx, name) -} diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go deleted file mode 100644 index cfe66f797..000000000 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "syscall" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/safemem" -) - -// handle represents a remote "open file descriptor", consisting of an opened -// fid (p9.File) and optionally a host file descriptor. -type handle struct { - file p9file - fd int32 // -1 if unavailable -} - -// Preconditions: read || write. -func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (handle, error) { - _, newfile, err := file.walk(ctx, nil) - if err != nil { - return handle{fd: -1}, err - } - var flags p9.OpenFlags - switch { - case read && !write: - flags = p9.ReadOnly - case !read && write: - flags = p9.WriteOnly - case read && write: - flags = p9.ReadWrite - } - if trunc { - flags |= p9.OpenTruncate - } - fdobj, _, _, err := newfile.open(ctx, flags) - if err != nil { - newfile.close(ctx) - return handle{fd: -1}, err - } - fd := int32(-1) - if fdobj != nil { - fd = int32(fdobj.Release()) - } - return handle{ - file: newfile, - fd: fd, - }, nil -} - -func (h *handle) close(ctx context.Context) { - h.file.close(ctx) - h.file = p9file{} - if h.fd >= 0 { - syscall.Close(int(h.fd)) - h.fd = -1 - } -} - -func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { - if dsts.IsEmpty() { - return 0, nil - } - if h.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - n, err := hostPreadv(h.fd, dsts, int64(offset)) - ctx.UninterruptibleSleepFinish(false) - return n, err - } - if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() { - n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset) - return uint64(n), err - } - // Buffer the read since p9.File.ReadAt() takes []byte. - buf := make([]byte, dsts.NumBytes()) - n, err := h.file.readAt(ctx, buf, offset) - if n == 0 { - return 0, err - } - if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil { - return cp, cperr - } - return uint64(n), err -} - -func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { - if srcs.IsEmpty() { - return 0, nil - } - if h.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - n, err := hostPwritev(h.fd, srcs, int64(offset)) - ctx.UninterruptibleSleepFinish(false) - return n, err - } - if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() { - n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset) - return uint64(n), err - } - // Buffer the write since p9.File.WriteAt() takes []byte. - buf := make([]byte, srcs.NumBytes()) - cp, cperr := safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), srcs) - if cp == 0 { - return 0, cperr - } - n, err := h.file.writeAt(ctx, buf[:cp], offset) - if err != nil { - return uint64(n), err - } - return cp, cperr -} - -func (h *handle) sync(ctx context.Context) error { - if h.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(h.fd)) - ctx.UninterruptibleSleepFinish(false) - return err - } - return h.file.fsync(ctx) -} diff --git a/pkg/sentry/fsimpl/gofer/handle_unsafe.go b/pkg/sentry/fsimpl/gofer/handle_unsafe.go deleted file mode 100644 index 19560ab26..000000000 --- a/pkg/sentry/fsimpl/gofer/handle_unsafe.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "syscall" - "unsafe" - - "gvisor.dev/gvisor/pkg/safemem" -) - -// Preconditions: !dsts.IsEmpty(). -func hostPreadv(fd int32, dsts safemem.BlockSeq, off int64) (uint64, error) { - // No buffering is necessary regardless of safecopy; host syscalls will - // return EFAULT if appropriate, instead of raising SIGBUS. - if dsts.NumBlocks() == 1 { - // Use pread() instead of preadv() to avoid iovec allocation and - // copying. - dst := dsts.Head() - n, _, e := syscall.Syscall6(syscall.SYS_PREAD64, uintptr(fd), dst.Addr(), uintptr(dst.Len()), uintptr(off), 0, 0) - if e != 0 { - return 0, e - } - return uint64(n), nil - } - iovs := safemem.IovecsFromBlockSeq(dsts) - n, _, e := syscall.Syscall6(syscall.SYS_PREADV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(off), 0, 0) - if e != 0 { - return 0, e - } - return uint64(n), nil -} - -// Preconditions: !srcs.IsEmpty(). -func hostPwritev(fd int32, srcs safemem.BlockSeq, off int64) (uint64, error) { - // No buffering is necessary regardless of safecopy; host syscalls will - // return EFAULT if appropriate, instead of raising SIGBUS. - if srcs.NumBlocks() == 1 { - // Use pwrite() instead of pwritev() to avoid iovec allocation and - // copying. - src := srcs.Head() - n, _, e := syscall.Syscall6(syscall.SYS_PWRITE64, uintptr(fd), src.Addr(), uintptr(src.Len()), uintptr(off), 0, 0) - if e != 0 { - return 0, e - } - return uint64(n), nil - } - iovs := safemem.IovecsFromBlockSeq(srcs) - n, _, e := syscall.Syscall6(syscall.SYS_PWRITEV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(off), 0, 0) - if e != 0 { - return 0, e - } - return uint64(n), nil -} diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go deleted file mode 100644 index 755ac2985..000000000 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/syserror" -) - -// p9file is a wrapper around p9.File that provides methods that are -// Context-aware. -type p9file struct { - file p9.File -} - -func (f p9file) isNil() bool { - return f.file == nil -} - -func (f p9file) walk(ctx context.Context, names []string) ([]p9.QID, p9file, error) { - ctx.UninterruptibleSleepStart(false) - qids, newfile, err := f.file.Walk(names) - ctx.UninterruptibleSleepFinish(false) - return qids, p9file{newfile}, err -} - -func (f p9file) walkGetAttr(ctx context.Context, names []string) ([]p9.QID, p9file, p9.AttrMask, p9.Attr, error) { - ctx.UninterruptibleSleepStart(false) - qids, newfile, attrMask, attr, err := f.file.WalkGetAttr(names) - ctx.UninterruptibleSleepFinish(false) - return qids, p9file{newfile}, attrMask, attr, err -} - -// walkGetAttrOne is a wrapper around p9.File.WalkGetAttr that takes a single -// path component and returns a single qid. -func (f p9file) walkGetAttrOne(ctx context.Context, name string) (p9.QID, p9file, p9.AttrMask, p9.Attr, error) { - ctx.UninterruptibleSleepStart(false) - qids, newfile, attrMask, attr, err := f.file.WalkGetAttr([]string{name}) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, err - } - if len(qids) != 1 { - ctx.Warningf("p9.File.WalkGetAttr returned %d qids (%v), wanted 1", len(qids), qids) - if newfile != nil { - p9file{newfile}.close(ctx) - } - return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, syserror.EIO - } - return qids[0], p9file{newfile}, attrMask, attr, nil -} - -func (f p9file) statFS(ctx context.Context) (p9.FSStat, error) { - ctx.UninterruptibleSleepStart(false) - fsstat, err := f.file.StatFS() - ctx.UninterruptibleSleepFinish(false) - return fsstat, err -} - -func (f p9file) getAttr(ctx context.Context, req p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) { - ctx.UninterruptibleSleepStart(false) - qid, attrMask, attr, err := f.file.GetAttr(req) - ctx.UninterruptibleSleepFinish(false) - return qid, attrMask, attr, err -} - -func (f p9file) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.SetAttr(valid, attr) - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) getXattr(ctx context.Context, name string, size uint64) (string, error) { - ctx.UninterruptibleSleepStart(false) - val, err := f.file.GetXattr(name, size) - ctx.UninterruptibleSleepFinish(false) - return val, err -} - -func (f p9file) setXattr(ctx context.Context, name, value string, flags uint32) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.SetXattr(name, value, flags) - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.Allocate(mode, offset, length) - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) close(ctx context.Context) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.Close() - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { - ctx.UninterruptibleSleepStart(false) - fdobj, qid, iounit, err := f.file.Open(flags) - ctx.UninterruptibleSleepFinish(false) - return fdobj, qid, iounit, err -} - -func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) { - ctx.UninterruptibleSleepStart(false) - n, err := f.file.ReadAt(p, offset) - ctx.UninterruptibleSleepFinish(false) - return n, err -} - -func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) { - ctx.UninterruptibleSleepStart(false) - n, err := f.file.WriteAt(p, offset) - ctx.UninterruptibleSleepFinish(false) - return n, err -} - -func (f p9file) fsync(ctx context.Context) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.FSync() - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) create(ctx context.Context, name string, flags p9.OpenFlags, permissions p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9file, p9.QID, uint32, error) { - ctx.UninterruptibleSleepStart(false) - fdobj, newfile, qid, iounit, err := f.file.Create(name, flags, permissions, uid, gid) - ctx.UninterruptibleSleepFinish(false) - return fdobj, p9file{newfile}, qid, iounit, err -} - -func (f p9file) mkdir(ctx context.Context, name string, permissions p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) { - ctx.UninterruptibleSleepStart(false) - qid, err := f.file.Mkdir(name, permissions, uid, gid) - ctx.UninterruptibleSleepFinish(false) - return qid, err -} - -func (f p9file) symlink(ctx context.Context, oldName string, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) { - ctx.UninterruptibleSleepStart(false) - qid, err := f.file.Symlink(oldName, newName, uid, gid) - ctx.UninterruptibleSleepFinish(false) - return qid, err -} - -func (f p9file) link(ctx context.Context, target p9file, newName string) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.Link(target.file, newName) - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) mknod(ctx context.Context, name string, mode p9.FileMode, major uint32, minor uint32, uid p9.UID, gid p9.GID) (p9.QID, error) { - ctx.UninterruptibleSleepStart(false) - qid, err := f.file.Mknod(name, mode, major, minor, uid, gid) - ctx.UninterruptibleSleepFinish(false) - return qid, err -} - -func (f p9file) rename(ctx context.Context, newDir p9file, newName string) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.Rename(newDir.file, newName) - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) unlinkAt(ctx context.Context, name string, flags uint32) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.UnlinkAt(name, flags) - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) readdir(ctx context.Context, offset uint64, count uint32) ([]p9.Dirent, error) { - ctx.UninterruptibleSleepStart(false) - dirents, err := f.file.Readdir(offset, count) - ctx.UninterruptibleSleepFinish(false) - return dirents, err -} - -func (f p9file) readlink(ctx context.Context) (string, error) { - ctx.UninterruptibleSleepStart(false) - target, err := f.file.Readlink() - ctx.UninterruptibleSleepFinish(false) - return target, err -} - -func (f p9file) flush(ctx context.Context) error { - ctx.UninterruptibleSleepStart(false) - err := f.file.Flush() - ctx.UninterruptibleSleepFinish(false) - return err -} - -func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, error) { - ctx.UninterruptibleSleepStart(false) - fdobj, err := f.file.Connect(flags) - ctx.UninterruptibleSleepFinish(false) - return fdobj, err -} diff --git a/pkg/sentry/fsimpl/gofer/pagemath.go b/pkg/sentry/fsimpl/gofer/pagemath.go deleted file mode 100644 index 847cb0784..000000000 --- a/pkg/sentry/fsimpl/gofer/pagemath.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "gvisor.dev/gvisor/pkg/usermem" -) - -// This are equivalent to usermem.Addr.RoundDown/Up, but without the -// potentially truncating conversion to usermem.Addr. This is necessary because -// there is no way to define generic "PageRoundDown/Up" functions in Go. - -func pageRoundDown(x uint64) uint64 { - return x &^ (usermem.PageSize - 1) -} - -func pageRoundUp(x uint64) uint64 { - return pageRoundDown(x + usermem.PageSize - 1) -} diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go deleted file mode 100644 index e95209661..000000000 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ /dev/null @@ -1,872 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "fmt" - "io" - "math" - "sync" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -func (d *dentry) isRegularFile() bool { - return d.fileType() == linux.S_IFREG -} - -type regularFileFD struct { - fileDescription - - // off is the file offset. off is protected by mu. - mu sync.Mutex - off int64 -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() { -} - -// OnClose implements vfs.FileDescriptionImpl.OnClose. -func (fd *regularFileFD) OnClose(ctx context.Context) error { - if !fd.vfsfd.IsWritable() { - return nil - } - // Skip flushing if writes may be buffered by the client, since (as with - // the VFS1 client) we don't flush buffered writes on close anyway. - d := fd.dentry() - if d.fs.opts.interop == InteropModeExclusive { - return nil - } - d.handleMu.RLock() - defer d.handleMu.RUnlock() - return d.handle.file.flush(ctx) -} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - if offset < 0 { - return 0, syserror.EINVAL - } - if opts.Flags != 0 { - return 0, syserror.EOPNOTSUPP - } - - // Check for reading at EOF before calling into MM (but not under - // InteropModeShared, which makes d.size unreliable). - d := fd.dentry() - if d.fs.opts.interop != InteropModeShared && uint64(offset) >= atomic.LoadUint64(&d.size) { - return 0, io.EOF - } - - if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { - // Lock d.metadataMu for the rest of the read to prevent d.size from - // changing. - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - // Write dirty cached pages that will be touched by the read back to - // the remote file. - if err := d.writeback(ctx, offset, dst.NumBytes()); err != nil { - return 0, err - } - } - - rw := getDentryReadWriter(ctx, d, offset) - if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { - // Require the read to go to the remote file. - rw.direct = true - } - n, err := dst.CopyOutFrom(ctx, rw) - putDentryReadWriter(rw) - if d.fs.opts.interop != InteropModeShared { - // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed(). - d.touchAtime(ctx, fd.vfsfd.Mount()) - } - return n, err -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - fd.mu.Lock() - n, err := fd.PRead(ctx, dst, fd.off, opts) - fd.off += n - fd.mu.Unlock() - return n, err -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - if offset < 0 { - return 0, syserror.EINVAL - } - if opts.Flags != 0 { - return 0, syserror.EOPNOTSUPP - } - - d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - if d.fs.opts.interop != InteropModeShared { - // Compare Linux's mm/filemap.c:__generic_file_write_iter() => - // file_update_time(). This is d.touchCMtime(), but without locking - // d.metadataMu (recursively). - if now, ok := nowFromContext(ctx); ok { - atomic.StoreInt64(&d.mtime, now) - atomic.StoreInt64(&d.ctime, now) - } - } - if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { - // Write dirty cached pages that will be touched by the write back to - // the remote file. - if err := d.writeback(ctx, offset, src.NumBytes()); err != nil { - return 0, err - } - // Remove touched pages from the cache. - pgstart := pageRoundDown(uint64(offset)) - pgend := pageRoundUp(uint64(offset + src.NumBytes())) - if pgend < pgstart { - return 0, syserror.EINVAL - } - mr := memmap.MappableRange{pgstart, pgend} - var freed []platform.FileRange - d.dataMu.Lock() - cseg := d.cache.LowerBoundSegment(mr.Start) - for cseg.Ok() && cseg.Start() < mr.End { - cseg = d.cache.Isolate(cseg, mr) - freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) - cseg = d.cache.Remove(cseg).NextSegment() - } - d.dataMu.Unlock() - // Invalidate mappings of removed pages. - d.mapsMu.Lock() - d.mappings.Invalidate(mr, memmap.InvalidateOpts{}) - d.mapsMu.Unlock() - // Finally free pages removed from the cache. - mf := d.fs.mfp.MemoryFile() - for _, freedFR := range freed { - mf.DecRef(freedFR) - } - } - rw := getDentryReadWriter(ctx, d, offset) - if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { - // Require the write to go to the remote file. - rw.direct = true - } - n, err := src.CopyInTo(ctx, rw) - putDentryReadWriter(rw) - if n != 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 { - // Write dirty cached pages touched by the write back to the remote - // file. - if err := d.writeback(ctx, offset, src.NumBytes()); err != nil { - return 0, err - } - // Request the remote filesystem to sync the remote file. - if err := d.handle.file.fsync(ctx); err != nil { - return 0, err - } - } - return n, err -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n - fd.mu.Unlock() - return n, err -} - -type dentryReadWriter struct { - ctx context.Context - d *dentry - off uint64 - direct bool -} - -var dentryReadWriterPool = sync.Pool{ - New: func() interface{} { - return &dentryReadWriter{} - }, -} - -func getDentryReadWriter(ctx context.Context, d *dentry, offset int64) *dentryReadWriter { - rw := dentryReadWriterPool.Get().(*dentryReadWriter) - rw.ctx = ctx - rw.d = d - rw.off = uint64(offset) - rw.direct = false - return rw -} - -func putDentryReadWriter(rw *dentryReadWriter) { - rw.ctx = nil - rw.d = nil - dentryReadWriterPool.Put(rw) -} - -// ReadToBlocks implements safemem.Reader.ReadToBlocks. -func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - if dsts.IsEmpty() { - return 0, nil - } - - // If we have a mmappable host FD (which must be used here to ensure - // coherence with memory-mapped I/O), or if InteropModeShared is in effect - // (which prevents us from caching file contents and makes dentry.size - // unreliable), or if the file was opened O_DIRECT, read directly from - // dentry.handle without locking dentry.dataMu. - rw.d.handleMu.RLock() - if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { - n, err := rw.d.handle.readToBlocksAt(rw.ctx, dsts, rw.off) - rw.d.handleMu.RUnlock() - rw.off += n - return n, err - } - - // Otherwise read from/through the cache. - mf := rw.d.fs.mfp.MemoryFile() - fillCache := mf.ShouldCacheEvictable() - var dataMuUnlock func() - if fillCache { - rw.d.dataMu.Lock() - dataMuUnlock = rw.d.dataMu.Unlock - } else { - rw.d.dataMu.RLock() - dataMuUnlock = rw.d.dataMu.RUnlock - } - - // Compute the range to read (limited by file size and overflow-checked). - if rw.off >= rw.d.size { - dataMuUnlock() - rw.d.handleMu.RUnlock() - return 0, io.EOF - } - end := rw.d.size - if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end { - end = rend - } - - var done uint64 - seg, gap := rw.d.cache.Find(rw.off) - for rw.off < end { - mr := memmap.MappableRange{rw.off, end} - switch { - case seg.Ok(): - // Get internal mappings from the cache. - ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) - if err != nil { - dataMuUnlock() - rw.d.handleMu.RUnlock() - return done, err - } - - // Copy from internal mappings. - n, err := safemem.CopySeq(dsts, ims) - done += n - rw.off += n - dsts = dsts.DropFirst64(n) - if err != nil { - dataMuUnlock() - rw.d.handleMu.RUnlock() - return done, err - } - - // Continue. - seg, gap = seg.NextNonEmpty() - - case gap.Ok(): - gapMR := gap.Range().Intersect(mr) - if fillCache { - // Read into the cache, then re-enter the loop to read from the - // cache. - reqMR := memmap.MappableRange{ - Start: pageRoundDown(gapMR.Start), - End: pageRoundUp(gapMR.End), - } - optMR := gap.Range() - err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, rw.d.handle.readToBlocksAt) - mf.MarkEvictable(rw.d, pgalloc.EvictableRange{optMR.Start, optMR.End}) - seg, gap = rw.d.cache.Find(rw.off) - if !seg.Ok() { - dataMuUnlock() - rw.d.handleMu.RUnlock() - return done, err - } - // err might have occurred in part of gap.Range() outside - // gapMR. Forget about it for now; if the error matters and - // persists, we'll run into it again in a later iteration of - // this loop. - } else { - // Read directly from the file. - gapDsts := dsts.TakeFirst64(gapMR.Length()) - n, err := rw.d.handle.readToBlocksAt(rw.ctx, gapDsts, gapMR.Start) - done += n - rw.off += n - dsts = dsts.DropFirst64(n) - // Partial reads are fine. But we must stop reading. - if n != gapDsts.NumBytes() || err != nil { - dataMuUnlock() - rw.d.handleMu.RUnlock() - return done, err - } - - // Continue. - seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} - } - } - } - dataMuUnlock() - rw.d.handleMu.RUnlock() - return done, nil -} - -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. -// -// Preconditions: rw.d.metadataMu must be locked. -func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - if srcs.IsEmpty() { - return 0, nil - } - - // If we have a mmappable host FD (which must be used here to ensure - // coherence with memory-mapped I/O), or if InteropModeShared is in effect - // (which prevents us from caching file contents), or if the file was - // opened with O_DIRECT, write directly to dentry.handle without locking - // dentry.dataMu. - rw.d.handleMu.RLock() - if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { - n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off) - rw.off += n - rw.d.dataMu.Lock() - if rw.off > rw.d.size { - atomic.StoreUint64(&rw.d.size, rw.off) - // The remote file's size will implicitly be extended to the correct - // value when we write back to it. - } - rw.d.dataMu.Unlock() - rw.d.handleMu.RUnlock() - return n, err - } - - // Otherwise write to/through the cache. - mf := rw.d.fs.mfp.MemoryFile() - rw.d.dataMu.Lock() - - // Compute the range to write (overflow-checked). - start := rw.off - end := rw.off + srcs.NumBytes() - if end <= rw.off { - end = math.MaxInt64 - } - - var ( - done uint64 - retErr error - ) - seg, gap := rw.d.cache.Find(rw.off) - for rw.off < end { - mr := memmap.MappableRange{rw.off, end} - switch { - case seg.Ok(): - // Get internal mappings from the cache. - segMR := seg.Range().Intersect(mr) - ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write) - if err != nil { - retErr = err - goto exitLoop - } - - // Copy to internal mappings. - n, err := safemem.CopySeq(ims, srcs) - done += n - rw.off += n - srcs = srcs.DropFirst64(n) - rw.d.dirty.MarkDirty(segMR) - if err != nil { - retErr = err - goto exitLoop - } - - // Continue. - seg, gap = seg.NextNonEmpty() - - case gap.Ok(): - // Write directly to the file. At present, we never fill the cache - // when writing, since doing so can convert small writes into - // inefficient read-modify-write cycles, and we have no mechanism - // for detecting or avoiding this. - gapMR := gap.Range().Intersect(mr) - gapSrcs := srcs.TakeFirst64(gapMR.Length()) - n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, gapSrcs, gapMR.Start) - done += n - rw.off += n - srcs = srcs.DropFirst64(n) - // Partial writes are fine. But we must stop writing. - if n != gapSrcs.NumBytes() || err != nil { - retErr = err - goto exitLoop - } - - // Continue. - seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} - } - } -exitLoop: - if rw.off > rw.d.size { - atomic.StoreUint64(&rw.d.size, rw.off) - // The remote file's size will implicitly be extended to the correct - // value when we write back to it. - } - // If InteropModeWritethrough is in effect, flush written data back to the - // remote filesystem. - if rw.d.fs.opts.interop == InteropModeWritethrough && done != 0 { - if err := fsutil.SyncDirty(rw.ctx, memmap.MappableRange{ - Start: start, - End: rw.off, - }, &rw.d.cache, &rw.d.dirty, rw.d.size, mf, rw.d.handle.writeFromBlocksAt); err != nil { - // We have no idea how many bytes were actually flushed. - rw.off = start - done = 0 - retErr = err - } - } - rw.d.dataMu.Unlock() - rw.d.handleMu.RUnlock() - return done, retErr -} - -func (d *dentry) writeback(ctx context.Context, offset, size int64) error { - if size == 0 { - return nil - } - d.handleMu.RLock() - defer d.handleMu.RUnlock() - d.dataMu.Lock() - defer d.dataMu.Unlock() - // Compute the range of valid bytes (overflow-checked). - if uint64(offset) >= d.size { - return nil - } - end := int64(d.size) - if rend := offset + size; rend > offset && rend < end { - end = rend - } - return fsutil.SyncDirty(ctx, memmap.MappableRange{ - Start: uint64(offset), - End: uint64(end), - }, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - fd.mu.Lock() - defer fd.mu.Unlock() - switch whence { - case linux.SEEK_SET: - // Use offset as specified. - case linux.SEEK_CUR: - offset += fd.off - case linux.SEEK_END, linux.SEEK_DATA, linux.SEEK_HOLE: - // Ensure file size is up to date. - d := fd.dentry() - if fd.filesystem().opts.interop == InteropModeShared { - if err := d.updateFromGetattr(ctx); err != nil { - return 0, err - } - } - size := int64(atomic.LoadUint64(&d.size)) - // For SEEK_DATA and SEEK_HOLE, treat the file as a single contiguous - // block of data. - switch whence { - case linux.SEEK_END: - offset += size - case linux.SEEK_DATA: - if offset > size { - return 0, syserror.ENXIO - } - // Use offset as specified. - case linux.SEEK_HOLE: - if offset > size { - return 0, syserror.ENXIO - } - offset = size - } - default: - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL - } - fd.off = offset - return offset, nil -} - -// Sync implements vfs.FileDescriptionImpl.Sync. -func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncSharedHandle(ctx) -} - -func (d *dentry) syncSharedHandle(ctx context.Context) error { - d.handleMu.RLock() - if !d.handleWritable { - d.handleMu.RUnlock() - return nil - } - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) - d.dataMu.Unlock() - if err == nil { - // Sync the remote file. - err = d.handle.sync(ctx) - } - d.handleMu.RUnlock() - return err -} - -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - d := fd.dentry() - switch d.fs.opts.interop { - case InteropModeExclusive: - // Any mapping is fine. - case InteropModeWritethrough: - // Shared writable mappings require a host FD, since otherwise we can't - // synchronously flush memory-mapped writes to the remote file. - if opts.Private || !opts.MaxPerms.Write { - break - } - fallthrough - case InteropModeShared: - // All mappings require a host FD to be coherent with other filesystem - // users. - if d.fs.opts.forcePageCache { - // Whether or not we have a host FD, we're not allowed to use it. - return syserror.ENODEV - } - d.handleMu.RLock() - haveFD := d.handle.fd >= 0 - d.handleMu.RUnlock() - if !haveFD { - return syserror.ENODEV - } - default: - panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop)) - } - // After this point, d may be used as a memmap.Mappable. - d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init) - return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts) -} - -func (d *dentry) mayCachePages() bool { - if d.fs.opts.interop == InteropModeShared { - return false - } - if d.fs.opts.forcePageCache { - return true - } - d.handleMu.RLock() - haveFD := d.handle.fd >= 0 - d.handleMu.RUnlock() - return haveFD -} - -// AddMapping implements memmap.Mappable.AddMapping. -func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { - d.mapsMu.Lock() - mapped := d.mappings.AddMapping(ms, ar, offset, writable) - // Do this unconditionally since whether we have a host FD can change - // across save/restore. - for _, r := range mapped { - d.pf.hostFileMapper.IncRefOn(r) - } - if d.mayCachePages() { - // d.Evict() will refuse to evict memory-mapped pages, so tell the - // MemoryFile to not bother trying. - mf := d.fs.mfp.MemoryFile() - for _, r := range mapped { - mf.MarkUnevictable(d, pgalloc.EvictableRange{r.Start, r.End}) - } - } - d.mapsMu.Unlock() - return nil -} - -// RemoveMapping implements memmap.Mappable.RemoveMapping. -func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { - d.mapsMu.Lock() - unmapped := d.mappings.RemoveMapping(ms, ar, offset, writable) - for _, r := range unmapped { - d.pf.hostFileMapper.DecRefOn(r) - } - if d.mayCachePages() { - // Pages that are no longer referenced by any application memory - // mappings are now considered unused; allow MemoryFile to evict them - // when necessary. - mf := d.fs.mfp.MemoryFile() - d.dataMu.Lock() - for _, r := range unmapped { - // Since these pages are no longer mapped, they are no longer - // concurrently dirtyable by a writable memory mapping. - d.dirty.AllowClean(r) - mf.MarkEvictable(d, pgalloc.EvictableRange{r.Start, r.End}) - } - d.dataMu.Unlock() - } - d.mapsMu.Unlock() -} - -// CopyMapping implements memmap.Mappable.CopyMapping. -func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { - return d.AddMapping(ctx, ms, dstAR, offset, writable) -} - -// Translate implements memmap.Mappable.Translate. -func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { - d.handleMu.RLock() - if d.handle.fd >= 0 && !d.fs.opts.forcePageCache { - d.handleMu.RUnlock() - mr := optional - if d.fs.opts.limitHostFDTranslation { - mr = maxFillRange(required, optional) - } - return []memmap.Translation{ - { - Source: mr, - File: &d.pf, - Offset: mr.Start, - Perms: usermem.AnyAccess, - }, - }, nil - } - - d.dataMu.Lock() - - // Constrain translations to d.size (rounded up) to prevent translation to - // pages that may be concurrently truncated. - pgend := pageRoundUp(d.size) - var beyondEOF bool - if required.End > pgend { - if required.Start >= pgend { - d.dataMu.Unlock() - d.handleMu.RUnlock() - return nil, &memmap.BusError{io.EOF} - } - beyondEOF = true - required.End = pgend - } - if optional.End > pgend { - optional.End = pgend - } - - mf := d.fs.mfp.MemoryFile() - cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, d.handle.readToBlocksAt) - - var ts []memmap.Translation - var translatedEnd uint64 - for seg := d.cache.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() { - segMR := seg.Range().Intersect(optional) - // TODO(jamieliu): Make Translations writable even if writability is - // not required if already kept-dirty by another writable translation. - perms := usermem.AccessType{ - Read: true, - Execute: true, - } - if at.Write { - // From this point forward, this memory can be dirtied through the - // mapping at any time. - d.dirty.KeepDirty(segMR) - perms.Write = true - } - ts = append(ts, memmap.Translation{ - Source: segMR, - File: mf, - Offset: seg.FileRangeOf(segMR).Start, - Perms: perms, - }) - translatedEnd = segMR.End - } - - d.dataMu.Unlock() - d.handleMu.RUnlock() - - // Don't return the error returned by c.cache.Fill if it occurred outside - // of required. - if translatedEnd < required.End && cerr != nil { - return ts, &memmap.BusError{cerr} - } - if beyondEOF { - return ts, &memmap.BusError{io.EOF} - } - return ts, nil -} - -func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange { - const maxReadahead = 64 << 10 // 64 KB, chosen arbitrarily - if required.Length() >= maxReadahead { - return required - } - if optional.Length() <= maxReadahead { - return optional - } - optional.Start = required.Start - if optional.Length() <= maxReadahead { - return optional - } - optional.End = optional.Start + maxReadahead - return optional -} - -// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. -func (d *dentry) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is - // mapped) can change across save/restore, so invalidate all translations - // unconditionally. - d.mapsMu.Lock() - defer d.mapsMu.Unlock() - d.mappings.InvalidateAll(memmap.InvalidateOpts{}) - - // Write the cache's contents back to the remote file so that if we have a - // host fd after restore, the remote file's contents are coherent. - mf := d.fs.mfp.MemoryFile() - d.dataMu.Lock() - defer d.dataMu.Unlock() - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { - return err - } - - // Discard the cache so that it's not stored in saved state. This is safe - // because per InvalidateUnsavable invariants, no new translations can have - // been returned after we invalidated all existing translations above. - d.cache.DropAll(mf) - d.dirty.RemoveAll() - - return nil -} - -// Evict implements pgalloc.EvictableMemoryUser.Evict. -func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { - d.mapsMu.Lock() - defer d.mapsMu.Unlock() - d.dataMu.Lock() - defer d.dataMu.Unlock() - - mr := memmap.MappableRange{er.Start, er.End} - mf := d.fs.mfp.MemoryFile() - // Only allow pages that are no longer memory-mapped to be evicted. - for mgap := d.mappings.LowerBoundGap(mr.Start); mgap.Ok() && mgap.Start() < mr.End; mgap = mgap.NextGap() { - mgapMR := mgap.Range().Intersect(mr) - if mgapMR.Length() == 0 { - continue - } - if err := fsutil.SyncDirty(ctx, mgapMR, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { - log.Warningf("Failed to writeback cached data %v: %v", mgapMR, err) - } - d.cache.Drop(mgapMR, mf) - d.dirty.KeepClean(mgapMR) - } -} - -// dentryPlatformFile implements platform.File. It exists solely because dentry -// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef. -// -// dentryPlatformFile is only used when a host FD representing the remote file -// is available (i.e. dentry.handle.fd >= 0), and that FD is used for -// application memory mappings (i.e. !filesystem.opts.forcePageCache). -type dentryPlatformFile struct { - *dentry - - // fdRefs counts references on platform.File offsets. fdRefs is protected - // by dentry.dataMu. - fdRefs fsutil.FrameRefSet - - // If this dentry represents a regular file, and handle.fd >= 0, - // hostFileMapper caches mappings of handle.fd. - hostFileMapper fsutil.HostFileMapper - - // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. - hostFileMapperInitOnce sync.Once -} - -// IncRef implements platform.File.IncRef. -func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { - d.dataMu.Lock() - seg, gap := d.fdRefs.Find(fr.Start) - for { - switch { - case seg.Ok() && seg.Start() < fr.End: - seg = d.fdRefs.Isolate(seg, fr) - seg.SetValue(seg.Value() + 1) - seg, gap = seg.NextNonEmpty() - case gap.Ok() && gap.Start() < fr.End: - newRange := gap.Range().Intersect(fr) - usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) - seg, gap = d.fdRefs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty() - default: - d.fdRefs.MergeAdjacent(fr) - d.dataMu.Unlock() - return - } - } -} - -// DecRef implements platform.File.DecRef. -func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { - d.dataMu.Lock() - seg := d.fdRefs.FindSegment(fr.Start) - - for seg.Ok() && seg.Start() < fr.End { - seg = d.fdRefs.Isolate(seg, fr) - if old := seg.Value(); old == 1 { - usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) - seg = d.fdRefs.Remove(seg).NextSegment() - } else { - seg.SetValue(old - 1) - seg = seg.NextSegment() - } - } - d.fdRefs.MergeAdjacent(fr) - d.dataMu.Unlock() - -} - -// MapInternal implements platform.File.MapInternal. -func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { - d.handleMu.RLock() - bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write) - d.handleMu.RUnlock() - return bs, err -} - -// FD implements platform.File.FD. -func (d *dentryPlatformFile) FD() int { - d.handleMu.RLock() - fd := d.handle.fd - d.handleMu.RUnlock() - return int(fd) -} diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go deleted file mode 100644 index 08c691c47..000000000 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "sync" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// specialFileFD implements vfs.FileDescriptionImpl for files other than -// regular files, directories, and symlinks: pipes, sockets, etc. It is also -// used for regular files when filesystemOptions.specialRegularFiles is in -// effect. specialFileFD differs from regularFileFD by using per-FD handles -// instead of shared per-dentry handles, and never buffering I/O. -type specialFileFD struct { - fileDescription - - // handle is immutable. - handle handle - - // off is the file offset. off is protected by mu. (POSIX 2.9.7 only - // requires operations using the file offset to be atomic for regular files - // and symlinks; however, since specialFileFD may be used for regular - // files, we apply this atomicity unconditionally.) - mu sync.Mutex - off int64 -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *specialFileFD) Release() { - fd.handle.close(context.Background()) - fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) - fs.syncMu.Lock() - delete(fs.specialFileFDs, fd) - fs.syncMu.Unlock() -} - -// OnClose implements vfs.FileDescriptionImpl.OnClose. -func (fd *specialFileFD) OnClose(ctx context.Context) error { - if !fd.vfsfd.IsWritable() { - return nil - } - return fd.handle.file.flush(ctx) -} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - if offset < 0 { - return 0, syserror.EINVAL - } - if opts.Flags != 0 { - return 0, syserror.EOPNOTSUPP - } - - // Going through dst.CopyOutFrom() holds MM locks around file operations of - // unknown duration. For regularFileFD, doing so is necessary to support - // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't - // hold here since specialFileFD doesn't client-cache data. Just buffer the - // read instead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { - d.touchAtime(ctx, fd.vfsfd.Mount()) - } - buf := make([]byte, dst.NumBytes()) - n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) - if n == 0 { - return 0, err - } - if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil { - return int64(cp), cperr - } - return int64(n), err -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - fd.mu.Lock() - n, err := fd.PRead(ctx, dst, fd.off, opts) - fd.off += n - fd.mu.Unlock() - return n, err -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - if offset < 0 { - return 0, syserror.EINVAL - } - if opts.Flags != 0 { - return 0, syserror.EOPNOTSUPP - } - - // Do a buffered write. See rationale in PRead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { - d.touchCMtime(ctx) - } - buf := make([]byte, src.NumBytes()) - // Don't do partial writes if we get a partial read from src. - if _, err := src.CopyIn(ctx, buf); err != nil { - return 0, err - } - n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) - return int64(n), err -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n - fd.mu.Unlock() - return n, err -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - fd.mu.Lock() - defer fd.mu.Unlock() - switch whence { - case linux.SEEK_SET: - // Use offset as given. - case linux.SEEK_CUR: - offset += fd.off - default: - // SEEK_END, SEEK_DATA, and SEEK_HOLE aren't supported since it's not - // clear that file size is even meaningful for these files. - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL - } - fd.off = offset - return offset, nil -} - -// Sync implements vfs.FileDescriptionImpl.Sync. -func (fd *specialFileFD) Sync(ctx context.Context) error { - if !fd.vfsfd.IsWritable() { - return nil - } - return fd.handle.sync(ctx) -} diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go deleted file mode 100644 index adf43be60..000000000 --- a/pkg/sentry/fsimpl/gofer/symlink.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -func (d *dentry) isSymlink() bool { - return d.fileType() == linux.S_IFLNK -} - -// Precondition: d.isSymlink(). -func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { - if d.fs.opts.interop != InteropModeShared { - d.touchAtime(ctx, mnt) - d.dataMu.Lock() - if d.haveTarget { - target := d.target - d.dataMu.Unlock() - return target, nil - } - } - target, err := d.file.readlink(ctx) - if d.fs.opts.interop != InteropModeShared { - if err == nil { - d.haveTarget = true - d.target = target - } - d.dataMu.Unlock() - } - return target, err -} diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go deleted file mode 100644 index 7598ec6a8..000000000 --- a/pkg/sentry/fsimpl/gofer/time.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -func dentryTimestampFromP9(s, ns uint64) int64 { - return int64(s*1e9 + ns) -} - -func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 { - return ts.Sec*1e9 + int64(ts.Nsec) -} - -func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { - return linux.StatxTimestamp{ - Sec: ns / 1e9, - Nsec: uint32(ns % 1e9), - } -} - -func nowFromContext(ctx context.Context) (int64, bool) { - if clock := ktime.RealtimeClockFromContext(ctx); clock != nil { - return clock.Now().Nanoseconds(), true - } - return 0, false -} - -// Preconditions: fs.interop != InteropModeShared. -func (d *dentry) touchAtime(ctx context.Context, mnt *vfs.Mount) { - if err := mnt.CheckBeginWrite(); err != nil { - return - } - now, ok := nowFromContext(ctx) - if !ok { - mnt.EndWrite() - return - } - d.metadataMu.Lock() - atomic.StoreInt64(&d.atime, now) - d.metadataMu.Unlock() - mnt.EndWrite() -} - -// Preconditions: fs.interop != InteropModeShared. The caller has successfully -// called vfs.Mount.CheckBeginWrite(). -func (d *dentry) touchCMtime(ctx context.Context) { - now, ok := nowFromContext(ctx) - if !ok { - return - } - d.metadataMu.Lock() - atomic.StoreInt64(&d.mtime, now) - atomic.StoreInt64(&d.ctime, now) - d.metadataMu.Unlock() -} diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD deleted file mode 100644 index 731f192b3..000000000 --- a/pkg/sentry/fsimpl/host/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "host", - srcs = [ - "default_file.go", - "host.go", - "util.go", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/refs", - "//pkg/safemem", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/host/default_file.go b/pkg/sentry/fsimpl/host/default_file.go deleted file mode 100644 index 172cdb161..000000000 --- a/pkg/sentry/fsimpl/host/default_file.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "math" - "syscall" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// defaultFileFD implements FileDescriptionImpl for non-socket, non-TTY files. -type defaultFileFD struct { - fileDescription - - // canMap specifies whether we allow the file to be memory mapped. - canMap bool - - // mu protects the fields below. - mu sync.Mutex - - // offset specifies the current file offset. - offset int64 -} - -// TODO(gvisor.dev/issue/1672): Implement Waitable interface. - -// PRead implements FileDescriptionImpl. -func (f *defaultFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. - if f.inode.isStream { - return 0, syserror.ESPIPE - } - - return readFromHostFD(ctx, f.inode.hostFD, dst, offset, int(opts.Flags)) -} - -// Read implements FileDescriptionImpl. -func (f *defaultFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. - if f.inode.isStream { - // These files can't be memory mapped, assert this. - if f.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - f.mu.Lock() - n, err := readFromHostFD(ctx, f.inode.hostFD, dst, -1, int(opts.Flags)) - f.mu.Unlock() - if isBlockError(err) { - // If we got any data at all, return it as a "completed" partial read - // rather than retrying until complete. - if n != 0 { - err = nil - } else { - err = syserror.ErrWouldBlock - } - } - return n, err - } - // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. - f.mu.Lock() - n, err := readFromHostFD(ctx, f.inode.hostFD, dst, f.offset, int(opts.Flags)) - f.offset += n - f.mu.Unlock() - return n, err -} - -func readFromHostFD(ctx context.Context, fd int, dst usermem.IOSequence, offset int64, flags int) (int64, error) { - if flags&^(linux.RWF_VALID) != 0 { - return 0, syserror.EOPNOTSUPP - } - - reader := safemem.FromVecReaderFunc{ - func(srcs [][]byte) (int64, error) { - n, err := unix.Preadv2(fd, srcs, offset, flags) - return int64(n), err - }, - } - n, err := dst.CopyOutFrom(ctx, reader) - return int64(n), err -} - -// PWrite implements FileDescriptionImpl. -func (f *defaultFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. - if f.inode.isStream { - return 0, syserror.ESPIPE - } - - return writeToHostFD(ctx, f.inode.hostFD, src, offset, int(opts.Flags)) -} - -// Write implements FileDescriptionImpl. -func (f *defaultFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. - if f.inode.isStream { - // These files can't be memory mapped, assert this. - if f.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - f.mu.Lock() - n, err := writeToHostFD(ctx, f.inode.hostFD, src, -1, int(opts.Flags)) - f.mu.Unlock() - if isBlockError(err) { - err = syserror.ErrWouldBlock - } - return n, err - } - // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. - // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file. - f.mu.Lock() - n, err := writeToHostFD(ctx, f.inode.hostFD, src, f.offset, int(opts.Flags)) - f.offset += n - f.mu.Unlock() - return n, err -} - -func writeToHostFD(ctx context.Context, fd int, src usermem.IOSequence, offset int64, flags int) (int64, error) { - if flags&^(linux.RWF_VALID) != 0 { - return 0, syserror.EOPNOTSUPP - } - - writer := safemem.FromVecWriterFunc{ - func(srcs [][]byte) (int64, error) { - n, err := unix.Pwritev2(fd, srcs, offset, flags) - return int64(n), err - }, - } - n, err := src.CopyInTo(ctx, writer) - return int64(n), err -} - -// Seek implements FileDescriptionImpl. -// -// Note that we do not support seeking on directories, since we do not even -// allow directory fds to be imported at all. -func (f *defaultFileFD) Seek(_ context.Context, offset int64, whence int32) (int64, error) { - // TODO(b/34716638): Some char devices do support seeking, e.g. /dev/null. - if f.inode.isStream { - return 0, syserror.ESPIPE - } - - f.mu.Lock() - defer f.mu.Unlock() - - switch whence { - case linux.SEEK_SET: - if offset < 0 { - return f.offset, syserror.EINVAL - } - f.offset = offset - - case linux.SEEK_CUR: - // Check for overflow. Note that underflow cannot occur, since f.offset >= 0. - if offset > math.MaxInt64-f.offset { - return f.offset, syserror.EOVERFLOW - } - if f.offset+offset < 0 { - return f.offset, syserror.EINVAL - } - f.offset += offset - - case linux.SEEK_END: - var s syscall.Stat_t - if err := syscall.Fstat(f.inode.hostFD, &s); err != nil { - return f.offset, err - } - size := s.Size - - // Check for overflow. Note that underflow cannot occur, since size >= 0. - if offset > math.MaxInt64-size { - return f.offset, syserror.EOVERFLOW - } - if size+offset < 0 { - return f.offset, syserror.EINVAL - } - f.offset = size + offset - - case linux.SEEK_DATA, linux.SEEK_HOLE: - // Modifying the offset in the host file table should not matter, since - // this is the only place where we use it. - // - // For reading and writing, we always rely on our internal offset. - n, err := unix.Seek(f.inode.hostFD, offset, int(whence)) - if err != nil { - return f.offset, err - } - f.offset = n - - default: - // Invalid whence. - return f.offset, syserror.EINVAL - } - - return f.offset, nil -} - -// Sync implements FileDescriptionImpl. -func (f *defaultFileFD) Sync(context.Context) error { - // TODO(gvisor.dev/issue/1672): Currently we do not support the SyncData optimization, so we always sync everything. - return unix.Fsync(f.inode.hostFD) -} - -// ConfigureMMap implements FileDescriptionImpl. -func (f *defaultFileFD) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { - if !f.canMap { - return syserror.ENODEV - } - // TODO(gvisor.dev/issue/1672): Implement ConfigureMMap and Mappable interface. - return syserror.ENODEV -} diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go deleted file mode 100644 index c205e6a0b..000000000 --- a/pkg/sentry/fsimpl/host/host.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package host provides a filesystem implementation for host files imported as -// file descriptors. -package host - -import ( - "errors" - "fmt" - "syscall" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" -) - -// filesystem implements vfs.FilesystemImpl. -type filesystem struct { - kernfs.Filesystem -} - -// ImportFD sets up and returns a vfs.FileDescription from a donated fd. -func ImportFD(mnt *vfs.Mount, hostFD int, ownerUID auth.KUID, ownerGID auth.KGID, isTTY bool) (*vfs.FileDescription, error) { - // Must be importing to a mount of host.filesystem. - fs, ok := mnt.Filesystem().Impl().(*filesystem) - if !ok { - return nil, fmt.Errorf("can't import host FDs into filesystems of type %T", mnt.Filesystem().Impl()) - } - - // Retrieve metadata. - var s syscall.Stat_t - if err := syscall.Fstat(hostFD, &s); err != nil { - return nil, err - } - - fileMode := linux.FileMode(s.Mode) - fileType := fileMode.FileType() - // Pipes, character devices, and sockets can return EWOULDBLOCK for - // operations that would block. - isStream := fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK - - i := &inode{ - hostFD: hostFD, - isStream: isStream, - isTTY: isTTY, - ino: fs.NextIno(), - mode: fileMode, - uid: ownerUID, - gid: ownerGID, - } - - d := &kernfs.Dentry{} - d.Init(i) - // i.open will take a reference on d. - defer d.DecRef() - - return i.open(d.VFSDentry(), mnt) -} - -// inode implements kernfs.Inode. -type inode struct { - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink - - // When the reference count reaches zero, the host fd is closed. - refs.AtomicRefCount - - // hostFD contains the host fd that this file was originally created from, - // which must be available at time of restore. - // - // This field is initialized at creation time and is immutable. - hostFD int - - // isStream is true if the host fd points to a file representing a stream, - // e.g. a socket or a pipe. Such files are not seekable and can return - // EWOULDBLOCK for I/O operations. - // - // This field is initialized at creation time and is immutable. - isStream bool - - // isTTY is true if this file represents a TTY. - // - // This field is initialized at creation time and is immutable. - isTTY bool - - // ino is an inode number unique within this filesystem. - ino uint64 - - // mu protects the inode metadata below. - mu sync.Mutex - - // mode is the file mode of this inode. Note that this value may become out - // of date if the mode is changed on the host, e.g. with chmod. - mode linux.FileMode - - // uid and gid of the file owner. Note that these refer to the owner of the - // file created on import, not the fd on the host. - uid auth.KUID - gid auth.KGID -} - -// Note that these flags may become out of date, since they can be modified -// on the host, e.g. with fcntl. -func fileFlagsFromHostFD(fd int) (int, error) { - flags, err := unix.FcntlInt(uintptr(fd), syscall.F_GETFL, 0) - if err != nil { - log.Warningf("Failed to get file flags for donated FD %d: %v", fd, err) - return 0, err - } - // TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags. - flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND - return flags, nil -} - -// CheckPermissions implements kernfs.Inode. -func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error { - return vfs.GenericCheckPermissions(creds, atx, false /* isDir */, uint16(i.mode), i.uid, i.gid) -} - -// Mode implements kernfs.Inode. -func (i *inode) Mode() linux.FileMode { - return i.mode -} - -// Stat implements kernfs.Inode. -func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - var s unix.Statx_t - if err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(opts.Mask), &s); err != nil { - return linux.Statx{}, err - } - ls := unixToLinuxStatx(s) - - // Use our own internal inode number and file owner. - // - // TODO(gvisor.dev/issue/1672): Use a kernfs-specific device number as well. - // If we use the device number from the host, it may collide with another - // sentry-internal device number. We handle device/inode numbers without - // relying on the host to prevent collisions. - ls.Ino = i.ino - ls.UID = uint32(i.uid) - ls.GID = uint32(i.gid) - - // Update file mode from the host. - i.mode = linux.FileMode(ls.Mode) - - return ls, nil -} - -// SetStat implements kernfs.Inode. -func (i *inode) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { - s := opts.Stat - - m := s.Mask - if m == 0 { - return nil - } - if m&(linux.STATX_UID|linux.STATX_GID) != 0 { - return syserror.EPERM - } - if m&linux.STATX_MODE != 0 { - if err := syscall.Fchmod(i.hostFD, uint32(s.Mode)); err != nil { - return err - } - i.mode = linux.FileMode(s.Mode) - } - if m&linux.STATX_SIZE != 0 { - if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { - return err - } - } - if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 { - timestamps := []unix.Timespec{ - toTimespec(s.Atime, m&linux.STATX_ATIME == 0), - toTimespec(s.Mtime, m&linux.STATX_MTIME == 0), - } - if err := unix.UtimesNanoAt(i.hostFD, "", timestamps, unix.AT_EMPTY_PATH); err != nil { - return err - } - } - return nil -} - -// DecRef implements kernfs.Inode. -func (i *inode) DecRef() { - i.AtomicRefCount.DecRefWithDestructor(i.Destroy) -} - -// Destroy implements kernfs.Inode. -func (i *inode) Destroy() { - if err := unix.Close(i.hostFD); err != nil { - log.Warningf("failed to close host fd %d: %v", i.hostFD, err) - } -} - -// Open implements kernfs.Inode. -func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - return i.open(vfsd, rp.Mount()) -} - -func (i *inode) open(d *vfs.Dentry, mnt *vfs.Mount) (*vfs.FileDescription, error) { - - fileType := i.mode.FileType() - if fileType == syscall.S_IFSOCK { - if i.isTTY { - return nil, errors.New("cannot use host socket as TTY") - } - // TODO(gvisor.dev/issue/1672): support importing sockets. - return nil, errors.New("importing host sockets not supported") - } - - if i.isTTY { - // TODO(gvisor.dev/issue/1672): support importing host fd as TTY. - return nil, errors.New("importing host fd as TTY not supported") - } - - // For simplicity, set offset to 0. Technically, we should - // only set to 0 on files that are not seekable (sockets, pipes, etc.), - // and use the offset from the host fd otherwise. - fd := &defaultFileFD{ - fileDescription: fileDescription{ - inode: i, - }, - canMap: canMap(uint32(fileType)), - mu: sync.Mutex{}, - offset: 0, - } - - vfsfd := &fd.vfsfd - flags, err := fileFlagsFromHostFD(i.hostFD) - if err != nil { - return nil, err - } - - if err := vfsfd.Init(fd, uint32(flags), mnt, d, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - return vfsfd, nil -} - -// fileDescription is embedded by host fd implementations of FileDescriptionImpl. -type fileDescription struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl - - // inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but - // cached to reduce indirections and casting. fileDescription does not hold - // a reference on the inode through the inode field (since one is already - // held via the Dentry). - // - // inode is immutable after fileDescription creation. - inode *inode -} - -// SetStat implements vfs.FileDescriptionImpl. -func (f *fileDescription) SetStat(_ context.Context, opts vfs.SetStatOptions) error { - return f.inode.SetStat(nil, opts) -} - -// Stat implements vfs.FileDescriptionImpl. -func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) { - return f.inode.Stat(nil, opts) -} - -// Release implements vfs.FileDescriptionImpl. -func (f *fileDescription) Release() { - // noop -} diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go deleted file mode 100644 index e1ccacb4d..000000000 --- a/pkg/sentry/fsimpl/host/util.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "syscall" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/syserror" -) - -func toTimespec(ts linux.StatxTimestamp, omit bool) unix.Timespec { - if omit { - return unix.Timespec{ - Sec: 0, - Nsec: unix.UTIME_OMIT, - } - } - return unix.Timespec{ - Sec: int64(ts.Sec), - Nsec: int64(ts.Nsec), - } -} - -func unixToLinuxStatx(s unix.Statx_t) linux.Statx { - return linux.Statx{ - Mask: s.Mask, - Blksize: s.Blksize, - Attributes: s.Attributes, - Nlink: s.Nlink, - UID: s.Uid, - GID: s.Gid, - Mode: s.Mode, - Ino: s.Ino, - Size: s.Size, - Blocks: s.Blocks, - AttributesMask: s.Attributes_mask, - Atime: unixToLinuxStatxTimestamp(s.Atime), - Btime: unixToLinuxStatxTimestamp(s.Btime), - Ctime: unixToLinuxStatxTimestamp(s.Ctime), - Mtime: unixToLinuxStatxTimestamp(s.Mtime), - RdevMajor: s.Rdev_major, - RdevMinor: s.Rdev_minor, - DevMajor: s.Dev_major, - DevMinor: s.Dev_minor, - } -} - -func unixToLinuxStatxTimestamp(ts unix.StatxTimestamp) linux.StatxTimestamp { - return linux.StatxTimestamp{Sec: ts.Sec, Nsec: ts.Nsec} -} - -// wouldBlock returns true for file types that can return EWOULDBLOCK -// for blocking operations, e.g. pipes, character devices, and sockets. -func wouldBlock(fileType uint32) bool { - return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK -} - -// canMap returns true if a file with fileType is allowed to be memory mapped. -// This is ported over from VFS1, but it's probably not the best way for us -// to check if a file can be memory mapped. -func canMap(fileType uint32) bool { - // TODO(gvisor.dev/issue/1672): Also allow "special files" to be mapped (see fs/host:canMap()). - // - // TODO(b/38213152): Some obscure character devices can be mapped. - return fileType == syscall.S_IFREG -} - -// isBlockError checks if an error is EAGAIN or EWOULDBLOCK. -// If so, they can be transformed into syserror.ErrWouldBlock. -func isBlockError(err error) bool { - return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK -} diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD deleted file mode 100644 index e73f1f857..000000000 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ /dev/null @@ -1,61 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "slot_list", - out = "slot_list.go", - package = "kernfs", - prefix = "slot", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*slot", - "Linker": "*slot", - }, -) - -go_library( - name = "kernfs", - srcs = [ - "dynamic_bytes_file.go", - "fd_impl_util.go", - "filesystem.go", - "inode_impl_util.go", - "kernfs.go", - "slot_list.go", - "symlink.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) - -go_test( - name = "kernfs_test", - size = "small", - srcs = ["kernfs_test.go"], - deps = [ - ":kernfs", - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/fsimpl/testutil", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go deleted file mode 100644 index 1c026f4d8..000000000 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernfs - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// DynamicBytesFile implements kernfs.Inode and represents a read-only -// file whose contents are backed by a vfs.DynamicBytesSource. -// -// Must be instantiated with NewDynamicBytesFile or initialized with Init -// before first use. -// -// +stateify savable -type DynamicBytesFile struct { - InodeAttrs - InodeNoopRefCount - InodeNotDirectory - InodeNotSymlink - - data vfs.DynamicBytesSource -} - -var _ Inode = (*DynamicBytesFile)(nil) - -// Init initializes a dynamic bytes file. -func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { - if perm&^linux.PermissionsMask != 0 { - panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) - } - f.InodeAttrs.Init(creds, ino, linux.ModeRegular|perm) - f.data = data -} - -// Open implements Inode.Open. -func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &DynamicBytesFD{} - if err := fd.Init(rp.Mount(), vfsd, f.data, opts.Flags); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - -// SetStat implements Inode.SetStat. -func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error { - // DynamicBytesFiles are immutable. - return syserror.EPERM -} - -// DynamicBytesFD implements vfs.FileDescriptionImpl for an FD backed by a -// DynamicBytesFile. -// -// Must be initialized with Init before first use. -// -// +stateify savable -type DynamicBytesFD struct { - vfs.FileDescriptionDefaultImpl - vfs.DynamicBytesFileDescriptionImpl - - vfsfd vfs.FileDescription - inode Inode -} - -// Init initializes a DynamicBytesFD. -func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) error { - if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { - return err - } - fd.inode = d.Impl().(*Dentry).inode - fd.SetDataSource(data) - return nil -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence) -} - -// Read implmenets vfs.FileDescriptionImpl.Read. -func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts) -} - -// PRead implmenets vfs.FileDescriptionImpl.PRead. -func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts) -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *DynamicBytesFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return fd.DynamicBytesFileDescriptionImpl.Write(ctx, src, opts) -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return fd.DynamicBytesFileDescriptionImpl.PWrite(ctx, src, offset, opts) -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *DynamicBytesFD) Release() {} - -// Stat implements vfs.FileDescriptionImpl.Stat. -func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(fs, opts) -} - -// SetStat implements vfs.FileDescriptionImpl.SetStat. -func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { - // DynamicBytesFiles are immutable. - return syserror.EPERM -} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go deleted file mode 100644 index da821d524..000000000 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory -// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not -// compatible with dynamic directories. -// -// Note that GenericDirectoryFD holds a lock over OrderedChildren while calling -// IterDirents callback. The IterDirents callback therefore cannot hash or -// unhash children, or recursively call IterDirents on the same underlying -// inode. -// -// Must be initialize with Init before first use. -type GenericDirectoryFD struct { - vfs.FileDescriptionDefaultImpl - vfs.DirectoryFileDescriptionDefaultImpl - - vfsfd vfs.FileDescription - children *OrderedChildren - off int64 -} - -// Init initializes a GenericDirectoryFD. -func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, opts *vfs.OpenOptions) error { - if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 { - // Can't open directories for writing. - return syserror.EISDIR - } - if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { - return err - } - fd.children = children - return nil -} - -// VFSFileDescription returns a pointer to the vfs.FileDescription representing -// this object. -func (fd *GenericDirectoryFD) VFSFileDescription() *vfs.FileDescription { - return &fd.vfsfd -} - -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *GenericDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - return fd.FileDescriptionDefaultImpl.ConfigureMMap(ctx, opts) -} - -// Read implmenets vfs.FileDescriptionImpl.Read. -func (fd *GenericDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return fd.DirectoryFileDescriptionDefaultImpl.Read(ctx, dst, opts) -} - -// PRead implmenets vfs.FileDescriptionImpl.PRead. -func (fd *GenericDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return fd.DirectoryFileDescriptionDefaultImpl.PRead(ctx, dst, offset, opts) -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *GenericDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return fd.DirectoryFileDescriptionDefaultImpl.Write(ctx, src, opts) -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) -} - -// Release implements vfs.FileDecriptionImpl.Release. -func (fd *GenericDirectoryFD) Release() {} - -func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { - return fd.vfsfd.VirtualDentry().Mount().Filesystem() -} - -func (fd *GenericDirectoryFD) inode() Inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode -} - -// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds -// o.mu when calling cb. -func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { - vfsFS := fd.filesystem() - fs := vfsFS.Impl().(*Filesystem) - vfsd := fd.vfsfd.VirtualDentry().Dentry() - - fs.mu.Lock() - defer fs.mu.Unlock() - - opts := vfs.StatOptions{Mask: linux.STATX_INO} - // Handle ".". - if fd.off == 0 { - stat, err := fd.inode().Stat(vfsFS, opts) - if err != nil { - return err - } - dirent := vfs.Dirent{ - Name: ".", - Type: linux.DT_DIR, - Ino: stat.Ino, - NextOff: 1, - } - if err := cb.Handle(dirent); err != nil { - return err - } - fd.off++ - } - - // Handle "..". - if fd.off == 1 { - parentInode := vfsd.ParentOrSelf().Impl().(*Dentry).inode - stat, err := parentInode.Stat(vfsFS, opts) - if err != nil { - return err - } - dirent := vfs.Dirent{ - Name: "..", - Type: linux.FileMode(stat.Mode).DirentType(), - Ino: stat.Ino, - NextOff: 2, - } - if err := cb.Handle(dirent); err != nil { - return err - } - fd.off++ - } - - // Handle static children. - fd.children.mu.RLock() - defer fd.children.mu.RUnlock() - // fd.off accounts for "." and "..", but fd.children do not track - // these. - childIdx := fd.off - 2 - for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { - inode := it.Dentry.Impl().(*Dentry).inode - stat, err := inode.Stat(vfsFS, opts) - if err != nil { - return err - } - dirent := vfs.Dirent{ - Name: it.Name, - Type: linux.FileMode(stat.Mode).DirentType(), - Ino: stat.Ino, - NextOff: fd.off + 1, - } - if err := cb.Handle(dirent); err != nil { - return err - } - fd.off++ - } - - var err error - relOffset := fd.off - int64(len(fd.children.set)) - 2 - fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset) - return err -} - -// Seek implements vfs.FileDecriptionImpl.Seek. -func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - fs := fd.filesystem().Impl().(*Filesystem) - fs.mu.Lock() - defer fs.mu.Unlock() - - switch whence { - case linux.SEEK_SET: - // Use offset as given. - case linux.SEEK_CUR: - offset += fd.off - default: - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL - } - fd.off = offset - return offset, nil -} - -// Stat implements vfs.FileDescriptionImpl.Stat. -func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - fs := fd.filesystem() - inode := fd.inode() - return inode.Stat(fs, opts) -} - -// SetStat implements vfs.FileDescriptionImpl.SetStat. -func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - fs := fd.filesystem() - inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode - return inode.SetStat(fs, opts) -} diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go deleted file mode 100644 index 3288de290..000000000 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ /dev/null @@ -1,788 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernfs - -// This file implements vfs.FilesystemImpl for kernfs. - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// stepExistingLocked resolves rp.Component() in parent directory vfsd. -// -// stepExistingLocked is loosely analogous to fs/namei.c:walk_component(). -// -// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). -// -// Postcondition: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) (*vfs.Dentry, error) { - d := vfsd.Impl().(*Dentry) - if !d.isDir() { - return nil, syserror.ENOTDIR - } - // Directory searchable? - if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { - return nil, err - } -afterSymlink: - name := rp.Component() - // Revalidation must be skipped if name is "." or ".."; d or its parent - // respectively can't be expected to transition from invalidated back to - // valid, so detecting invalidation and retrying would loop forever. This - // is consistent with Linux: fs/namei.c:walk_component() => lookup_fast() - // calls d_revalidate(), but walk_component() => handle_dots() does not. - if name == "." { - rp.Advance() - return vfsd, nil - } - if name == ".." { - nextVFSD, err := rp.ResolveParent(vfsd) - if err != nil { - return nil, err - } - rp.Advance() - return nextVFSD, nil - } - d.dirMu.Lock() - nextVFSD, err := rp.ResolveChild(vfsd, name) - if err != nil { - d.dirMu.Unlock() - return nil, err - } - next, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), d, name, nextVFSD) - d.dirMu.Unlock() - if err != nil { - return nil, err - } - // Resolve any symlink at current path component. - if rp.ShouldFollowSymlink() && next.isSymlink() { - // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks". - target, err := next.inode.Readlink(ctx) - if err != nil { - return nil, err - } - if err := rp.HandleSymlink(target); err != nil { - return nil, err - } - goto afterSymlink - - } - rp.Advance() - return &next.vfsd, nil -} - -// revalidateChildLocked must be called after a call to parent.vfsd.Child(name) -// or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be -// nil) to verify that the returned child (or lack thereof) is correct. -// -// Preconditions: Filesystem.mu must be locked for at least reading. -// parent.dirMu must be locked. parent.isDir(). name is not "." or "..". -// -// Postconditions: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *Dentry, name string, childVFSD *vfs.Dentry) (*Dentry, error) { - if childVFSD != nil { - // Cached dentry exists, revalidate. - child := childVFSD.Impl().(*Dentry) - if !child.inode.Valid(ctx) { - vfsObj.ForceDeleteDentry(childVFSD) - fs.deferDecRef(childVFSD) // Reference from Lookup. - childVFSD = nil - } - } - if childVFSD == nil { - // Dentry isn't cached; it either doesn't exist or failed - // revalidation. Attempt to resolve it via Lookup. - // - // FIXME(gvisor.dev/issue/1193): Inode.Lookup() should return - // *(kernfs.)Dentry, not *vfs.Dentry, since (kernfs.)Filesystem assumes - // that all dentries in the filesystem are (kernfs.)Dentry and performs - // vfs.DentryImpl casts accordingly. - var err error - childVFSD, err = parent.inode.Lookup(ctx, name) - if err != nil { - return nil, err - } - // Reference on childVFSD dropped by a corresponding Valid. - parent.insertChildLocked(name, childVFSD) - } - return childVFSD.Impl().(*Dentry), nil -} - -// walkExistingLocked resolves rp to an existing file. -// -// walkExistingLocked is loosely analogous to Linux's -// fs/namei.c:path_lookupat(). -// -// Preconditions: Filesystem.mu must be locked for at least reading. -// -// Postconditions: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { - vfsd := rp.Start() - for !rp.Done() { - var err error - vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd) - if err != nil { - return nil, nil, err - } - } - d := vfsd.Impl().(*Dentry) - if rp.MustBeDir() && !d.isDir() { - return nil, nil, syserror.ENOTDIR - } - return vfsd, d.inode, nil -} - -// walkParentDirLocked resolves all but the last path component of rp to an -// existing directory. It does not check that the returned directory is -// searchable by the provider of rp. -// -// walkParentDirLocked is loosely analogous to Linux's -// fs/namei.c:path_parentat(). -// -// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). -// -// Postconditions: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { - vfsd := rp.Start() - for !rp.Final() { - var err error - vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd) - if err != nil { - return nil, nil, err - } - } - d := vfsd.Impl().(*Dentry) - if !d.isDir() { - return nil, nil, syserror.ENOTDIR - } - return vfsd, d.inode, nil -} - -// checkCreateLocked checks that a file named rp.Component() may be created in -// directory parentVFSD, then returns rp.Component(). -// -// Preconditions: Filesystem.mu must be locked for at least reading. parentInode -// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true. -func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) { - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { - return "", err - } - pc := rp.Component() - if pc == "." || pc == ".." { - return "", syserror.EEXIST - } - childVFSD, err := rp.ResolveChild(parentVFSD, pc) - if err != nil { - return "", err - } - if childVFSD != nil { - return "", syserror.EEXIST - } - if parentVFSD.IsDisowned() { - return "", syserror.ENOENT - } - return pc, nil -} - -// checkDeleteLocked checks that the file represented by vfsd may be deleted. -// -// Preconditions: Filesystem.mu must be locked for at least reading. -func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error { - parentVFSD := vfsd.Parent() - if parentVFSD == nil { - return syserror.EBUSY - } - if parentVFSD.IsDisowned() { - return syserror.ENOENT - } - if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { - return err - } - return nil -} - -// Release implements vfs.FilesystemImpl.Release. -func (fs *Filesystem) Release() { -} - -// Sync implements vfs.FilesystemImpl.Sync. -func (fs *Filesystem) Sync(ctx context.Context) error { - // All filesystem state is in-memory. - return nil -} - -// AccessAt implements vfs.Filesystem.Impl.AccessAt. -func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - defer fs.processDeferredDecRefs() - - _, inode, err := fs.walkExistingLocked(ctx, rp) - if err != nil { - return err - } - return inode.CheckPermissions(ctx, creds, ats) -} - -// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. -func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { - fs.mu.RLock() - defer fs.processDeferredDecRefs() - defer fs.mu.RUnlock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) - if err != nil { - return nil, err - } - - if opts.CheckSearchable { - d := vfsd.Impl().(*Dentry) - if !d.isDir() { - return nil, syserror.ENOTDIR - } - if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { - return nil, err - } - } - vfsd.IncRef() // Ownership transferred to caller. - return vfsd, nil -} - -// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. -func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { - fs.mu.RLock() - defer fs.processDeferredDecRefs() - defer fs.mu.RUnlock() - vfsd, _, err := fs.walkParentDirLocked(ctx, rp) - if err != nil { - return nil, err - } - vfsd.IncRef() // Ownership transferred to caller. - return vfsd, nil -} - -// LinkAt implements vfs.FilesystemImpl.LinkAt. -func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - if rp.Done() { - return syserror.EEXIST - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() - if err != nil { - return err - } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) - if err != nil { - return err - } - if rp.Mount() != vd.Mount() { - return syserror.EXDEV - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - - d := vd.Dentry().Impl().(*Dentry) - if d.isDir() { - return syserror.EPERM - } - - child, err := parentInode.NewLink(ctx, pc, d.inode) - if err != nil { - return err - } - parentVFSD.Impl().(*Dentry).InsertChild(pc, child) - return nil -} - -// MkdirAt implements vfs.FilesystemImpl.MkdirAt. -func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - if rp.Done() { - return syserror.EEXIST - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() - if err != nil { - return err - } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - child, err := parentInode.NewDir(ctx, pc, opts) - if err != nil { - return err - } - parentVFSD.Impl().(*Dentry).InsertChild(pc, child) - return nil -} - -// MknodAt implements vfs.FilesystemImpl.MknodAt. -func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - if rp.Done() { - return syserror.EEXIST - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() - if err != nil { - return err - } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - new, err := parentInode.NewNode(ctx, pc, opts) - if err != nil { - return err - } - parentVFSD.Impl().(*Dentry).InsertChild(pc, new) - return nil -} - -// OpenAt implements vfs.FilesystemImpl.OpenAt. -func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - // Filter out flags that are not supported by kernfs. O_DIRECTORY and - // O_NOFOLLOW have no effect here (they're handled by VFS by setting - // appropriate bits in rp), but are returned by - // FileDescriptionImpl.StatusFlags(). - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW - ats := vfs.AccessTypesForOpenFlags(&opts) - - // Do not create new file. - if opts.Flags&linux.O_CREAT == 0 { - fs.mu.RLock() - defer fs.processDeferredDecRefs() - defer fs.mu.RUnlock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) - if err != nil { - return nil, err - } - if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { - return nil, err - } - return inode.Open(rp, vfsd, opts) - } - - // May create new file. - mustCreate := opts.Flags&linux.O_EXCL != 0 - vfsd := rp.Start() - inode := vfsd.Impl().(*Dentry).inode - fs.mu.Lock() - defer fs.mu.Unlock() - if rp.Done() { - if rp.MustBeDir() { - return nil, syserror.EISDIR - } - if mustCreate { - return nil, syserror.EEXIST - } - if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { - return nil, err - } - return inode.Open(rp, vfsd, opts) - } -afterTrailingSymlink: - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() - if err != nil { - return nil, err - } - // Check for search permission in the parent directory. - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { - return nil, err - } - // Reject attempts to open directories with O_CREAT. - if rp.MustBeDir() { - return nil, syserror.EISDIR - } - pc := rp.Component() - if pc == "." || pc == ".." { - return nil, syserror.EISDIR - } - // Determine whether or not we need to create a file. - childVFSD, err := rp.ResolveChild(parentVFSD, pc) - if err != nil { - return nil, err - } - if childVFSD == nil { - // Already checked for searchability above; now check for writability. - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { - return nil, err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return nil, err - } - defer rp.Mount().EndWrite() - // Create and open the child. - child, err := parentInode.NewFile(ctx, pc, opts) - if err != nil { - return nil, err - } - parentVFSD.Impl().(*Dentry).InsertChild(pc, child) - return child.Impl().(*Dentry).inode.Open(rp, child, opts) - } - // Open existing file or follow symlink. - if mustCreate { - return nil, syserror.EEXIST - } - childDentry := childVFSD.Impl().(*Dentry) - childInode := childDentry.inode - if rp.ShouldFollowSymlink() { - if childDentry.isSymlink() { - target, err := childInode.Readlink(ctx) - if err != nil { - return nil, err - } - if err := rp.HandleSymlink(target); err != nil { - return nil, err - } - // rp.Final() may no longer be true since we now need to resolve the - // symlink target. - goto afterTrailingSymlink - } - } - if err := childInode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { - return nil, err - } - return childInode.Open(rp, childVFSD, opts) -} - -// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. -func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - fs.mu.RLock() - d, inode, err := fs.walkExistingLocked(ctx, rp) - fs.mu.RUnlock() - fs.processDeferredDecRefs() - if err != nil { - return "", err - } - if !d.Impl().(*Dentry).isSymlink() { - return "", syserror.EINVAL - } - return inode.Readlink(ctx) -} - -// RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - // Only RENAME_NOREPLACE is supported. - if opts.Flags&^linux.RENAME_NOREPLACE != 0 { - return syserror.EINVAL - } - noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 - - fs.mu.Lock() - defer fs.mu.Lock() - - // Resolve the destination directory first to verify that it's on this - // Mount. - dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() - if err != nil { - return err - } - mnt := rp.Mount() - if mnt != oldParentVD.Mount() { - return syserror.EXDEV - } - if err := mnt.CheckBeginWrite(); err != nil { - return err - } - defer mnt.EndWrite() - - srcDirVFSD := oldParentVD.Dentry() - srcDir := srcDirVFSD.Impl().(*Dentry) - srcDir.dirMu.Lock() - src, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), srcDir, oldName, srcDirVFSD.Child(oldName)) - srcDir.dirMu.Unlock() - fs.processDeferredDecRefsLocked() - if err != nil { - return err - } - srcVFSD := &src.vfsd - - // Can we remove the src dentry? - if err := checkDeleteLocked(ctx, rp, srcVFSD); err != nil { - return err - } - - // Can we create the dst dentry? - var dstVFSD *vfs.Dentry - pc, err := checkCreateLocked(ctx, rp, dstDirVFSD, dstDirInode) - switch err { - case nil: - // Ok, continue with rename as replacement. - case syserror.EEXIST: - if noReplace { - // Won't overwrite existing node since RENAME_NOREPLACE was requested. - return syserror.EEXIST - } - dstVFSD, err = rp.ResolveChild(dstDirVFSD, pc) - if err != nil { - panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD)) - } - default: - return err - } - - mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() - virtfs := rp.VirtualFilesystem() - - srcDirDentry := srcDirVFSD.Impl().(*Dentry) - dstDirDentry := dstDirVFSD.Impl().(*Dentry) - - // We can't deadlock here due to lock ordering because we're protected from - // concurrent renames by fs.mu held for writing. - srcDirDentry.dirMu.Lock() - defer srcDirDentry.dirMu.Unlock() - dstDirDentry.dirMu.Lock() - defer dstDirDentry.dirMu.Unlock() - - if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil { - return err - } - srcDirInode := srcDirDentry.inode - replaced, err := srcDirInode.Rename(ctx, srcVFSD.Name(), pc, srcVFSD, dstDirVFSD) - if err != nil { - virtfs.AbortRenameDentry(srcVFSD, dstVFSD) - return err - } - virtfs.CommitRenameReplaceDentry(srcVFSD, dstDirVFSD, pc, replaced) - return nil -} - -// RmdirAt implements vfs.FilesystemImpl.RmdirAt. -func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { - fs.mu.Lock() - defer fs.mu.Unlock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) - fs.processDeferredDecRefsLocked() - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - if err := checkDeleteLocked(ctx, rp, vfsd); err != nil { - return err - } - if !vfsd.Impl().(*Dentry).isDir() { - return syserror.ENOTDIR - } - if inode.HasChildren() { - return syserror.ENOTEMPTY - } - virtfs := rp.VirtualFilesystem() - parentDentry := vfsd.Parent().Impl().(*Dentry) - parentDentry.dirMu.Lock() - defer parentDentry.dirMu.Unlock() - - mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() - if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { - return err - } - if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil { - virtfs.AbortDeleteDentry(vfsd) - return err - } - virtfs.CommitDeleteDentry(vfsd) - return nil -} - -// SetStatAt implements vfs.FilesystemImpl.SetStatAt. -func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { - fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) - fs.mu.RUnlock() - fs.processDeferredDecRefs() - if err != nil { - return err - } - if opts.Stat.Mask == 0 { - return nil - } - return inode.SetStat(fs.VFSFilesystem(), opts) -} - -// StatAt implements vfs.FilesystemImpl.StatAt. -func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { - fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) - fs.mu.RUnlock() - fs.processDeferredDecRefs() - if err != nil { - return linux.Statx{}, err - } - return inode.Stat(fs.VFSFilesystem(), opts) -} - -// StatFSAt implements vfs.FilesystemImpl.StatFSAt. -func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { - fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) - fs.mu.RUnlock() - fs.processDeferredDecRefs() - if err != nil { - return linux.Statfs{}, err - } - // TODO: actually implement statfs - return linux.Statfs{}, syserror.ENOSYS -} - -// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. -func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - if rp.Done() { - return syserror.EEXIST - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() - if err != nil { - return err - } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - child, err := parentInode.NewSymlink(ctx, pc, target) - if err != nil { - return err - } - parentVFSD.Impl().(*Dentry).InsertChild(pc, child) - return nil -} - -// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. -func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { - fs.mu.Lock() - defer fs.mu.Unlock() - vfsd, _, err := fs.walkExistingLocked(ctx, rp) - fs.processDeferredDecRefsLocked() - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - if err := checkDeleteLocked(ctx, rp, vfsd); err != nil { - return err - } - if vfsd.Impl().(*Dentry).isDir() { - return syserror.EISDIR - } - virtfs := rp.VirtualFilesystem() - parentDentry := vfsd.Parent().Impl().(*Dentry) - parentDentry.dirMu.Lock() - defer parentDentry.dirMu.Unlock() - mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() - if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { - return err - } - if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil { - virtfs.AbortDeleteDentry(vfsd) - return err - } - virtfs.CommitDeleteDentry(vfsd) - return nil -} - -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { - fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) - fs.mu.RUnlock() - fs.processDeferredDecRefs() - if err != nil { - return nil, err - } - // kernfs currently does not support extended attributes. - return nil, syserror.ENOTSUP -} - -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { - fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) - fs.mu.RUnlock() - fs.processDeferredDecRefs() - if err != nil { - return "", err - } - // kernfs currently does not support extended attributes. - return "", syserror.ENOTSUP -} - -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { - fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) - fs.mu.RUnlock() - fs.processDeferredDecRefs() - if err != nil { - return err - } - // kernfs currently does not support extended attributes. - return syserror.ENOTSUP -} - -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { - fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) - fs.mu.RUnlock() - fs.processDeferredDecRefs() - if err != nil { - return err - } - // kernfs currently does not support extended attributes. - return syserror.ENOTSUP -} - -// PrependPath implements vfs.FilesystemImpl.PrependPath. -func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - return vfs.GenericPrependPath(vfsroot, vd, b) -} diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go deleted file mode 100644 index d50018b18..000000000 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ /dev/null @@ -1,556 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernfs - -import ( - "fmt" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" -) - -// InodeNoopRefCount partially implements the Inode interface, specifically the -// inodeRefs sub interface. InodeNoopRefCount implements a simple reference -// count for inodes, performing no extra actions when references are obtained or -// released. This is suitable for simple file inodes that don't reference any -// resources. -type InodeNoopRefCount struct { -} - -// IncRef implements Inode.IncRef. -func (InodeNoopRefCount) IncRef() { -} - -// DecRef implements Inode.DecRef. -func (InodeNoopRefCount) DecRef() { -} - -// TryIncRef implements Inode.TryIncRef. -func (InodeNoopRefCount) TryIncRef() bool { - return true -} - -// Destroy implements Inode.Destroy. -func (InodeNoopRefCount) Destroy() { -} - -// InodeDirectoryNoNewChildren partially implements the Inode interface. -// InodeDirectoryNoNewChildren represents a directory inode which does not -// support creation of new children. -type InodeDirectoryNoNewChildren struct{} - -// NewFile implements Inode.NewFile. -func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { - return nil, syserror.EPERM -} - -// NewDir implements Inode.NewDir. -func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { - return nil, syserror.EPERM -} - -// NewLink implements Inode.NewLink. -func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { - return nil, syserror.EPERM -} - -// NewSymlink implements Inode.NewSymlink. -func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { - return nil, syserror.EPERM -} - -// NewNode implements Inode.NewNode. -func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { - return nil, syserror.EPERM -} - -// InodeNotDirectory partially implements the Inode interface, specifically the -// inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not -// represent directories can embed this to provide no-op implementations for -// directory-related functions. -type InodeNotDirectory struct { -} - -// HasChildren implements Inode.HasChildren. -func (InodeNotDirectory) HasChildren() bool { - return false -} - -// NewFile implements Inode.NewFile. -func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { - panic("NewFile called on non-directory inode") -} - -// NewDir implements Inode.NewDir. -func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { - panic("NewDir called on non-directory inode") -} - -// NewLink implements Inode.NewLinkink. -func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { - panic("NewLink called on non-directory inode") -} - -// NewSymlink implements Inode.NewSymlink. -func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { - panic("NewSymlink called on non-directory inode") -} - -// NewNode implements Inode.NewNode. -func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { - panic("NewNode called on non-directory inode") -} - -// Unlink implements Inode.Unlink. -func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error { - panic("Unlink called on non-directory inode") -} - -// RmDir implements Inode.RmDir. -func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error { - panic("RmDir called on non-directory inode") -} - -// Rename implements Inode.Rename. -func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) { - panic("Rename called on non-directory inode") -} - -// Lookup implements Inode.Lookup. -func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { - panic("Lookup called on non-directory inode") -} - -// IterDirents implements Inode.IterDirents. -func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { - panic("IterDirents called on non-directory inode") -} - -// Valid implements Inode.Valid. -func (InodeNotDirectory) Valid(context.Context) bool { - return true -} - -// InodeNoDynamicLookup partially implements the Inode interface, specifically -// the inodeDynamicLookup sub interface. Directory inodes that do not support -// dymanic entries (i.e. entries that are not "hashed" into the -// vfs.Dentry.children) can embed this to provide no-op implementations for -// functions related to dynamic entries. -type InodeNoDynamicLookup struct{} - -// Lookup implements Inode.Lookup. -func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { - return nil, syserror.ENOENT -} - -// IterDirents implements Inode.IterDirents. -func (InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { - return offset, nil -} - -// Valid implements Inode.Valid. -func (InodeNoDynamicLookup) Valid(ctx context.Context) bool { - return true -} - -// InodeNotSymlink partially implements the Inode interface, specifically the -// inodeSymlink sub interface. All inodes that are not symlinks may embed this -// to return the appropriate errors from symlink-related functions. -type InodeNotSymlink struct{} - -// Readlink implements Inode.Readlink. -func (InodeNotSymlink) Readlink(context.Context) (string, error) { - return "", syserror.EINVAL -} - -// InodeAttrs partially implements the Inode interface, specifically the -// inodeMetadata sub interface. InodeAttrs provides functionality related to -// inode attributes. -// -// Must be initialized by Init prior to first use. -type InodeAttrs struct { - ino uint64 - mode uint32 - uid uint32 - gid uint32 - nlink uint32 -} - -// Init initializes this InodeAttrs. -func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMode) { - if mode.FileType() == 0 { - panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) - } - - nlink := uint32(1) - if mode.FileType() == linux.ModeDirectory { - nlink = 2 - } - atomic.StoreUint64(&a.ino, ino) - atomic.StoreUint32(&a.mode, uint32(mode)) - atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) - atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) - atomic.StoreUint32(&a.nlink, nlink) -} - -// Mode implements Inode.Mode. -func (a *InodeAttrs) Mode() linux.FileMode { - return linux.FileMode(atomic.LoadUint32(&a.mode)) -} - -// Stat partially implements Inode.Stat. Note that this function doesn't provide -// all the stat fields, and the embedder should consider extending the result -// with filesystem-specific fields. -func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { - var stat linux.Statx - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK - stat.Ino = atomic.LoadUint64(&a.ino) - stat.Mode = uint16(a.Mode()) - stat.UID = atomic.LoadUint32(&a.uid) - stat.GID = atomic.LoadUint32(&a.gid) - stat.Nlink = atomic.LoadUint32(&a.nlink) - - // TODO: Implement other stat fields like timestamps. - - return stat, nil -} - -// SetStat implements Inode.SetStat. -func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { - stat := opts.Stat - if stat.Mask&linux.STATX_MODE != 0 { - for { - old := atomic.LoadUint32(&a.mode) - new := old | uint32(stat.Mode & ^uint16(linux.S_IFMT)) - if swapped := atomic.CompareAndSwapUint32(&a.mode, old, new); swapped { - break - } - } - } - - if stat.Mask&linux.STATX_UID != 0 { - atomic.StoreUint32(&a.uid, stat.UID) - } - if stat.Mask&linux.STATX_GID != 0 { - atomic.StoreUint32(&a.gid, stat.GID) - } - - // Note that not all fields are modifiable. For example, the file type and - // inode numbers are immutable after node creation. - - // TODO: Implement other stat fields like timestamps. - - return nil -} - -// CheckPermissions implements Inode.CheckPermissions. -func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { - mode := a.Mode() - return vfs.GenericCheckPermissions( - creds, - ats, - mode.FileType() == linux.ModeDirectory, - uint16(mode), - auth.KUID(atomic.LoadUint32(&a.uid)), - auth.KGID(atomic.LoadUint32(&a.gid)), - ) -} - -// IncLinks implements Inode.IncLinks. -func (a *InodeAttrs) IncLinks(n uint32) { - if atomic.AddUint32(&a.nlink, n) <= n { - panic("InodeLink.IncLinks called with no existing links") - } -} - -// DecLinks implements Inode.DecLinks. -func (a *InodeAttrs) DecLinks() { - if nlink := atomic.AddUint32(&a.nlink, ^uint32(0)); nlink == ^uint32(0) { - // Negative overflow - panic("Inode.DecLinks called at 0 links") - } -} - -type slot struct { - Name string - Dentry *vfs.Dentry - slotEntry -} - -// OrderedChildrenOptions contains initialization options for OrderedChildren. -type OrderedChildrenOptions struct { - // Writable indicates whether vfs.FilesystemImpl methods implemented by - // OrderedChildren may modify the tracked children. This applies to - // operations related to rename, unlink and rmdir. If an OrderedChildren is - // not writable, these operations all fail with EPERM. - Writable bool -} - -// OrderedChildren partially implements the Inode interface. OrderedChildren can -// be embedded in directory inodes to keep track of the children in the -// directory, and can then be used to implement a generic directory FD -- see -// GenericDirectoryFD. OrderedChildren is not compatible with dynamic -// directories. -// -// Must be initialize with Init before first use. -type OrderedChildren struct { - refs.AtomicRefCount - - // Can children be modified by user syscalls? It set to false, interface - // methods that would modify the children return EPERM. Immutable. - writable bool - - mu sync.RWMutex - order slotList - set map[string]*slot -} - -// Init initializes an OrderedChildren. -func (o *OrderedChildren) Init(opts OrderedChildrenOptions) { - o.writable = opts.Writable - o.set = make(map[string]*slot) -} - -// DecRef implements Inode.DecRef. -func (o *OrderedChildren) DecRef() { - o.AtomicRefCount.DecRefWithDestructor(o.Destroy) -} - -// Destroy cleans up resources referenced by this OrderedChildren. -func (o *OrderedChildren) Destroy() { - o.mu.Lock() - defer o.mu.Unlock() - o.order.Reset() - o.set = nil -} - -// Populate inserts children into this OrderedChildren, and d's dentry -// cache. Populate returns the number of directories inserted, which the caller -// may use to update the link count for the parent directory. -// -// Precondition: d.Impl() must be a kernfs Dentry. d must represent a directory -// inode. children must not contain any conflicting entries already in o. -func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 { - var links uint32 - for name, child := range children { - if child.isDir() { - links++ - } - if err := o.Insert(name, child.VFSDentry()); err != nil { - panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d)) - } - d.InsertChild(name, child.VFSDentry()) - } - return links -} - -// HasChildren implements Inode.HasChildren. -func (o *OrderedChildren) HasChildren() bool { - o.mu.RLock() - defer o.mu.RUnlock() - return len(o.set) > 0 -} - -// Insert inserts child into o. This ignores the writability of o, as this is -// not part of the vfs.FilesystemImpl interface, and is a lower-level operation. -func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error { - o.mu.Lock() - defer o.mu.Unlock() - if _, ok := o.set[name]; ok { - return syserror.EEXIST - } - s := &slot{ - Name: name, - Dentry: child, - } - o.order.PushBack(s) - o.set[name] = s - return nil -} - -// Precondition: caller must hold o.mu for writing. -func (o *OrderedChildren) removeLocked(name string) { - if s, ok := o.set[name]; ok { - delete(o.set, name) - o.order.Remove(s) - } -} - -// Precondition: caller must hold o.mu for writing. -func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry { - if s, ok := o.set[name]; ok { - // Existing slot with given name, simply replace the dentry. - var old *vfs.Dentry - old, s.Dentry = s.Dentry, new - return old - } - - // No existing slot with given name, create and hash new slot. - s := &slot{ - Name: name, - Dentry: new, - } - o.order.PushBack(s) - o.set[name] = s - return nil -} - -// Precondition: caller must hold o.mu for reading or writing. -func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error { - s, ok := o.set[name] - if !ok { - return syserror.ENOENT - } - if s.Dentry != child { - panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child)) - } - return nil -} - -// Unlink implements Inode.Unlink. -func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error { - if !o.writable { - return syserror.EPERM - } - o.mu.Lock() - defer o.mu.Unlock() - if err := o.checkExistingLocked(name, child); err != nil { - return err - } - o.removeLocked(name) - return nil -} - -// Rmdir implements Inode.Rmdir. -func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error { - // We're not responsible for checking that child is a directory, that it's - // empty, or updating any link counts; so this is the same as unlink. - return o.Unlink(ctx, name, child) -} - -type renameAcrossDifferentImplementationsError struct{} - -func (renameAcrossDifferentImplementationsError) Error() string { - return "rename across inodes with different implementations" -} - -// Rename implements Inode.Rename. -// -// Precondition: Rename may only be called across two directory inodes with -// identical implementations of Rename. Practically, this means filesystems that -// implement Rename by embedding OrderedChildren for any directory -// implementation must use OrderedChildren for all directory implementations -// that will support Rename. -// -// Postcondition: reference on any replaced dentry transferred to caller. -func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) { - dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren) - if !ok { - return nil, renameAcrossDifferentImplementationsError{} - } - if !o.writable || !dst.writable { - return nil, syserror.EPERM - } - // Note: There's a potential deadlock below if concurrent calls to Rename - // refer to the same src and dst directories in reverse. We avoid any - // ordering issues because the caller is required to serialize concurrent - // calls to Rename in accordance with the interface declaration. - o.mu.Lock() - defer o.mu.Unlock() - if dst != o { - dst.mu.Lock() - defer dst.mu.Unlock() - } - if err := o.checkExistingLocked(oldname, child); err != nil { - return nil, err - } - replaced := dst.replaceChildLocked(newname, child) - return replaced, nil -} - -// nthLocked returns an iterator to the nth child tracked by this object. The -// iterator is valid until the caller releases o.mu. Returns nil if the -// requested index falls out of bounds. -// -// Preconditon: Caller must hold o.mu for reading. -func (o *OrderedChildren) nthLocked(i int64) *slot { - for it := o.order.Front(); it != nil && i >= 0; it = it.Next() { - if i == 0 { - return it - } - i-- - } - return nil -} - -// InodeSymlink partially implements Inode interface for symlinks. -type InodeSymlink struct { - InodeNotDirectory -} - -// Open implements Inode.Open. -func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - return nil, syserror.ELOOP -} - -// StaticDirectory is a standard implementation of a directory with static -// contents. -// -// +stateify savable -type StaticDirectory struct { - InodeNotSymlink - InodeDirectoryNoNewChildren - InodeAttrs - InodeNoDynamicLookup - OrderedChildren -} - -var _ Inode = (*StaticDirectory)(nil) - -// NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry { - inode := &StaticDirectory{} - inode.Init(creds, ino, perm) - - dentry := &Dentry{} - dentry.Init(inode) - - inode.OrderedChildren.Init(OrderedChildrenOptions{}) - links := inode.OrderedChildren.Populate(dentry, children) - inode.IncLinks(links) - - return dentry -} - -// Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, ino uint64, perm linux.FileMode) { - if perm&^linux.PermissionsMask != 0 { - panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) - } - s.InodeAttrs.Init(creds, ino, linux.ModeDirectory|perm) -} - -// Open implements kernfs.Inode. -func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &GenericDirectoryFD{} - fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, &opts) - return fd.VFSFileDescription(), nil -} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go deleted file mode 100644 index a8ab2a2ba..000000000 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ /dev/null @@ -1,421 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package kernfs provides the tools to implement inode-based filesystems. -// Kernfs has two main features: -// -// 1. The Inode interface, which maps VFS2's path-based filesystem operations to -// specific filesystem nodes. Kernfs uses the Inode interface to provide a -// blanket implementation for the vfs.FilesystemImpl. Kernfs also serves as -// the synchronization mechanism for all filesystem operations by holding a -// filesystem-wide lock across all operations. -// -// 2. Various utility types which provide generic implementations for various -// parts of the Inode and vfs.FileDescription interfaces. Client filesystems -// based on kernfs can embed the appropriate set of these to avoid having to -// reimplement common filesystem operations. See inode_impl_util.go and -// fd_impl_util.go. -// -// Reference Model: -// -// Kernfs dentries represents named pointers to inodes. Dentries and inode have -// independent lifetimes and reference counts. A child dentry unconditionally -// holds a reference on its parent directory's dentry. A dentry also holds a -// reference on the inode it points to. Multiple dentries can point to the same -// inode (for example, in the case of hardlinks). File descriptors hold a -// reference to the dentry they're opened on. -// -// Dentries are guaranteed to exist while holding Filesystem.mu for -// reading. Dropping dentries require holding Filesystem.mu for writing. To -// queue dentries for destruction from a read critical section, see -// Filesystem.deferDecRef. -// -// Lock ordering: -// -// kernfs.Filesystem.mu -// kernfs.Dentry.dirMu -// vfs.VirtualFilesystem.mountMu -// vfs.Dentry.mu -// kernfs.Filesystem.droppedDentriesMu -// (inode implementation locks, if any) -package kernfs - -import ( - "fmt" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" -) - -// FilesystemType implements vfs.FilesystemType. -type FilesystemType struct{} - -// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory -// filesystem. Concrete implementations are expected to embed this in their own -// Filesystem type. -type Filesystem struct { - vfsfs vfs.Filesystem - - droppedDentriesMu sync.Mutex - - // droppedDentries is a list of dentries waiting to be DecRef()ed. This is - // used to defer dentry destruction until mu can be acquired for - // writing. Protected by droppedDentriesMu. - droppedDentries []*vfs.Dentry - - // mu synchronizes the lifetime of Dentries on this filesystem. Holding it - // for reading guarantees continued existence of any resolved dentries, but - // the dentry tree may be modified. - // - // Kernfs dentries can only be DecRef()ed while holding mu for writing. For - // example: - // - // fs.mu.Lock() - // defer fs.mu.Unlock() - // ... - // dentry1.DecRef() - // defer dentry2.DecRef() // Ok, will run before Unlock. - // - // If discarding dentries in a read context, use Filesystem.deferDecRef. For - // example: - // - // fs.mu.RLock() - // fs.mu.processDeferredDecRefs() - // defer fs.mu.RUnlock() - // ... - // fs.deferDecRef(dentry) - mu sync.RWMutex - - // nextInoMinusOne is used to to allocate inode numbers on this - // filesystem. Must be accessed by atomic operations. - nextInoMinusOne uint64 -} - -// deferDecRef defers dropping a dentry ref until the next call to -// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu. -// -// Precondition: d must not already be pending destruction. -func (fs *Filesystem) deferDecRef(d *vfs.Dentry) { - fs.droppedDentriesMu.Lock() - fs.droppedDentries = append(fs.droppedDentries, d) - fs.droppedDentriesMu.Unlock() -} - -// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the -// droppedDentries list. See comment on Filesystem.mu. -func (fs *Filesystem) processDeferredDecRefs() { - fs.mu.Lock() - fs.processDeferredDecRefsLocked() - fs.mu.Unlock() -} - -// Precondition: fs.mu must be held for writing. -func (fs *Filesystem) processDeferredDecRefsLocked() { - fs.droppedDentriesMu.Lock() - for _, d := range fs.droppedDentries { - d.DecRef() - } - fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse. - fs.droppedDentriesMu.Unlock() -} - -// Init initializes a kernfs filesystem. This should be called from during -// vfs.FilesystemType.NewFilesystem for the concrete filesystem embedding -// kernfs. -func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem) { - fs.vfsfs.Init(vfsObj, fs) -} - -// VFSFilesystem returns the generic vfs filesystem object. -func (fs *Filesystem) VFSFilesystem() *vfs.Filesystem { - return &fs.vfsfs -} - -// NextIno allocates a new inode number on this filesystem. -func (fs *Filesystem) NextIno() uint64 { - return atomic.AddUint64(&fs.nextInoMinusOne, 1) -} - -// These consts are used in the Dentry.flags field. -const ( - // Dentry points to a directory inode. - dflagsIsDir = 1 << iota - - // Dentry points to a symlink inode. - dflagsIsSymlink -) - -// Dentry implements vfs.DentryImpl. -// -// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a -// named reference to an inode. A dentry generally lives as long as it's part of -// a mounted filesystem tree. Kernfs doesn't cache dentries once all references -// to them are removed. Dentries hold a single reference to the inode they point -// to, and child dentries hold a reference on their parent. -// -// Must be initialized by Init prior to first use. -type Dentry struct { - refs.AtomicRefCount - - vfsd vfs.Dentry - inode Inode - - // flags caches useful information about the dentry from the inode. See the - // dflags* consts above. Must be accessed by atomic ops. - flags uint32 - - // dirMu protects vfsd.children for directory dentries. - dirMu sync.Mutex -} - -// Init initializes this dentry. -// -// Precondition: Caller must hold a reference on inode. -// -// Postcondition: Caller's reference on inode is transferred to the dentry. -func (d *Dentry) Init(inode Inode) { - d.vfsd.Init(d) - d.inode = inode - ftype := inode.Mode().FileType() - if ftype == linux.ModeDirectory { - d.flags |= dflagsIsDir - } - if ftype == linux.ModeSymlink { - d.flags |= dflagsIsSymlink - } -} - -// VFSDentry returns the generic vfs dentry for this kernfs dentry. -func (d *Dentry) VFSDentry() *vfs.Dentry { - return &d.vfsd -} - -// isDir checks whether the dentry points to a directory inode. -func (d *Dentry) isDir() bool { - return atomic.LoadUint32(&d.flags)&dflagsIsDir != 0 -} - -// isSymlink checks whether the dentry points to a symlink inode. -func (d *Dentry) isSymlink() bool { - return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 -} - -// DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef() { - d.AtomicRefCount.DecRefWithDestructor(d.destroy) -} - -// Precondition: Dentry must be removed from VFS' dentry cache. -func (d *Dentry) destroy() { - d.inode.DecRef() // IncRef from Init. - d.inode = nil - if parent := d.vfsd.Parent(); parent != nil { - parent.DecRef() // IncRef from Dentry.InsertChild. - } -} - -// InsertChild inserts child into the vfs dentry cache with the given name under -// this dentry. This does not update the directory inode, so calling this on -// it's own isn't sufficient to insert a child into a directory. InsertChild -// updates the link count on d if required. -// -// Precondition: d must represent a directory inode. -func (d *Dentry) InsertChild(name string, child *vfs.Dentry) { - d.dirMu.Lock() - d.insertChildLocked(name, child) - d.dirMu.Unlock() -} - -// insertChildLocked is equivalent to InsertChild, with additional -// preconditions. -// -// Precondition: d.dirMu must be locked. -func (d *Dentry) insertChildLocked(name string, child *vfs.Dentry) { - if !d.isDir() { - panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d)) - } - vfsDentry := d.VFSDentry() - vfsDentry.IncRef() // DecRef in child's Dentry.destroy. - vfsDentry.InsertChild(child, name) -} - -// The Inode interface maps filesystem-level operations that operate on paths to -// equivalent operations on specific filesystem nodes. -// -// The interface methods are groups into logical categories as sub interfaces -// below. Generally, an implementation for each sub interface can be provided by -// embedding an appropriate type from inode_impl_utils.go. The sub interfaces -// are purely organizational. Methods declared directly in the main interface -// have no generic implementations, and should be explicitly provided by the -// client filesystem. -// -// Generally, implementations are not responsible for tasks that are common to -// all filesystems. These include: -// -// - Checking that dentries passed to methods are of the appropriate file type. -// - Checking permissions. -// - Updating link and reference counts. -// -// Specific responsibilities of implementations are documented below. -type Inode interface { - // Methods related to reference counting. A generic implementation is - // provided by InodeNoopRefCount. These methods are generally called by the - // equivalent Dentry methods. - inodeRefs - - // Methods related to node metadata. A generic implementation is provided by - // InodeAttrs. - inodeMetadata - - // Method for inodes that represent symlink. InodeNotSymlink provides a - // blanket implementation for all non-symlink inodes. - inodeSymlink - - // Method for inodes that represent directories. InodeNotDirectory provides - // a blanket implementation for all non-directory inodes. - inodeDirectory - - // Method for inodes that represent dynamic directories and their - // children. InodeNoDynamicLookup provides a blanket implementation for all - // non-dynamic-directory inodes. - inodeDynamicLookup - - // Open creates a file description for the filesystem object represented by - // this inode. The returned file description should hold a reference on the - // inode for its lifetime. - // - // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing - // the inode on which Open() is being called. - Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) -} - -type inodeRefs interface { - IncRef() - DecRef() - TryIncRef() bool - // Destroy is called when the inode reaches zero references. Destroy release - // all resources (references) on objects referenced by the inode, including - // any child dentries. - Destroy() -} - -type inodeMetadata interface { - // CheckPermissions checks that creds may access this inode for the - // requested access type, per the the rules of - // fs/namei.c:generic_permission(). - CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error - - // Mode returns the (struct stat)::st_mode value for this inode. This is - // separated from Stat for performance. - Mode() linux.FileMode - - // Stat returns the metadata for this inode. This corresponds to - // vfs.FilesystemImpl.StatAt. - Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) - - // SetStat updates the metadata for this inode. This corresponds to - // vfs.FilesystemImpl.SetStatAt. - SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error -} - -// Precondition: All methods in this interface may only be called on directory -// inodes. -type inodeDirectory interface { - // The New{File,Dir,Node,Symlink} methods below should return a new inode - // hashed into this inode. - // - // These inode constructors are inode-level operations rather than - // filesystem-level operations to allow client filesystems to mix different - // implementations based on the new node's location in the - // filesystem. - - // HasChildren returns true if the directory inode has any children. - HasChildren() bool - - // NewFile creates a new regular file inode. - NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) - - // NewDir creates a new directory inode. - NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) - - // NewLink creates a new hardlink to a specified inode in this - // directory. Implementations should create a new kernfs Dentry pointing to - // target, and update target's link count. - NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error) - - // NewSymlink creates a new symbolic link inode. - NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error) - - // NewNode creates a new filesystem node for a mknod syscall. - NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error) - - // Unlink removes a child dentry from this directory inode. - Unlink(ctx context.Context, name string, child *vfs.Dentry) error - - // RmDir removes an empty child directory from this directory - // inode. Implementations must update the parent directory's link count, - // if required. Implementations are not responsible for checking that child - // is a directory, checking for an empty directory. - RmDir(ctx context.Context, name string, child *vfs.Dentry) error - - // Rename is called on the source directory containing an inode being - // renamed. child should point to the resolved child in the source - // directory. If Rename replaces a dentry in the destination directory, it - // should return the replaced dentry or nil otherwise. - // - // Precondition: Caller must serialize concurrent calls to Rename. - Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error) -} - -type inodeDynamicLookup interface { - // Lookup should return an appropriate dentry if name should resolve to a - // child of this dynamic directory inode. This gives the directory an - // opportunity on every lookup to resolve additional entries that aren't - // hashed into the directory. This is only called when the inode is a - // directory. If the inode is not a directory, or if the directory only - // contains a static set of children, the implementer can unconditionally - // return an appropriate error (ENOTDIR and ENOENT respectively). - // - // The child returned by Lookup will be hashed into the VFS dentry tree. Its - // lifetime can be controlled by the filesystem implementation with an - // appropriate implementation of Valid. - // - // Lookup returns the child with an extra reference and the caller owns this - // reference. - Lookup(ctx context.Context, name string) (*vfs.Dentry, error) - - // Valid should return true if this inode is still valid, or needs to - // be resolved again by a call to Lookup. - Valid(ctx context.Context) bool - - // IterDirents is used to iterate over dynamically created entries. It invokes - // cb on each entry in the directory represented by the FileDescription. - // 'offset' is the offset for the entire IterDirents call, which may include - // results from the caller. 'relOffset' is the offset inside the entries - // returned by this IterDirents invocation. In other words, - // 'offset+relOffset+1' is the value that should be set in vfs.Dirent.NextOff, - // while 'relOffset' is the place where iteration should start from. - IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) -} - -type inodeSymlink interface { - // Readlink resolves the target of a symbolic link. If an inode is not a - // symlink, the implementation should return EINVAL. - Readlink(ctx context.Context) (string, error) -} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go deleted file mode 100644 index 0459fb305..000000000 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernfs_test - -import ( - "bytes" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -const defaultMode linux.FileMode = 01777 -const staticFileContent = "This is sample content for a static test file." - -// RootDentryFn is a generator function for creating the root dentry of a test -// filesystem. See newTestSystem. -type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry - -// newTestSystem sets up a minimal environment for running a test, including an -// instance of a test filesystem. Tests can control the contents of the -// filesystem by providing an appropriate rootFn, which should return a -// pre-populated root dentry. -func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System { - ctx := contexttest.Context(t) - creds := auth.CredentialsFromContext(ctx) - v := &vfs.VirtualFilesystem{} - if err := v.Init(); err != nil { - t.Fatalf("VFS init: %v", err) - } - v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{}) - if err != nil { - t.Fatalf("Failed to create testfs root mount: %v", err) - } - return testutil.NewSystem(ctx, t, v, mns) -} - -type fsType struct { - rootFn RootDentryFn -} - -type filesystem struct { - kernfs.Filesystem -} - -type file struct { - kernfs.DynamicBytesFile - content string -} - -func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry { - f := &file{} - f.content = content - f.DynamicBytesFile.Init(creds, fs.NextIno(), f, 0777) - - d := &kernfs.Dentry{} - d.Init(f) - return d -} - -func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "%s", f.content) - return nil -} - -type attrs struct { - kernfs.InodeAttrs -} - -func (a *attrs) SetStat(fs *vfs.Filesystem, opt vfs.SetStatOptions) error { - return syserror.EPERM -} - -type readonlyDir struct { - attrs - kernfs.InodeNotSymlink - kernfs.InodeNoDynamicLookup - kernfs.InodeDirectoryNoNewChildren - - kernfs.OrderedChildren - dentry kernfs.Dentry -} - -func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { - dir := &readonlyDir{} - dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) - dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - dir.dentry.Init(dir) - - dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) - - return &dir.dentry -} - -func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &kernfs.GenericDirectoryFD{} - if err := fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts); err != nil { - return nil, err - } - return fd.VFSFileDescription(), nil -} - -type dir struct { - attrs - kernfs.InodeNotSymlink - kernfs.InodeNoDynamicLookup - - fs *filesystem - dentry kernfs.Dentry - kernfs.OrderedChildren -} - -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { - dir := &dir{} - dir.fs = fs - dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) - dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) - dir.dentry.Init(dir) - - dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) - - return &dir.dentry -} - -func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &kernfs.GenericDirectoryFD{} - fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts) - return fd.VFSFileDescription(), nil -} - -func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { - creds := auth.CredentialsFromContext(ctx) - dir := d.fs.newDir(creds, opts.Mode, nil) - dirVFSD := dir.VFSDentry() - if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil { - dir.DecRef() - return nil, err - } - d.IncLinks(1) - return dirVFSD, nil -} - -func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) { - creds := auth.CredentialsFromContext(ctx) - f := d.fs.newFile(creds, "") - fVFSD := f.VFSDentry() - if err := d.OrderedChildren.Insert(name, fVFSD); err != nil { - f.DecRef() - return nil, err - } - return fVFSD, nil -} - -func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) { - return nil, syserror.EPERM -} - -func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { - return nil, syserror.EPERM -} - -func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { - return nil, syserror.EPERM -} - -func (fst *fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - fs := &filesystem{} - fs.Init(vfsObj) - root := fst.rootFn(creds, fs) - return fs.VFSFilesystem(), root.VFSDentry(), nil -} - -// -------------------- Remainder of the file are test cases -------------------- - -func TestBasic(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { - return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ - "file1": fs.newFile(creds, staticFileContent), - }) - }) - defer sys.Destroy() - sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef() -} - -func TestMkdirGetDentry(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { - return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ - "dir1": fs.newDir(creds, 0755, nil), - }) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("dir1/a new directory") - if err := sys.VFS.MkdirAt(sys.Ctx, sys.Creds, pop, &vfs.MkdirOptions{Mode: 0755}); err != nil { - t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err) - } - sys.GetDentryOrDie(pop).DecRef() -} - -func TestReadStaticFile(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { - return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ - "file1": fs.newFile(creds, staticFileContent), - }) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("file1") - fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) - } - defer fd.DecRef() - - content, err := sys.ReadToEnd(fd) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - if diff := cmp.Diff(staticFileContent, content); diff != "" { - t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff) - } -} - -func TestCreateNewFileInStaticDir(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { - return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ - "dir1": fs.newDir(creds, 0755, nil), - }) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("dir1/newfile") - opts := &vfs.OpenOptions{Flags: linux.O_CREAT | linux.O_EXCL, Mode: defaultMode} - fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, opts) - if err != nil { - t.Fatalf("OpenAt(pop:%+v, opts:%+v) failed: %v", pop, opts, err) - } - - // Close the file. The file should persist. - fd.DecRef() - - fd, err = sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) - } - fd.DecRef() -} - -func TestDirFDReadWrite(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { - return fs.newReadonlyDir(creds, 0755, nil) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("/") - fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) - } - defer fd.DecRef() - - // Read/Write should fail for directory FDs. - if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR { - t.Fatalf("Read for directory FD failed with unexpected error: %v", err) - } - if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EBADF { - t.Fatalf("Write for directory FD failed with unexpected error: %v", err) - } -} - -func TestDirFDIterDirents(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { - return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ - // Fill root with nodes backed by various inode implementations. - "dir1": fs.newReadonlyDir(creds, 0755, nil), - "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{ - "dir3": fs.newDir(creds, 0755, nil), - }), - "file1": fs.newFile(creds, staticFileContent), - }) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("/") - sys.AssertAllDirentTypes(sys.ListDirents(pop), map[string]testutil.DirentType{ - "dir1": linux.DT_DIR, - "dir2": linux.DT_DIR, - "file1": linux.DT_REG, - }) -} diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go deleted file mode 100644 index 0ee7eb9b7..000000000 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" -) - -// StaticSymlink provides an Inode implementation for symlinks that point to -// a immutable target. -type StaticSymlink struct { - InodeAttrs - InodeNoopRefCount - InodeSymlink - - target string -} - -var _ Inode = (*StaticSymlink)(nil) - -// NewStaticSymlink creates a new symlink file pointing to 'target'. -func NewStaticSymlink(creds *auth.Credentials, ino uint64, target string) *Dentry { - inode := &StaticSymlink{} - inode.Init(creds, ino, target) - - d := &Dentry{} - d.Init(inode) - return d -} - -// Init initializes the instance. -func (s *StaticSymlink) Init(creds *auth.Credentials, ino uint64, target string) { - s.target = target - s.InodeAttrs.Init(creds, ino, linux.ModeSymlink|0777) -} - -// Readlink implements Inode. -func (s *StaticSymlink) Readlink(_ context.Context) (string, error) { - return s.target, nil -} diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD deleted file mode 100644 index bb609a305..000000000 --- a/pkg/sentry/fsimpl/proc/BUILD +++ /dev/null @@ -1,63 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "proc", - srcs = [ - "filesystem.go", - "subtasks.go", - "task.go", - "task_files.go", - "task_net.go", - "tasks.go", - "tasks_files.go", - "tasks_sys.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/safemem", - "//pkg/sentry/fs", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/mm", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/tcpip/header", - "//pkg/usermem", - ], -) - -go_test( - name = "proc_test", - size = "small", - srcs = [ - "tasks_sys_test.go", - "tasks_test.go", - ], - library = ":proc", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/sentry/contexttest", - "//pkg/sentry/fsimpl/testutil", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go deleted file mode 100644 index 5c19d5522..000000000 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package proc implements a partial in-memory file system for procfs. -package proc - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -// Name is the default filesystem name. -const Name = "proc" - -// FilesystemType is the factory class for procfs. -// -// +stateify savable -type FilesystemType struct{} - -var _ vfs.FilesystemType = (*FilesystemType)(nil) - -// GetFilesystem implements vfs.FilesystemType. -func (ft *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - k := kernel.KernelFromContext(ctx) - if k == nil { - return nil, nil, fmt.Errorf("procfs requires a kernel") - } - pidns := kernel.PIDNamespaceFromContext(ctx) - if pidns == nil { - return nil, nil, fmt.Errorf("procfs requires a PID namespace") - } - - procfs := &kernfs.Filesystem{} - procfs.VFSFilesystem().Init(vfsObj, procfs) - - var cgroups map[string]string - if opts.InternalData != nil { - data := opts.InternalData.(*InternalData) - cgroups = data.Cgroups - } - - _, dentry := newTasksInode(procfs, k, pidns, cgroups) - return procfs.VFSFilesystem(), dentry.VFSDentry(), nil -} - -// dynamicInode is an overfitted interface for common Inodes with -// dynamicByteSource types used in procfs. -type dynamicInode interface { - kernfs.Inode - vfs.DynamicBytesSource - - Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) -} - -func newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry { - inode.Init(creds, ino, inode, perm) - - d := &kernfs.Dentry{} - d.Init(inode) - return d -} - -type staticFile struct { - kernfs.DynamicBytesFile - vfs.StaticData -} - -var _ dynamicInode = (*staticFile)(nil) - -func newStaticFile(data string) *staticFile { - return &staticFile{StaticData: vfs.StaticData{Data: data}} -} - -// InternalData contains internal data passed in to the procfs mount via -// vfs.GetFilesystemOptions.InternalData. -type InternalData struct { - Cgroups map[string]string -} diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go deleted file mode 100644 index 611645f3f..000000000 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "sort" - "strconv" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// subtasksInode represents the inode for /proc/[pid]/task/ directory. -// -// +stateify savable -type subtasksInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren - kernfs.InodeAttrs - kernfs.OrderedChildren - - task *kernel.Task - pidns *kernel.PIDNamespace - inoGen InoGenerator - cgroupControllers map[string]string -} - -var _ kernfs.Inode = (*subtasksInode)(nil) - -func newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, inoGen InoGenerator, cgroupControllers map[string]string) *kernfs.Dentry { - subInode := &subtasksInode{ - task: task, - pidns: pidns, - inoGen: inoGen, - cgroupControllers: cgroupControllers, - } - // Note: credentials are overridden by taskOwnedInode. - subInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) - subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - - inode := &taskOwnedInode{Inode: subInode, owner: task} - dentry := &kernfs.Dentry{} - dentry.Init(inode) - - return dentry -} - -// Valid implements kernfs.inodeDynamicLookup. -func (i *subtasksInode) Valid(ctx context.Context) bool { - return true -} - -// Lookup implements kernfs.inodeDynamicLookup. -func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { - tid, err := strconv.ParseUint(name, 10, 32) - if err != nil { - return nil, syserror.ENOENT - } - - subTask := i.pidns.TaskWithID(kernel.ThreadID(tid)) - if subTask == nil { - return nil, syserror.ENOENT - } - if subTask.ThreadGroup() != i.task.ThreadGroup() { - return nil, syserror.ENOENT - } - - subTaskDentry := newTaskInode(i.inoGen, subTask, i.pidns, false, i.cgroupControllers) - return subTaskDentry.VFSDentry(), nil -} - -// IterDirents implements kernfs.inodeDynamicLookup. -func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { - tasks := i.task.ThreadGroup().MemberIDs(i.pidns) - if len(tasks) == 0 { - return offset, syserror.ENOENT - } - - tids := make([]int, 0, len(tasks)) - for _, tid := range tasks { - tids = append(tids, int(tid)) - } - - sort.Ints(tids) - for _, tid := range tids[relOffset:] { - dirent := vfs.Dirent{ - Name: strconv.FormatUint(uint64(tid), 10), - Type: linux.DT_DIR, - Ino: i.inoGen.NextIno(), - NextOff: offset + 1, - } - if err := cb.Handle(dirent); err != nil { - return offset, err - } - offset++ - } - return offset, nil -} - -// Open implements kernfs.Inode. -func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &kernfs.GenericDirectoryFD{} - fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts) - return fd.VFSFileDescription(), nil -} - -// Stat implements kernfs.Inode. -func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) - if err != nil { - return linux.Statx{}, err - } - if opts.Mask&linux.STATX_NLINK != 0 { - stat.Nlink += uint32(i.task.ThreadGroup().Count()) - } - return stat, nil -} diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go deleted file mode 100644 index 493acbd1b..000000000 --- a/pkg/sentry/fsimpl/proc/task.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "bytes" - "fmt" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// taskInode represents the inode for /proc/PID/ directory. -// -// +stateify savable -type taskInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren - kernfs.InodeNoDynamicLookup - kernfs.InodeAttrs - kernfs.OrderedChildren - - task *kernel.Task -} - -var _ kernfs.Inode = (*taskInode)(nil) - -func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry { - contents := map[string]*kernfs.Dentry{ - "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}), - "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), - "comm": newComm(task, inoGen.NextIno(), 0444), - "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), - //"exe": newExe(t, msrc), - //"fd": newFdDir(t, msrc), - //"fdinfo": newFdInfoDir(t, msrc), - "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}), - "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)), - "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}), - //"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), - //"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), - "net": newTaskNetDir(task, inoGen), - "ns": newTaskOwnedDir(task, inoGen.NextIno(), 0511, map[string]*kernfs.Dentry{ - "net": newNamespaceSymlink(task, inoGen.NextIno(), "net"), - "pid": newNamespaceSymlink(task, inoGen.NextIno(), "pid"), - "user": newNamespaceSymlink(task, inoGen.NextIno(), "user"), - }), - "oom_score": newTaskOwnedFile(task, inoGen.NextIno(), 0444, newStaticFile("0\n")), - "oom_score_adj": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &oomScoreAdj{task: task}), - "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}), - "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), - "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}), - "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}), - "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}), - } - if isThreadGroup { - contents["task"] = newSubtasks(task, pidns, inoGen, cgroupControllers) - } - if len(cgroupControllers) > 0 { - contents["cgroup"] = newTaskOwnedFile(task, inoGen.NextIno(), 0444, newCgroupData(cgroupControllers)) - } - - taskInode := &taskInode{task: task} - // Note: credentials are overridden by taskOwnedInode. - taskInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) - - inode := &taskOwnedInode{Inode: taskInode, owner: task} - dentry := &kernfs.Dentry{} - dentry.Init(inode) - - taskInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - links := taskInode.OrderedChildren.Populate(dentry, contents) - taskInode.IncLinks(links) - - return dentry -} - -// Valid implements kernfs.inodeDynamicLookup. This inode remains valid as long -// as the task is still running. When it's dead, another tasks with the same -// PID could replace it. -func (i *taskInode) Valid(ctx context.Context) bool { - return i.task.ExitState() != kernel.TaskExitDead -} - -// Open implements kernfs.Inode. -func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &kernfs.GenericDirectoryFD{} - fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts) - return fd.VFSFileDescription(), nil -} - -// SetStat implements kernfs.Inode. -func (i *taskInode) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { - stat := opts.Stat - if stat.Mask&linux.STATX_MODE != 0 { - return syserror.EPERM - } - return nil -} - -// taskOwnedInode implements kernfs.Inode and overrides inode owner with task -// effective user and group. -type taskOwnedInode struct { - kernfs.Inode - - // owner is the task that owns this inode. - owner *kernel.Task -} - -var _ kernfs.Inode = (*taskOwnedInode)(nil) - -func newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry { - // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), ino, inode, perm) - - taskInode := &taskOwnedInode{Inode: inode, owner: task} - d := &kernfs.Dentry{} - d.Init(taskInode) - return d -} - -func newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry { - dir := &kernfs.StaticDirectory{} - - // Note: credentials are overridden by taskOwnedInode. - dir.Init(task.Credentials(), ino, perm) - - inode := &taskOwnedInode{Inode: dir, owner: task} - d := &kernfs.Dentry{} - d.Init(inode) - - dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - links := dir.OrderedChildren.Populate(d, children) - dir.IncLinks(links) - - return d -} - -// Stat implements kernfs.Inode. -func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.Inode.Stat(fs, opts) - if err != nil { - return linux.Statx{}, err - } - if opts.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 { - uid, gid := i.getOwner(linux.FileMode(stat.Mode)) - if opts.Mask&linux.STATX_UID != 0 { - stat.UID = uint32(uid) - } - if opts.Mask&linux.STATX_GID != 0 { - stat.GID = uint32(gid) - } - } - return stat, nil -} - -// CheckPermissions implements kernfs.Inode. -func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { - mode := i.Mode() - uid, gid := i.getOwner(mode) - return vfs.GenericCheckPermissions( - creds, - ats, - mode.FileType() == linux.ModeDirectory, - uint16(mode), - uid, - gid, - ) -} - -func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) { - // By default, set the task owner as the file owner. - creds := i.owner.Credentials() - uid := creds.EffectiveKUID - gid := creds.EffectiveKGID - - // Linux doesn't apply dumpability adjustments to world readable/executable - // directories so that applications can stat /proc/PID to determine the - // effective UID of a process. See fs/proc/base.c:task_dump_owner. - if mode.FileType() == linux.ModeDirectory && mode.Permissions() == 0555 { - return uid, gid - } - - // If the task is not dumpable, then root (in the namespace preferred) - // owns the file. - m := getMM(i.owner) - if m == nil { - return auth.RootKUID, auth.RootKGID - } - if m.Dumpability() != mm.UserDumpable { - uid = auth.RootKUID - if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() { - uid = kuid - } - gid = auth.RootKGID - if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() { - gid = kgid - } - } - return uid, gid -} - -func newIO(t *kernel.Task, isThreadGroup bool) *ioData { - if isThreadGroup { - return &ioData{ioUsage: t.ThreadGroup()} - } - return &ioData{ioUsage: t} -} - -func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry { - // Namespace symlinks should contain the namespace name and the inode number - // for the namespace instance, so for example user:[123456]. We currently fake - // the inode number by sticking the symlink inode in its place. - target := fmt.Sprintf("%s:[%d]", ns, ino) - - inode := &kernfs.StaticSymlink{} - // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), ino, target) - - taskInode := &taskOwnedInode{Inode: inode, owner: task} - d := &kernfs.Dentry{} - d.Init(taskInode) - return d -} - -// newCgroupData creates inode that shows cgroup information. -// From man 7 cgroups: "For each cgroup hierarchy of which the process is a -// member, there is one entry containing three colon-separated fields: -// hierarchy-ID:controller-list:cgroup-path" -func newCgroupData(controllers map[string]string) dynamicInode { - var buf bytes.Buffer - - // The hierarchy ids must be positive integers (for cgroup v1), but the - // exact number does not matter, so long as they are unique. We can - // just use a counter, but since linux sorts this file in descending - // order, we must count down to preserve this behavior. - i := len(controllers) - for name, dir := range controllers { - fmt.Fprintf(&buf, "%d:%s:%s\n", i, name, dir) - i-- - } - return newStaticFile(buf.String()) -} diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go deleted file mode 100644 index 4d3332771..000000000 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ /dev/null @@ -1,572 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "bytes" - "fmt" - "io" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// mm gets the kernel task's MemoryManager. No additional reference is taken on -// mm here. This is safe because MemoryManager.destroy is required to leave the -// MemoryManager in a state where it's still usable as a DynamicBytesSource. -func getMM(task *kernel.Task) *mm.MemoryManager { - var tmm *mm.MemoryManager - task.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - tmm = mm - } - }) - return tmm -} - -// getMMIncRef returns t's MemoryManager. If getMMIncRef succeeds, the -// MemoryManager's users count is incremented, and must be decremented by the -// caller when it is no longer in use. -func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) { - if task.ExitState() == kernel.TaskExitDead { - return nil, syserror.ESRCH - } - var m *mm.MemoryManager - task.WithMuLocked(func(t *kernel.Task) { - m = t.MemoryManager() - }) - if m == nil || !m.IncUsers() { - return nil, io.EOF - } - return m, nil -} - -type bufferWriter struct { - buf *bytes.Buffer -} - -// WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns -// the number of bytes written. It may return a partial write without an -// error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not -// return a full write with an error (i.e. srcs.NumBytes(), err) where err -// != nil). -func (w *bufferWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - written := srcs.NumBytes() - for !srcs.IsEmpty() { - w.buf.Write(srcs.Head().ToSlice()) - srcs = srcs.Tail() - } - return written, nil -} - -// auxvData implements vfs.DynamicBytesSource for /proc/[pid]/auxv. -// -// +stateify savable -type auxvData struct { - kernfs.DynamicBytesFile - - task *kernel.Task -} - -var _ dynamicInode = (*auxvData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error { - m, err := getMMIncRef(d.task) - if err != nil { - return err - } - defer m.DecUsers(ctx) - - // Space for buffer with AT_NULL (0) terminator at the end. - auxv := m.Auxv() - buf.Grow((len(auxv) + 1) * 16) - for _, e := range auxv { - var tmp [8]byte - usermem.ByteOrder.PutUint64(tmp[:], e.Key) - buf.Write(tmp[:]) - - usermem.ByteOrder.PutUint64(tmp[:], uint64(e.Value)) - buf.Write(tmp[:]) - } - return nil -} - -// execArgType enumerates the types of exec arguments that are exposed through -// proc. -type execArgType int - -const ( - cmdlineDataArg execArgType = iota - environDataArg -) - -// cmdlineData implements vfs.DynamicBytesSource for /proc/[pid]/cmdline. -// -// +stateify savable -type cmdlineData struct { - kernfs.DynamicBytesFile - - task *kernel.Task - - // arg is the type of exec argument this file contains. - arg execArgType -} - -var _ dynamicInode = (*cmdlineData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { - m, err := getMMIncRef(d.task) - if err != nil { - return err - } - defer m.DecUsers(ctx) - - // Figure out the bounds of the exec arg we are trying to read. - var ar usermem.AddrRange - switch d.arg { - case cmdlineDataArg: - ar = usermem.AddrRange{ - Start: m.ArgvStart(), - End: m.ArgvEnd(), - } - case environDataArg: - ar = usermem.AddrRange{ - Start: m.EnvvStart(), - End: m.EnvvEnd(), - } - default: - panic(fmt.Sprintf("unknown exec arg type %v", d.arg)) - } - if ar.Start == 0 || ar.End == 0 { - // Don't attempt to read before the start/end are set up. - return io.EOF - } - - // N.B. Technically this should be usermem.IOOpts.IgnorePermissions = true - // until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading - // cmdline and environment"). - writer := &bufferWriter{buf: buf} - if n, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil { - // Nothing to copy or something went wrong. - return err - } - - // On Linux, if the NULL byte at the end of the argument vector has been - // overwritten, it continues reading the environment vector as part of - // the argument vector. - if d.arg == cmdlineDataArg && buf.Bytes()[buf.Len()-1] != 0 { - if end := bytes.IndexByte(buf.Bytes(), 0); end != -1 { - // If we found a NULL character somewhere else in argv, truncate the - // return up to the NULL terminator (including it). - buf.Truncate(end) - return nil - } - - // There is no NULL terminator in the string, return into envp. - arEnvv := usermem.AddrRange{ - Start: m.EnvvStart(), - End: m.EnvvEnd(), - } - - // Upstream limits the returned amount to one page of slop. - // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208 - // we'll return one page total between argv and envp because of the - // above page restrictions. - if buf.Len() >= usermem.PageSize { - // Returned at least one page already, nothing else to add. - return nil - } - remaining := usermem.PageSize - buf.Len() - if int(arEnvv.Length()) > remaining { - end, ok := arEnvv.Start.AddLength(uint64(remaining)) - if !ok { - return syserror.EFAULT - } - arEnvv.End = end - } - if _, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil { - return err - } - - // Linux will return envp up to and including the first NULL character, - // so find it. - if end := bytes.IndexByte(buf.Bytes()[ar.Length():], 0); end != -1 { - buf.Truncate(end) - } - } - - return nil -} - -// +stateify savable -type commInode struct { - kernfs.DynamicBytesFile - - task *kernel.Task -} - -func newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry { - inode := &commInode{task: task} - inode.DynamicBytesFile.Init(task.Credentials(), ino, &commData{task: task}, perm) - - d := &kernfs.Dentry{} - d.Init(inode) - return d -} - -func (i *commInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { - // This file can always be read or written by members of the same thread - // group. See fs/proc/base.c:proc_tid_comm_permission. - // - // N.B. This check is currently a no-op as we don't yet support writing and - // this file is world-readable anyways. - t := kernel.TaskFromContext(ctx) - if t != nil && t.ThreadGroup() == i.task.ThreadGroup() && !ats.MayExec() { - return nil - } - - return i.DynamicBytesFile.CheckPermissions(ctx, creds, ats) -} - -// commData implements vfs.DynamicBytesSource for /proc/[pid]/comm. -// -// +stateify savable -type commData struct { - kernfs.DynamicBytesFile - - task *kernel.Task -} - -var _ dynamicInode = (*commData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *commData) Generate(ctx context.Context, buf *bytes.Buffer) error { - buf.WriteString(d.task.Name()) - buf.WriteString("\n") - return nil -} - -// idMapData implements vfs.DynamicBytesSource for /proc/[pid]/{gid_map|uid_map}. -// -// +stateify savable -type idMapData struct { - kernfs.DynamicBytesFile - - task *kernel.Task - gids bool -} - -var _ dynamicInode = (*idMapData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error { - var entries []auth.IDMapEntry - if d.gids { - entries = d.task.UserNamespace().GIDMap() - } else { - entries = d.task.UserNamespace().UIDMap() - } - for _, e := range entries { - fmt.Fprintf(buf, "%10d %10d %10d\n", e.FirstID, e.FirstParentID, e.Length) - } - return nil -} - -// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. -// -// +stateify savable -type mapsData struct { - kernfs.DynamicBytesFile - - task *kernel.Task -} - -var _ dynamicInode = (*mapsData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error { - if mm := getMM(d.task); mm != nil { - mm.ReadMapsDataInto(ctx, buf) - } - return nil -} - -// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps. -// -// +stateify savable -type smapsData struct { - kernfs.DynamicBytesFile - - task *kernel.Task -} - -var _ dynamicInode = (*smapsData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error { - if mm := getMM(d.task); mm != nil { - mm.ReadSmapsDataInto(ctx, buf) - } - return nil -} - -// +stateify savable -type taskStatData struct { - kernfs.DynamicBytesFile - - task *kernel.Task - - // If tgstats is true, accumulate fault stats (not implemented) and CPU - // time across all tasks in t's thread group. - tgstats bool - - // pidns is the PID namespace associated with the proc filesystem that - // includes the file using this statData. - pidns *kernel.PIDNamespace -} - -var _ dynamicInode = (*taskStatData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.task)) - fmt.Fprintf(buf, "(%s) ", s.task.Name()) - fmt.Fprintf(buf, "%c ", s.task.StateStatus()[0]) - ppid := kernel.ThreadID(0) - if parent := s.task.Parent(); parent != nil { - ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) - } - fmt.Fprintf(buf, "%d ", ppid) - fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.task.ThreadGroup().ProcessGroup())) - fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.task.ThreadGroup().Session())) - fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */) - fmt.Fprintf(buf, "0 " /* flags */) - fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */) - var cputime usage.CPUStats - if s.tgstats { - cputime = s.task.ThreadGroup().CPUStats() - } else { - cputime = s.task.CPUStats() - } - fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime)) - cputime = s.task.ThreadGroup().JoinedChildCPUStats() - fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime)) - fmt.Fprintf(buf, "%d %d ", s.task.Priority(), s.task.Niceness()) - fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Count()) - - // itrealvalue. Since kernel 2.6.17, this field is no longer - // maintained, and is hard coded as 0. - fmt.Fprintf(buf, "0 ") - - // Start time is relative to boot time, expressed in clock ticks. - fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime()))) - - var vss, rss uint64 - s.task.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) - fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize) - - // rsslim. - fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Limits().Get(limits.Rss).Cur) - - fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */) - fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */) - fmt.Fprintf(buf, "0 0 " /* nswap cnswap */) - terminationSignal := linux.Signal(0) - if s.task == s.task.ThreadGroup().Leader() { - terminationSignal = s.task.ThreadGroup().TerminationSignal() - } - fmt.Fprintf(buf, "%d ", terminationSignal) - fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */) - fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */) - fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */) - fmt.Fprintf(buf, "0\n" /* exit_code */) - - return nil -} - -// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm. -// -// +stateify savable -type statmData struct { - kernfs.DynamicBytesFile - - task *kernel.Task -} - -var _ dynamicInode = (*statmData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { - var vss, rss uint64 - s.task.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) - - fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize) - return nil -} - -// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status. -// -// +stateify savable -type statusData struct { - kernfs.DynamicBytesFile - - task *kernel.Task - pidns *kernel.PIDNamespace -} - -var _ dynamicInode = (*statusData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name()) - fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus()) - fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup())) - fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task)) - ppid := kernel.ThreadID(0) - if parent := s.task.Parent(); parent != nil { - ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) - } - fmt.Fprintf(buf, "PPid:\t%d\n", ppid) - tpid := kernel.ThreadID(0) - if tracer := s.task.Tracer(); tracer != nil { - tpid = s.pidns.IDOfTask(tracer) - } - fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid) - var fds int - var vss, rss, data uint64 - s.task.WithMuLocked(func(t *kernel.Task) { - if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() - } - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - data = mm.VirtualDataSize() - } - }) - fmt.Fprintf(buf, "FDSize:\t%d\n", fds) - fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10) - fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10) - fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10) - fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count()) - creds := s.task.Credentials() - fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps) - fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps) - fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) - fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps) - fmt.Fprintf(buf, "Seccomp:\t%d\n", s.task.SeccompMode()) - // We unconditionally report a single NUMA node. See - // pkg/sentry/syscalls/linux/sys_mempolicy.go. - fmt.Fprintf(buf, "Mems_allowed:\t1\n") - fmt.Fprintf(buf, "Mems_allowed_list:\t0\n") - return nil -} - -// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider. -type ioUsage interface { - // IOUsage returns the io usage data. - IOUsage() *usage.IO -} - -// +stateify savable -type ioData struct { - kernfs.DynamicBytesFile - - ioUsage -} - -var _ dynamicInode = (*ioData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error { - io := usage.IO{} - io.Accumulate(i.IOUsage()) - - fmt.Fprintf(buf, "char: %d\n", io.CharsRead) - fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten) - fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls) - fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls) - fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead) - fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten) - fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled) - return nil -} - -// oomScoreAdj is a stub of the /proc/<pid>/oom_score_adj file. -// -// +stateify savable -type oomScoreAdj struct { - kernfs.DynamicBytesFile - - task *kernel.Task -} - -var _ vfs.WritableDynamicBytesSource = (*oomScoreAdj)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error { - if o.task.ExitState() == kernel.TaskExitDead { - return syserror.ESRCH - } - fmt.Fprintf(buf, "%d\n", o.task.OOMScoreAdj()) - return nil -} - -// Write implements vfs.WritableDynamicBytesSource.Write. -func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { - if src.NumBytes() == 0 { - return 0, nil - } - - // Limit input size so as not to impact performance if input size is large. - src = src.TakeFirst(usermem.PageSize - 1) - - var v int32 - n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) - if err != nil { - return 0, err - } - - if o.task.ExitState() == kernel.TaskExitDead { - return 0, syserror.ESRCH - } - if err := o.task.SetOOMScoreAdj(v); err != nil { - return 0, err - } - - return n, nil -} diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go deleted file mode 100644 index 373a7b17d..000000000 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ /dev/null @@ -1,790 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "bytes" - "fmt" - "io" - "reflect" - "time" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/socket/unix" - "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/usermem" -) - -func newTaskNetDir(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { - k := task.Kernel() - pidns := task.PIDNamespace() - root := auth.NewRootCredentials(pidns.UserNamespace()) - - var contents map[string]*kernfs.Dentry - if stack := task.NetworkNamespace().Stack(); stack != nil { - const ( - arp = "IP address HW type Flags HW address Mask Device\n" - netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n" - packet = "sk RefCnt Type Proto Iface R Rmem User Inode\n" - protocols = "protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n" - ptype = "Type Device Function\n" - upd6 = " sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n" - ) - psched := fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)) - - // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task - // network namespace. - contents = map[string]*kernfs.Dentry{ - "dev": newDentry(root, inoGen.NextIno(), 0444, &netDevData{stack: stack}), - "snmp": newDentry(root, inoGen.NextIno(), 0444, &netSnmpData{stack: stack}), - - // The following files are simple stubs until they are implemented in - // netstack, if the file contains a header the stub is just the header - // otherwise it is an empty file. - "arp": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(arp)), - "netlink": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(netlink)), - "netstat": newDentry(root, inoGen.NextIno(), 0444, &netStatData{}), - "packet": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(packet)), - "protocols": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(protocols)), - - // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000, - // high res timer ticks per sec (ClockGetres returns 1ns resolution). - "psched": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(psched)), - "ptype": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(ptype)), - "route": newDentry(root, inoGen.NextIno(), 0444, &netRouteData{stack: stack}), - "tcp": newDentry(root, inoGen.NextIno(), 0444, &netTCPData{kernel: k}), - "udp": newDentry(root, inoGen.NextIno(), 0444, &netUDPData{kernel: k}), - "unix": newDentry(root, inoGen.NextIno(), 0444, &netUnixData{kernel: k}), - } - - if stack.SupportsIPv6() { - contents["if_inet6"] = newDentry(root, inoGen.NextIno(), 0444, &ifinet6{stack: stack}) - contents["ipv6_route"] = newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")) - contents["tcp6"] = newDentry(root, inoGen.NextIno(), 0444, &netTCP6Data{kernel: k}) - contents["udp6"] = newDentry(root, inoGen.NextIno(), 0444, newStaticFile(upd6)) - } - } - - return newTaskOwnedDir(task, inoGen.NextIno(), 0555, contents) -} - -// ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6. -// -// +stateify savable -type ifinet6 struct { - kernfs.DynamicBytesFile - - stack inet.Stack -} - -var _ dynamicInode = (*ifinet6)(nil) - -func (n *ifinet6) contents() []string { - var lines []string - nics := n.stack.Interfaces() - for id, naddrs := range n.stack.InterfaceAddrs() { - nic, ok := nics[id] - if !ok { - // NIC was added after NICNames was called. We'll just ignore it. - continue - } - - for _, a := range naddrs { - // IPv6 only. - if a.Family != linux.AF_INET6 { - continue - } - - // Fields: - // IPv6 address displayed in 32 hexadecimal chars without colons - // Netlink device number (interface index) in hexadecimal (use nic id) - // Prefix length in hexadecimal - // Scope value (use 0) - // Interface flags - // Device name - lines = append(lines, fmt.Sprintf("%032x %02x %02x %02x %02x %8s\n", a.Addr, id, a.PrefixLen, 0, a.Flags, nic.Name)) - } - } - return lines -} - -// Generate implements vfs.DynamicBytesSource.Generate. -func (n *ifinet6) Generate(ctx context.Context, buf *bytes.Buffer) error { - for _, l := range n.contents() { - buf.WriteString(l) - } - return nil -} - -// netDevData implements vfs.DynamicBytesSource for /proc/net/dev. -// -// +stateify savable -type netDevData struct { - kernfs.DynamicBytesFile - - stack inet.Stack -} - -var _ dynamicInode = (*netDevData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (n *netDevData) Generate(ctx context.Context, buf *bytes.Buffer) error { - interfaces := n.stack.Interfaces() - buf.WriteString("Inter-| Receive | Transmit\n") - buf.WriteString(" face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n") - - for _, i := range interfaces { - // Implements the same format as - // net/core/net-procfs.c:dev_seq_printf_stats. - var stats inet.StatDev - if err := n.stack.Statistics(&stats, i.Name); err != nil { - log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err) - continue - } - fmt.Fprintf( - buf, - "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n", - i.Name, - // Received - stats[0], // bytes - stats[1], // packets - stats[2], // errors - stats[3], // dropped - stats[4], // fifo - stats[5], // frame - stats[6], // compressed - stats[7], // multicast - // Transmitted - stats[8], // bytes - stats[9], // packets - stats[10], // errors - stats[11], // dropped - stats[12], // fifo - stats[13], // frame - stats[14], // compressed - stats[15], // multicast - ) - } - - return nil -} - -// netUnixData implements vfs.DynamicBytesSource for /proc/net/unix. -// -// +stateify savable -type netUnixData struct { - kernfs.DynamicBytesFile - - kernel *kernel.Kernel -} - -var _ dynamicInode = (*netUnixData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { - buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n") - for _, se := range n.kernel.ListSockets() { - s := se.Sock.Get() - if s == nil { - log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock) - continue - } - sfile := s.(*fs.File) - if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { - s.DecRef() - // Not a unix socket. - continue - } - sops := sfile.FileOperations.(*unix.SocketOperations) - - addr, err := sops.Endpoint().GetLocalAddress() - if err != nil { - log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err) - addr.Addr = "<unknown>" - } - - sockFlags := 0 - if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok { - if ce.Listening() { - // For unix domain sockets, linux reports a single flag - // value if the socket is listening, of __SO_ACCEPTCON. - sockFlags = linux.SO_ACCEPTCON - } - } - - // In the socket entry below, the value for the 'Num' field requires - // some consideration. Linux prints the address to the struct - // unix_sock representing a socket in the kernel, but may redact the - // value for unprivileged users depending on the kptr_restrict - // sysctl. - // - // One use for this field is to allow a privileged user to - // introspect into the kernel memory to determine information about - // a socket not available through procfs, such as the socket's peer. - // - // In gvisor, returning a pointer to our internal structures would - // be pointless, as it wouldn't match the memory layout for struct - // unix_sock, making introspection difficult. We could populate a - // struct unix_sock with the appropriate data, but even that - // requires consideration for which kernel version to emulate, as - // the definition of this struct changes over time. - // - // For now, we always redact this pointer. - fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %5d", - (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct. - sfile.ReadRefs()-1, // RefCount, don't count our own ref. - 0, // Protocol, always 0 for UDS. - sockFlags, // Flags. - sops.Endpoint().Type(), // Type. - sops.State(), // State. - sfile.InodeID(), // Inode. - ) - - // Path - if len(addr.Addr) != 0 { - if addr.Addr[0] == 0 { - // Abstract path. - fmt.Fprintf(buf, " @%s", string(addr.Addr[1:])) - } else { - fmt.Fprintf(buf, " %s", string(addr.Addr)) - } - } - fmt.Fprintf(buf, "\n") - - s.DecRef() - } - return nil -} - -func networkToHost16(n uint16) uint16 { - // n is in network byte order, so is big-endian. The most-significant byte - // should be stored in the lower address. - // - // We manually inline binary.BigEndian.Uint16() because Go does not support - // non-primitive consts, so binary.BigEndian is a (mutable) var, so calls to - // binary.BigEndian.Uint16() require a read of binary.BigEndian and an - // interface method call, defeating inlining. - buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)} - return usermem.ByteOrder.Uint16(buf[:]) -} - -func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { - switch family { - case linux.AF_INET: - var a linux.SockAddrInet - if i != nil { - a = *i.(*linux.SockAddrInet) - } - - // linux.SockAddrInet.Port is stored in the network byte order and is - // printed like a number in host byte order. Note that all numbers in host - // byte order are printed with the most-significant byte first when - // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. - port := networkToHost16(a.Port) - - // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order - // (i.e. most-significant byte in index 0). Linux represents this as a - // __be32 which is a typedef for an unsigned int, and is printed with - // %X. This means that for a little-endian machine, Linux prints the - // least-significant byte of the address first. To emulate this, we first - // invert the byte order for the address using usermem.ByteOrder.Uint32, - // which makes it have the equivalent encoding to a __be32 on a little - // endian machine. Note that this operation is a no-op on a big endian - // machine. Then similar to Linux, we format it with %X, which will print - // the most-significant byte of the __be32 address first, which is now - // actually the least-significant byte of the original address in - // linux.SockAddrInet.Addr on little endian machines, due to the conversion. - addr := usermem.ByteOrder.Uint32(a.Addr[:]) - - fmt.Fprintf(w, "%08X:%04X ", addr, port) - case linux.AF_INET6: - var a linux.SockAddrInet6 - if i != nil { - a = *i.(*linux.SockAddrInet6) - } - - port := networkToHost16(a.Port) - addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4]) - addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8]) - addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12]) - addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16]) - fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port) - } -} - -func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, family int) error { - // t may be nil here if our caller is not part of a task goroutine. This can - // happen for example if we're here for "sentryctl cat". When t is nil, - // degrade gracefully and retrieve what we can. - t := kernel.TaskFromContext(ctx) - - for _, se := range k.ListSockets() { - s := se.Sock.Get() - if s == nil { - log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) - continue - } - sfile := s.(*fs.File) - sops, ok := sfile.FileOperations.(socket.Socket) - if !ok { - panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) - } - if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { - s.DecRef() - // Not tcp4 sockets. - continue - } - - // Linux's documentation for the fields below can be found at - // https://www.kernel.org/doc/Documentation/networking/proc_net_tcp.txt. - // For Linux's implementation, see net/ipv4/tcp_ipv4.c:get_tcp4_sock(). - // Note that the header doesn't contain labels for all the fields. - - // Field: sl; entry number. - fmt.Fprintf(buf, "%4d: ", se.ID) - - // Field: local_adddress. - var localAddr linux.SockAddr - if t != nil { - if local, _, err := sops.GetSockName(t); err == nil { - localAddr = local - } - } - writeInetAddr(buf, family, localAddr) - - // Field: rem_address. - var remoteAddr linux.SockAddr - if t != nil { - if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = remote - } - } - writeInetAddr(buf, family, remoteAddr) - - // Field: state; socket state. - fmt.Fprintf(buf, "%02X ", sops.State()) - - // Field: tx_queue, rx_queue; number of packets in the transmit and - // receive queue. Unimplemented. - fmt.Fprintf(buf, "%08X:%08X ", 0, 0) - - // Field: tr, tm->when; timer active state and number of jiffies - // until timer expires. Unimplemented. - fmt.Fprintf(buf, "%02X:%08X ", 0, 0) - - // Field: retrnsmt; number of unrecovered RTO timeouts. - // Unimplemented. - fmt.Fprintf(buf, "%08X ", 0) - - // Field: uid. - uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx) - if err != nil { - log.Warningf("Failed to retrieve unstable attr for socket file: %v", err) - fmt.Fprintf(buf, "%5d ", 0) - } else { - creds := auth.CredentialsFromContext(ctx) - fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow())) - } - - // Field: timeout; number of unanswered 0-window probes. - // Unimplemented. - fmt.Fprintf(buf, "%8d ", 0) - - // Field: inode. - fmt.Fprintf(buf, "%8d ", sfile.InodeID()) - - // Field: refcount. Don't count the ref we obtain while deferencing - // the weakref to this socket. - fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1) - - // Field: Socket struct address. Redacted due to the same reason as - // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. - fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil)) - - // Field: retransmit timeout. Unimplemented. - fmt.Fprintf(buf, "%d ", 0) - - // Field: predicted tick of soft clock (delayed ACK control data). - // Unimplemented. - fmt.Fprintf(buf, "%d ", 0) - - // Field: (ack.quick<<1)|ack.pingpong, Unimplemented. - fmt.Fprintf(buf, "%d ", 0) - - // Field: sending congestion window, Unimplemented. - fmt.Fprintf(buf, "%d ", 0) - - // Field: Slow start size threshold, -1 if threshold >= 0xFFFF. - // Unimplemented, report as large threshold. - fmt.Fprintf(buf, "%d", -1) - - fmt.Fprintf(buf, "\n") - - s.DecRef() - } - - return nil -} - -// netTCPData implements vfs.DynamicBytesSource for /proc/net/tcp. -// -// +stateify savable -type netTCPData struct { - kernfs.DynamicBytesFile - - kernel *kernel.Kernel -} - -var _ dynamicInode = (*netTCPData)(nil) - -func (d *netTCPData) Generate(ctx context.Context, buf *bytes.Buffer) error { - buf.WriteString(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n") - return commonGenerateTCP(ctx, buf, d.kernel, linux.AF_INET) -} - -// netTCP6Data implements vfs.DynamicBytesSource for /proc/net/tcp6. -// -// +stateify savable -type netTCP6Data struct { - kernfs.DynamicBytesFile - - kernel *kernel.Kernel -} - -var _ dynamicInode = (*netTCP6Data)(nil) - -func (d *netTCP6Data) Generate(ctx context.Context, buf *bytes.Buffer) error { - buf.WriteString(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n") - return commonGenerateTCP(ctx, buf, d.kernel, linux.AF_INET6) -} - -// netUDPData implements vfs.DynamicBytesSource for /proc/net/udp. -// -// +stateify savable -type netUDPData struct { - kernfs.DynamicBytesFile - - kernel *kernel.Kernel -} - -var _ dynamicInode = (*netUDPData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { - // t may be nil here if our caller is not part of a task goroutine. This can - // happen for example if we're here for "sentryctl cat". When t is nil, - // degrade gracefully and retrieve what we can. - t := kernel.TaskFromContext(ctx) - - for _, se := range d.kernel.ListSockets() { - s := se.Sock.Get() - if s == nil { - log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) - continue - } - sfile := s.(*fs.File) - sops, ok := sfile.FileOperations.(socket.Socket) - if !ok { - panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) - } - if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM { - s.DecRef() - // Not udp4 socket. - continue - } - - // For Linux's implementation, see net/ipv4/udp.c:udp4_format_sock(). - - // Field: sl; entry number. - fmt.Fprintf(buf, "%5d: ", se.ID) - - // Field: local_adddress. - var localAddr linux.SockAddrInet - if t != nil { - if local, _, err := sops.GetSockName(t); err == nil { - localAddr = *local.(*linux.SockAddrInet) - } - } - writeInetAddr(buf, linux.AF_INET, &localAddr) - - // Field: rem_address. - var remoteAddr linux.SockAddrInet - if t != nil { - if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = *remote.(*linux.SockAddrInet) - } - } - writeInetAddr(buf, linux.AF_INET, &remoteAddr) - - // Field: state; socket state. - fmt.Fprintf(buf, "%02X ", sops.State()) - - // Field: tx_queue, rx_queue; number of packets in the transmit and - // receive queue. Unimplemented. - fmt.Fprintf(buf, "%08X:%08X ", 0, 0) - - // Field: tr, tm->when. Always 0 for UDP. - fmt.Fprintf(buf, "%02X:%08X ", 0, 0) - - // Field: retrnsmt. Always 0 for UDP. - fmt.Fprintf(buf, "%08X ", 0) - - // Field: uid. - uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx) - if err != nil { - log.Warningf("Failed to retrieve unstable attr for socket file: %v", err) - fmt.Fprintf(buf, "%5d ", 0) - } else { - creds := auth.CredentialsFromContext(ctx) - fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow())) - } - - // Field: timeout. Always 0 for UDP. - fmt.Fprintf(buf, "%8d ", 0) - - // Field: inode. - fmt.Fprintf(buf, "%8d ", sfile.InodeID()) - - // Field: ref; reference count on the socket inode. Don't count the ref - // we obtain while deferencing the weakref to this socket. - fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1) - - // Field: Socket struct address. Redacted due to the same reason as - // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. - fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil)) - - // Field: drops; number of dropped packets. Unimplemented. - fmt.Fprintf(buf, "%d", 0) - - fmt.Fprintf(buf, "\n") - - s.DecRef() - } - return nil -} - -// netSnmpData implements vfs.DynamicBytesSource for /proc/net/snmp. -// -// +stateify savable -type netSnmpData struct { - kernfs.DynamicBytesFile - - stack inet.Stack -} - -var _ dynamicInode = (*netSnmpData)(nil) - -type snmpLine struct { - prefix string - header string -} - -var snmp = []snmpLine{ - { - prefix: "Ip", - header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates", - }, - { - prefix: "Icmp", - header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps", - }, - { - prefix: "IcmpMsg", - }, - { - prefix: "Tcp", - header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors", - }, - { - prefix: "Udp", - header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", - }, - { - prefix: "UdpLite", - header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", - }, -} - -func toSlice(a interface{}) []uint64 { - v := reflect.Indirect(reflect.ValueOf(a)) - return v.Slice(0, v.Len()).Interface().([]uint64) -} - -func sprintSlice(s []uint64) string { - if len(s) == 0 { - return "" - } - r := fmt.Sprint(s) - return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. -} - -// Generate implements vfs.DynamicBytesSource. -func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error { - types := []interface{}{ - &inet.StatSNMPIP{}, - &inet.StatSNMPICMP{}, - nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats. - &inet.StatSNMPTCP{}, - &inet.StatSNMPUDP{}, - &inet.StatSNMPUDPLite{}, - } - for i, stat := range types { - line := snmp[i] - if stat == nil { - fmt.Fprintf(buf, "%s:\n", line.prefix) - fmt.Fprintf(buf, "%s:\n", line.prefix) - continue - } - if err := d.stack.Statistics(stat, line.prefix); err != nil { - if err == syserror.EOPNOTSUPP { - log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) - } else { - log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) - } - } - - fmt.Fprintf(buf, "%s: %s\n", line.prefix, line.header) - - if line.prefix == "Tcp" { - tcp := stat.(*inet.StatSNMPTCP) - // "Tcp" needs special processing because MaxConn is signed. RFC 2012. - fmt.Sprintf("%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) - } else { - fmt.Sprintf("%s: %s\n", line.prefix, sprintSlice(toSlice(stat))) - } - } - return nil -} - -// netRouteData implements vfs.DynamicBytesSource for /proc/net/route. -// -// +stateify savable -type netRouteData struct { - kernfs.DynamicBytesFile - - stack inet.Stack -} - -var _ dynamicInode = (*netRouteData)(nil) - -// Generate implements vfs.DynamicBytesSource. -// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. -func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "%-127s\n", "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT") - - interfaces := d.stack.Interfaces() - for _, rt := range d.stack.RouteTable() { - // /proc/net/route only includes ipv4 routes. - if rt.Family != linux.AF_INET { - continue - } - - // /proc/net/route does not include broadcast or multicast routes. - if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST { - continue - } - - iface, ok := interfaces[rt.OutputInterface] - if !ok || iface.Name == "lo" { - continue - } - - var ( - gw uint32 - prefix uint32 - flags = linux.RTF_UP - ) - if len(rt.GatewayAddr) == header.IPv4AddressSize { - flags |= linux.RTF_GATEWAY - gw = usermem.ByteOrder.Uint32(rt.GatewayAddr) - } - if len(rt.DstAddr) == header.IPv4AddressSize { - prefix = usermem.ByteOrder.Uint32(rt.DstAddr) - } - l := fmt.Sprintf( - "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d", - iface.Name, - prefix, - gw, - flags, - 0, // RefCnt. - 0, // Use. - 0, // Metric. - (uint32(1)<<rt.DstLen)-1, - 0, // MTU. - 0, // Window. - 0, // RTT. - ) - fmt.Fprintf(buf, "%-127s\n", l) - } - return nil -} - -// netStatData implements vfs.DynamicBytesSource for /proc/net/netstat. -// -// +stateify savable -type netStatData struct { - kernfs.DynamicBytesFile - - stack inet.Stack -} - -var _ dynamicInode = (*netStatData)(nil) - -// Generate implements vfs.DynamicBytesSource. -// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. -func (d *netStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { - buf.WriteString("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed " + - "EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps " + - "LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive " + - "PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost " + - "ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog " + - "TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser " + - "TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging " + - "TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo " + - "TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit " + - "TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans " + - "TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes " + - "TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail " + - "TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent " + - "TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose " + - "TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed " + - "TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld " + - "TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected " + - "TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback " + - "TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter " + - "TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail " + - "TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK " + - "TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail " + - "TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow " + - "TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets " + - "TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv " + - "TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect " + - "TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd " + - "TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq " + - "TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge " + - "TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n") - return nil -} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go deleted file mode 100644 index d203cebd4..000000000 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "bytes" - "sort" - "strconv" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -const ( - selfName = "self" - threadSelfName = "thread-self" -) - -// InoGenerator generates unique inode numbers for a given filesystem. -type InoGenerator interface { - NextIno() uint64 -} - -// tasksInode represents the inode for /proc/ directory. -// -// +stateify savable -type tasksInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren - kernfs.InodeAttrs - kernfs.OrderedChildren - - inoGen InoGenerator - pidns *kernel.PIDNamespace - - // '/proc/self' and '/proc/thread-self' have custom directory offsets in - // Linux. So handle them outside of OrderedChildren. - selfSymlink *vfs.Dentry - threadSelfSymlink *vfs.Dentry - - // cgroupControllers is a map of controller name to directory in the - // cgroup hierarchy. These controllers are immutable and will be listed - // in /proc/pid/cgroup if not nil. - cgroupControllers map[string]string -} - -var _ kernfs.Inode = (*tasksInode)(nil) - -func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) { - root := auth.NewRootCredentials(pidns.UserNamespace()) - contents := map[string]*kernfs.Dentry{ - "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(cpuInfoData(k))), - //"filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}), - "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}), - "sys": newSysDir(root, inoGen, k), - "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}), - "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"), - "net": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/net"), - "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}), - "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}), - "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}), - } - - inode := &tasksInode{ - pidns: pidns, - inoGen: inoGen, - selfSymlink: newSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(), - threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(), - cgroupControllers: cgroupControllers, - } - inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555) - - dentry := &kernfs.Dentry{} - dentry.Init(inode) - - inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - links := inode.OrderedChildren.Populate(dentry, contents) - inode.IncLinks(links) - - return inode, dentry -} - -// Lookup implements kernfs.inodeDynamicLookup. -func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { - // Try to lookup a corresponding task. - tid, err := strconv.ParseUint(name, 10, 64) - if err != nil { - // If it failed to parse, check if it's one of the special handled files. - switch name { - case selfName: - return i.selfSymlink, nil - case threadSelfName: - return i.threadSelfSymlink, nil - } - return nil, syserror.ENOENT - } - - task := i.pidns.TaskWithID(kernel.ThreadID(tid)) - if task == nil { - return nil, syserror.ENOENT - } - - taskDentry := newTaskInode(i.inoGen, task, i.pidns, true, i.cgroupControllers) - return taskDentry.VFSDentry(), nil -} - -// Valid implements kernfs.inodeDynamicLookup. -func (i *tasksInode) Valid(ctx context.Context) bool { - return true -} - -// IterDirents implements kernfs.inodeDynamicLookup. -func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { - // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 - const FIRST_PROCESS_ENTRY = 256 - - // Use maxTaskID to shortcut searches that will result in 0 entries. - const maxTaskID = kernel.TasksLimit + 1 - if offset >= maxTaskID { - return offset, nil - } - - // According to Linux (fs/proc/base.c:proc_pid_readdir()), process directories - // start at offset FIRST_PROCESS_ENTRY with '/proc/self', followed by - // '/proc/thread-self' and then '/proc/[pid]'. - if offset < FIRST_PROCESS_ENTRY { - offset = FIRST_PROCESS_ENTRY - } - - if offset == FIRST_PROCESS_ENTRY { - dirent := vfs.Dirent{ - Name: selfName, - Type: linux.DT_LNK, - Ino: i.inoGen.NextIno(), - NextOff: offset + 1, - } - if err := cb.Handle(dirent); err != nil { - return offset, err - } - offset++ - } - if offset == FIRST_PROCESS_ENTRY+1 { - dirent := vfs.Dirent{ - Name: threadSelfName, - Type: linux.DT_LNK, - Ino: i.inoGen.NextIno(), - NextOff: offset + 1, - } - if err := cb.Handle(dirent); err != nil { - return offset, err - } - offset++ - } - - // Collect all tasks that TGIDs are greater than the offset specified. Per - // Linux we only include in directory listings if it's the leader. But for - // whatever crazy reason, you can still walk to the given node. - var tids []int - startTid := offset - FIRST_PROCESS_ENTRY - 2 - for _, tg := range i.pidns.ThreadGroups() { - tid := i.pidns.IDOfThreadGroup(tg) - if int64(tid) < startTid { - continue - } - if leader := tg.Leader(); leader != nil { - tids = append(tids, int(tid)) - } - } - - if len(tids) == 0 { - return offset, nil - } - - sort.Ints(tids) - for _, tid := range tids { - dirent := vfs.Dirent{ - Name: strconv.FormatUint(uint64(tid), 10), - Type: linux.DT_DIR, - Ino: i.inoGen.NextIno(), - NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1, - } - if err := cb.Handle(dirent); err != nil { - return offset, err - } - offset++ - } - return maxTaskID, nil -} - -// Open implements kernfs.Inode. -func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &kernfs.GenericDirectoryFD{} - fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts) - return fd.VFSFileDescription(), nil -} - -func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) - if err != nil { - return linux.Statx{}, err - } - - if opts.Mask&linux.STATX_NLINK != 0 { - // Add dynamic children to link count. - for _, tg := range i.pidns.ThreadGroups() { - if leader := tg.Leader(); leader != nil { - stat.Nlink++ - } - } - } - - return stat, nil -} - -func cpuInfoData(k *kernel.Kernel) string { - features := k.FeatureSet() - if features == nil { - // Kernel is always initialized with a FeatureSet. - panic("cpuinfo read with nil FeatureSet") - } - var buf bytes.Buffer - for i, max := uint(0), k.ApplicationCores(); i < max; i++ { - features.WriteCPUInfoTo(i, &buf) - } - return buf.String() -} - -func shmData(v uint64) dynamicInode { - return newStaticFile(strconv.FormatUint(v, 10)) -} diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go deleted file mode 100644 index 434998910..000000000 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "bytes" - "fmt" - "strconv" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -type selfSymlink struct { - kernfs.InodeAttrs - kernfs.InodeNoopRefCount - kernfs.InodeSymlink - - pidns *kernel.PIDNamespace -} - -var _ kernfs.Inode = (*selfSymlink)(nil) - -func newSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry { - inode := &selfSymlink{pidns: pidns} - inode.Init(creds, ino, linux.ModeSymlink|perm) - - d := &kernfs.Dentry{} - d.Init(inode) - return d -} - -func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { - t := kernel.TaskFromContext(ctx) - if t == nil { - // Who is reading this link? - return "", syserror.EINVAL - } - tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) - if tgid == 0 { - return "", syserror.ENOENT - } - return strconv.FormatUint(uint64(tgid), 10), nil -} - -type threadSelfSymlink struct { - kernfs.InodeAttrs - kernfs.InodeNoopRefCount - kernfs.InodeSymlink - - pidns *kernel.PIDNamespace -} - -var _ kernfs.Inode = (*threadSelfSymlink)(nil) - -func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry { - inode := &threadSelfSymlink{pidns: pidns} - inode.Init(creds, ino, linux.ModeSymlink|perm) - - d := &kernfs.Dentry{} - d.Init(inode) - return d -} - -func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { - t := kernel.TaskFromContext(ctx) - if t == nil { - // Who is reading this link? - return "", syserror.EINVAL - } - tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) - tid := s.pidns.IDOfTask(t) - if tid == 0 || tgid == 0 { - return "", syserror.ENOENT - } - return fmt.Sprintf("%d/task/%d", tgid, tid), nil -} - -// cpuStats contains the breakdown of CPU time for /proc/stat. -type cpuStats struct { - // user is time spent in userspace tasks with non-positive niceness. - user uint64 - - // nice is time spent in userspace tasks with positive niceness. - nice uint64 - - // system is time spent in non-interrupt kernel context. - system uint64 - - // idle is time spent idle. - idle uint64 - - // ioWait is time spent waiting for IO. - ioWait uint64 - - // irq is time spent in interrupt context. - irq uint64 - - // softirq is time spent in software interrupt context. - softirq uint64 - - // steal is involuntary wait time. - steal uint64 - - // guest is time spent in guests with non-positive niceness. - guest uint64 - - // guestNice is time spent in guests with positive niceness. - guestNice uint64 -} - -// String implements fmt.Stringer. -func (c cpuStats) String() string { - return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice) -} - -// statData implements vfs.DynamicBytesSource for /proc/stat. -// -// +stateify savable -type statData struct { - kernfs.DynamicBytesFile - - // k is the owning Kernel. - k *kernel.Kernel -} - -var _ dynamicInode = (*statData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error { - // TODO(b/37226836): We currently export only zero CPU stats. We could - // at least provide some aggregate stats. - var cpu cpuStats - fmt.Fprintf(buf, "cpu %s\n", cpu) - - for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ { - fmt.Fprintf(buf, "cpu%d %s\n", c, cpu) - } - - // The total number of interrupts is dependent on the CPUs and PCI - // devices on the system. See arch_probe_nr_irqs. - // - // Since we don't report real interrupt stats, just choose an arbitrary - // value from a representative VM. - const numInterrupts = 256 - - // The Kernel doesn't handle real interrupts, so report all zeroes. - // TODO(b/37226836): We could count page faults as #PF. - fmt.Fprintf(buf, "intr 0") // total - for i := 0; i < numInterrupts; i++ { - fmt.Fprintf(buf, " 0") - } - fmt.Fprintf(buf, "\n") - - // Total number of context switches. - // TODO(b/37226836): Count this. - fmt.Fprintf(buf, "ctxt 0\n") - - // CLOCK_REALTIME timestamp from boot, in seconds. - fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds()) - - // Total number of clones. - // TODO(b/37226836): Count this. - fmt.Fprintf(buf, "processes 0\n") - - // Number of runnable tasks. - // TODO(b/37226836): Count this. - fmt.Fprintf(buf, "procs_running 0\n") - - // Number of tasks waiting on IO. - // TODO(b/37226836): Count this. - fmt.Fprintf(buf, "procs_blocked 0\n") - - // Number of each softirq handled. - fmt.Fprintf(buf, "softirq 0") // total - for i := 0; i < linux.NumSoftIRQ; i++ { - fmt.Fprintf(buf, " 0") - } - fmt.Fprintf(buf, "\n") - return nil -} - -// loadavgData backs /proc/loadavg. -// -// +stateify savable -type loadavgData struct { - kernfs.DynamicBytesFile -} - -var _ dynamicInode = (*loadavgData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error { - // TODO(b/62345059): Include real data in fields. - // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods. - // Column 4-5: currently running processes and the total number of processes. - // Column 6: the last process ID used. - fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0) - return nil -} - -// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo. -// -// +stateify savable -type meminfoData struct { - kernfs.DynamicBytesFile - - // k is the owning Kernel. - k *kernel.Kernel -} - -var _ dynamicInode = (*meminfoData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { - mf := d.k.MemoryFile() - mf.UpdateUsage() - snapshot, totalUsage := usage.MemoryAccounting.Copy() - totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage) - anon := snapshot.Anonymous + snapshot.Tmpfs - file := snapshot.PageCache + snapshot.Mapped - // We don't actually have active/inactive LRUs, so just make up numbers. - activeFile := (file / 2) &^ (usermem.PageSize - 1) - inactiveFile := file - activeFile - - fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024) - memFree := (totalSize - totalUsage) / 1024 - // We use MemFree as MemAvailable because we don't swap. - // TODO(rahat): When reclaim is implemented the value of MemAvailable - // should change. - fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree) - fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree) - fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices - fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024) - // Emulate a system with no swap, which disables inactivation of anon pages. - fmt.Fprintf(buf, "SwapCache: 0 kB\n") - fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024) - fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024) - fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024) - fmt.Fprintf(buf, "Inactive(anon): 0 kB\n") - fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024) - fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024) - fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263) - fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263) - fmt.Fprintf(buf, "SwapTotal: 0 kB\n") - fmt.Fprintf(buf, "SwapFree: 0 kB\n") - fmt.Fprintf(buf, "Dirty: 0 kB\n") - fmt.Fprintf(buf, "Writeback: 0 kB\n") - fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024) - fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know - fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024) - return nil -} - -// uptimeData implements vfs.DynamicBytesSource for /proc/uptime. -// -// +stateify savable -type uptimeData struct { - kernfs.DynamicBytesFile -} - -var _ dynamicInode = (*uptimeData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (*uptimeData) Generate(ctx context.Context, buf *bytes.Buffer) error { - k := kernel.KernelFromContext(ctx) - now := time.NowFromContext(ctx) - - // Pretend that we've spent zero time sleeping (second number). - fmt.Fprintf(buf, "%.2f 0.00\n", now.Sub(k.Timekeeper().BootTime()).Seconds()) - return nil -} - -// versionData implements vfs.DynamicBytesSource for /proc/version. -// -// +stateify savable -type versionData struct { - kernfs.DynamicBytesFile - - // k is the owning Kernel. - k *kernel.Kernel -} - -var _ dynamicInode = (*versionData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { - init := v.k.GlobalInit() - if init == nil { - // Attempted to read before the init Task is created. This can - // only occur during startup, which should never need to read - // this file. - panic("Attempted to read version before initial Task is available") - } - - // /proc/version takes the form: - // - // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST) - // (COMPILER_VERSION) VERSION" - // - // where: - // - SYSNAME, RELEASE, and VERSION are the same as returned by - // sys_utsname - // - COMPILE_USER is the user that build the kernel - // - COMPILE_HOST is the hostname of the machine on which the kernel - // was built - // - COMPILER_VERSION is the version reported by the building compiler - // - // Since we don't really want to expose build information to - // applications, those fields are omitted. - // - // FIXME(mpratt): Using Version from the init task SyscallTable - // disregards the different version a task may have (e.g., in a uts - // namespace). - ver := init.Leader().SyscallTable().Version - fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version) - return nil -} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go deleted file mode 100644 index 3d5dc463c..000000000 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "bytes" - "fmt" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// newSysDir returns the dentry corresponding to /proc/sys directory. -func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { - return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "kernel": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "hostname": newDentry(root, inoGen.NextIno(), 0444, &hostnameData{}), - "shmall": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMALL)), - "shmmax": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMAX)), - "shmmni": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMNI)), - }), - "vm": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "mmap_min_addr": newDentry(root, inoGen.NextIno(), 0444, &mmapMinAddrData{}), - "overcommit_memory": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0\n")), - }), - "net": newSysNetDir(root, inoGen, k), - }) -} - -// newSysNetDir returns the dentry corresponding to /proc/sys/net directory. -func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { - var contents map[string]*kernfs.Dentry - - // TODO(gvisor.dev/issue/1833): Support for using the network stack in the - // network namespace of the calling process. - if stack := k.RootNetworkNamespace().Stack(); stack != nil { - contents = map[string]*kernfs.Dentry{ - "ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}), - - // The following files are simple stubs until they are implemented in - // netstack, most of these files are configuration related. We use the - // value closest to the actual netstack behavior or any empty file, all - // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("16000 65535")), - "ip_local_reserved_ports": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")), - "ipfrag_time": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("30")), - "ip_nonlocal_bind": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "ip_no_pmtu_disc": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), - - // tcp_allowed_congestion_control tell the user what they are able to - // do as an unprivledged process so we leave it empty. - "tcp_allowed_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")), - "tcp_available_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("reno")), - "tcp_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("reno")), - - // Many of the following stub files are features netstack doesn't - // support. The unsupported features return "0" to indicate they are - // disabled. - "tcp_base_mss": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1280")), - "tcp_dsack": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_early_retrans": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_fack": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_fastopen": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_fastopen_key": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")), - "tcp_invalid_ratelimit": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_keepalive_intvl": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_keepalive_probes": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_keepalive_time": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("7200")), - "tcp_mtu_probing": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_no_metrics_save": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), - "tcp_probe_interval": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_probe_threshold": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_retries1": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("3")), - "tcp_retries2": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("15")), - "tcp_rfc1337": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), - "tcp_slow_start_after_idle": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), - "tcp_synack_retries": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("5")), - "tcp_syn_retries": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("3")), - "tcp_timestamps": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), - }), - "core": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "default_qdisc": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("pfifo_fast")), - "message_burst": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("10")), - "message_cost": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("5")), - "optmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "rmem_default": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")), - "rmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")), - "somaxconn": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("128")), - "wmem_default": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")), - "wmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")), - }), - } - } - - return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "net": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, contents), - }) -} - -// mmapMinAddrData implements vfs.DynamicBytesSource for -// /proc/sys/vm/mmap_min_addr. -// -// +stateify savable -type mmapMinAddrData struct { - kernfs.DynamicBytesFile - - k *kernel.Kernel -} - -var _ dynamicInode = (*mmapMinAddrData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress()) - return nil -} - -// hostnameData implements vfs.DynamicBytesSource for /proc/sys/kernel/hostname. -// -// +stateify savable -type hostnameData struct { - kernfs.DynamicBytesFile -} - -var _ dynamicInode = (*hostnameData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (*hostnameData) Generate(ctx context.Context, buf *bytes.Buffer) error { - utsns := kernel.UTSNamespaceFromContext(ctx) - buf.WriteString(utsns.HostName()) - buf.WriteString("\n") - return nil -} - -// tcpSackData implements vfs.WritableDynamicBytesSource for -// /proc/sys/net/tcp_sack. -// -// +stateify savable -type tcpSackData struct { - kernfs.DynamicBytesFile - - stack inet.Stack `state:"wait"` - enabled *bool -} - -var _ vfs.WritableDynamicBytesSource = (*tcpSackData)(nil) - -// Generate implements vfs.DynamicBytesSource. -func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error { - if d.enabled == nil { - sack, err := d.stack.TCPSACKEnabled() - if err != nil { - return err - } - d.enabled = &sack - } - - val := "0\n" - if *d.enabled { - // Technically, this is not quite compatible with Linux. Linux stores these - // as an integer, so if you write "2" into tcp_sack, you should get 2 back. - // Tough luck. - val = "1\n" - } - buf.WriteString(val) - return nil -} - -func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { - if offset != 0 { - // No need to handle partial writes thus far. - return 0, syserror.EINVAL - } - if src.NumBytes() == 0 { - return 0, nil - } - - // Limit the amount of memory allocated. - src = src.TakeFirst(usermem.PageSize - 1) - - var v int32 - n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) - if err != nil { - return n, err - } - if d.enabled == nil { - d.enabled = new(bool) - } - *d.enabled = v != 0 - return n, d.stack.SetTCPSACKEnabled(*d.enabled) -} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go deleted file mode 100644 index be54897bb..000000000 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "bytes" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/inet" -) - -func newIPv6TestStack() *inet.TestStack { - s := inet.NewTestStack() - s.SupportsIPv6Flag = true - return s -} - -func TestIfinet6NoAddresses(t *testing.T) { - n := &ifinet6{stack: newIPv6TestStack()} - var buf bytes.Buffer - n.Generate(contexttest.Context(t), &buf) - if buf.Len() > 0 { - t.Errorf("n.Generate() generated = %v, want = %v", buf.Bytes(), []byte{}) - } -} - -func TestIfinet6(t *testing.T) { - s := newIPv6TestStack() - s.InterfacesMap[1] = inet.Interface{Name: "eth0"} - s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"), - }, - } - s.InterfacesMap[2] = inet.Interface{Name: "eth1"} - s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"), - }, - } - want := map[string]struct{}{ - "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {}, - "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {}, - } - - n := &ifinet6{stack: s} - contents := n.contents() - if len(contents) != len(want) { - t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want)) - } - got := map[string]struct{}{} - for _, l := range contents { - got[l] = struct{}{} - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("Got n.contents() = %v, want = %v", got, want) - } -} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go deleted file mode 100644 index 1bb9430c0..000000000 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ /dev/null @@ -1,456 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "fmt" - "math" - "path" - "strconv" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -var ( - // Next offset 256 by convention. Adds 1 for the next offset. - selfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 0 + 1} - threadSelfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 1 + 1} - - // /proc/[pid] next offset starts at 256+2 (files above), then adds the - // PID, and adds 1 for the next offset. - proc1 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 1 + 1} - proc2 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 2 + 1} - proc3 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 3 + 1} -) - -var ( - tasksStaticFiles = map[string]testutil.DirentType{ - "cpuinfo": linux.DT_REG, - "loadavg": linux.DT_REG, - "meminfo": linux.DT_REG, - "mounts": linux.DT_LNK, - "net": linux.DT_LNK, - "self": linux.DT_LNK, - "stat": linux.DT_REG, - "sys": linux.DT_DIR, - "thread-self": linux.DT_LNK, - "uptime": linux.DT_REG, - "version": linux.DT_REG, - } - tasksStaticFilesNextOffs = map[string]int64{ - "self": selfLink.NextOff, - "thread-self": threadSelfLink.NextOff, - } - taskStaticFiles = map[string]testutil.DirentType{ - "auxv": linux.DT_REG, - "cgroup": linux.DT_REG, - "cmdline": linux.DT_REG, - "comm": linux.DT_REG, - "environ": linux.DT_REG, - "gid_map": linux.DT_REG, - "io": linux.DT_REG, - "maps": linux.DT_REG, - "net": linux.DT_DIR, - "ns": linux.DT_DIR, - "oom_score": linux.DT_REG, - "oom_score_adj": linux.DT_REG, - "smaps": linux.DT_REG, - "stat": linux.DT_REG, - "statm": linux.DT_REG, - "status": linux.DT_REG, - "task": linux.DT_DIR, - "uid_map": linux.DT_REG, - } -) - -func setup(t *testing.T) *testutil.System { - k, err := testutil.Boot() - if err != nil { - t.Fatalf("Error creating kernel: %v", err) - } - - ctx := k.SupervisorContext() - creds := auth.CredentialsFromContext(ctx) - - k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - fsOpts := vfs.GetFilesystemOptions{ - InternalData: &InternalData{ - Cgroups: map[string]string{ - "cpuset": "/foo/cpuset", - "memory": "/foo/memory", - }, - }, - } - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", Name, &fsOpts) - if err != nil { - t.Fatalf("NewMountNamespace(): %v", err) - } - return testutil.NewSystem(ctx, t, k.VFS(), mntns) -} - -func TestTasksEmpty(t *testing.T) { - s := setup(t) - defer s.Destroy() - - collector := s.ListDirents(s.PathOpAtRoot("/")) - s.AssertAllDirentTypes(collector, tasksStaticFiles) - s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs) -} - -func TestTasks(t *testing.T) { - s := setup(t) - defer s.Destroy() - - expectedDirents := make(map[string]testutil.DirentType) - for n, d := range tasksStaticFiles { - expectedDirents[n] = d - } - - k := kernel.KernelFromContext(s.Ctx) - var tasks []*kernel.Task - for i := 0; i < 5; i++ { - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatalf("CreateTask(): %v", err) - } - tasks = append(tasks, task) - expectedDirents[fmt.Sprintf("%d", i+1)] = linux.DT_DIR - } - - collector := s.ListDirents(s.PathOpAtRoot("/")) - s.AssertAllDirentTypes(collector, expectedDirents) - s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs) - - lastPid := 0 - dirents := collector.OrderedDirents() - doneSkippingNonTaskDirs := false - for _, d := range dirents { - pid, err := strconv.Atoi(d.Name) - if err != nil { - if !doneSkippingNonTaskDirs { - // We haven't gotten to the task dirs yet. - continue - } - t.Fatalf("Invalid process directory %q", d.Name) - } - doneSkippingNonTaskDirs = true - if lastPid > pid { - t.Errorf("pids not in order: %v", dirents) - } - found := false - for _, t := range tasks { - if k.TaskSet().Root.IDOfTask(t) == kernel.ThreadID(pid) { - found = true - } - } - if !found { - t.Errorf("Additional task ID %d listed: %v", pid, tasks) - } - // Next offset starts at 256+2 ('self' and 'thread-self'), then adds the - // PID, and adds 1 for the next offset. - if want := int64(256 + 2 + pid + 1); d.NextOff != want { - t.Errorf("Wrong dirent offset want: %d got: %d: %+v", want, d.NextOff, d) - } - } - if !doneSkippingNonTaskDirs { - t.Fatalf("Never found any process directories.") - } - - // Test lookup. - for _, path := range []string{"/1", "/2"} { - fd, err := s.VFS.OpenAt( - s.Ctx, - s.Creds, - s.PathOpAtRoot(path), - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err) - } - buf := make([]byte, 1) - bufIOSeq := usermem.BytesIOSequence(buf) - if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR { - t.Errorf("wrong error reading directory: %v", err) - } - } - - if _, err := s.VFS.OpenAt( - s.Ctx, - s.Creds, - s.PathOpAtRoot("/9999"), - &vfs.OpenOptions{}, - ); err != syserror.ENOENT { - t.Fatalf("wrong error from vfsfs.OpenAt(/9999): %v", err) - } -} - -func TestTasksOffset(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - for i := 0; i < 3; i++ { - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - if _, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root); err != nil { - t.Fatalf("CreateTask(): %v", err) - } - } - - for _, tc := range []struct { - name string - offset int64 - wants map[string]vfs.Dirent - }{ - { - name: "small offset", - offset: 100, - wants: map[string]vfs.Dirent{ - "self": selfLink, - "thread-self": threadSelfLink, - "1": proc1, - "2": proc2, - "3": proc3, - }, - }, - { - name: "offset at start", - offset: 256, - wants: map[string]vfs.Dirent{ - "self": selfLink, - "thread-self": threadSelfLink, - "1": proc1, - "2": proc2, - "3": proc3, - }, - }, - { - name: "skip /proc/self", - offset: 257, - wants: map[string]vfs.Dirent{ - "thread-self": threadSelfLink, - "1": proc1, - "2": proc2, - "3": proc3, - }, - }, - { - name: "skip symlinks", - offset: 258, - wants: map[string]vfs.Dirent{ - "1": proc1, - "2": proc2, - "3": proc3, - }, - }, - { - name: "skip first process", - offset: 260, - wants: map[string]vfs.Dirent{ - "2": proc2, - "3": proc3, - }, - }, - { - name: "last process", - offset: 261, - wants: map[string]vfs.Dirent{ - "3": proc3, - }, - }, - { - name: "after last", - offset: 262, - wants: nil, - }, - { - name: "TaskLimit+1", - offset: kernel.TasksLimit + 1, - wants: nil, - }, - { - name: "max", - offset: math.MaxInt64, - wants: nil, - }, - } { - t.Run(tc.name, func(t *testing.T) { - s := s.WithSubtest(t) - fd, err := s.VFS.OpenAt( - s.Ctx, - s.Creds, - s.PathOpAtRoot("/"), - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) - } - if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil { - t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err) - } - - var collector testutil.DirentCollector - if err := fd.IterDirents(s.Ctx, &collector); err != nil { - t.Fatalf("IterDirent(): %v", err) - } - - expectedTypes := make(map[string]testutil.DirentType) - expectedOffsets := make(map[string]int64) - for name, want := range tc.wants { - expectedTypes[name] = want.Type - if want.NextOff != 0 { - expectedOffsets[name] = want.NextOff - } - } - - collector.SkipDotsChecks(true) // We seek()ed past the dots. - s.AssertAllDirentTypes(&collector, expectedTypes) - s.AssertDirentOffsets(&collector, expectedOffsets) - }) - } -} - -func TestTask(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - _, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatalf("CreateTask(): %v", err) - } - - collector := s.ListDirents(s.PathOpAtRoot("/1")) - s.AssertAllDirentTypes(collector, taskStaticFiles) -} - -func TestProcSelf(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - task, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatalf("CreateTask(): %v", err) - } - - collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{ - Root: s.Root, - Start: s.Root, - Path: fspath.Parse("/self/"), - FollowFinalSymlink: true, - }) - s.AssertAllDirentTypes(collector, taskStaticFiles) -} - -func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.FileDescription) { - t.Logf("Iterating: /proc%s", fd.MappedName(ctx)) - - var collector testutil.DirentCollector - if err := fd.IterDirents(ctx, &collector); err != nil { - t.Fatalf("IterDirents(): %v", err) - } - if err := collector.Contains(".", linux.DT_DIR); err != nil { - t.Error(err.Error()) - } - if err := collector.Contains("..", linux.DT_DIR); err != nil { - t.Error(err.Error()) - } - - for _, d := range collector.Dirents() { - if d.Name == "." || d.Name == ".." { - continue - } - childPath := path.Join(fd.MappedName(ctx), d.Name) - if d.Type == linux.DT_LNK { - link, err := s.VFS.ReadlinkAt( - ctx, - auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse(childPath)}, - ) - if err != nil { - t.Errorf("vfsfs.ReadlinkAt(%v) failed: %v", childPath, err) - } else { - t.Logf("Skipping symlink: /proc%s => %s", childPath, link) - } - continue - } - - t.Logf("Opening: /proc%s", childPath) - child, err := s.VFS.OpenAt( - ctx, - auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse(childPath)}, - &vfs.OpenOptions{}, - ) - if err != nil { - t.Errorf("vfsfs.OpenAt(%v) failed: %v", childPath, err) - continue - } - stat, err := child.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Errorf("Stat(%v) failed: %v", childPath, err) - } - if got := linux.FileMode(stat.Mode).DirentType(); got != d.Type { - t.Errorf("wrong file mode, stat: %v, dirent: %v", got, d.Type) - } - if d.Type == linux.DT_DIR { - // Found another dir, let's do it again! - iterateDir(ctx, t, s, child) - } - } -} - -// TestTree iterates all directories and stats every file. -func TestTree(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - var tasks []*kernel.Task - for i := 0; i < 5; i++ { - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatalf("CreateTask(): %v", err) - } - tasks = append(tasks, task) - } - - ctx := tasks[0] - fd, err := s.VFS.OpenAt( - ctx, - auth.CredentialsFromContext(s.Ctx), - &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/")}, - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) - } - iterateDir(ctx, t, s, fd) -} diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD deleted file mode 100644 index a741e2bb6..000000000 --- a/pkg/sentry/fsimpl/sys/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "sys", - srcs = [ - "sys.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/syserror", - ], -) - -go_test( - name = "sys_test", - srcs = ["sys_test.go"], - deps = [ - ":sys", - "//pkg/abi/linux", - "//pkg/sentry/fsimpl/testutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go deleted file mode 100644 index c36c4fa11..000000000 --- a/pkg/sentry/fsimpl/sys/sys.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package sys implements sysfs. -package sys - -import ( - "bytes" - "fmt" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// Name is the default filesystem name. -const Name = "sysfs" - -// FilesystemType implements vfs.FilesystemType. -type FilesystemType struct{} - -// filesystem implements vfs.FilesystemImpl. -type filesystem struct { - kernfs.Filesystem -} - -// GetFilesystem implements vfs.FilesystemType.GetFilesystem. -func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - fs := &filesystem{} - fs.Filesystem.Init(vfsObj) - k := kernel.KernelFromContext(ctx) - maxCPUCores := k.ApplicationCores() - defaultSysDirMode := linux.FileMode(0755) - - root := fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ - "block": fs.newDir(creds, defaultSysDirMode, nil), - "bus": fs.newDir(creds, defaultSysDirMode, nil), - "class": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ - "power_supply": fs.newDir(creds, defaultSysDirMode, nil), - }), - "dev": fs.newDir(creds, defaultSysDirMode, nil), - "devices": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ - "system": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ - "cpu": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ - "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - }), - }), - }), - "firmware": fs.newDir(creds, defaultSysDirMode, nil), - "fs": fs.newDir(creds, defaultSysDirMode, nil), - "kernel": fs.newDir(creds, defaultSysDirMode, nil), - "module": fs.newDir(creds, defaultSysDirMode, nil), - "power": fs.newDir(creds, defaultSysDirMode, nil), - }) - return fs.VFSFilesystem(), root.VFSDentry(), nil -} - -// dir implements kernfs.Inode. -type dir struct { - kernfs.InodeAttrs - kernfs.InodeNoDynamicLookup - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren - - kernfs.OrderedChildren - dentry kernfs.Dentry -} - -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { - d := &dir{} - d.InodeAttrs.Init(creds, fs.NextIno(), linux.ModeDirectory|0755) - d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - d.dentry.Init(d) - - d.IncLinks(d.OrderedChildren.Populate(&d.dentry, contents)) - - return &d.dentry -} - -// SetStat implements kernfs.Inode.SetStat. -func (d *dir) SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error { - return syserror.EPERM -} - -// Open implements kernfs.Inode.Open. -func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &kernfs.GenericDirectoryFD{} - fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts) - return fd.VFSFileDescription(), nil -} - -// cpuFile implements kernfs.Inode. -type cpuFile struct { - kernfs.DynamicBytesFile - maxCores uint -} - -// Generate implements vfs.DynamicBytesSource.Generate. -func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "0-%d", c.maxCores-1) - return nil -} - -func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) *kernfs.Dentry { - c := &cpuFile{maxCores: maxCores} - c.DynamicBytesFile.Init(creds, fs.NextIno(), c, mode) - d := &kernfs.Dentry{} - d.Init(c) - return d -} diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go deleted file mode 100644 index 4b3602d47..000000000 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sys_test - -import ( - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -func newTestSystem(t *testing.T) *testutil.System { - k, err := testutil.Boot() - if err != nil { - t.Fatalf("Failed to create test kernel: %v", err) - } - ctx := k.SupervisorContext() - creds := auth.CredentialsFromContext(ctx) - k.VFS().MustRegisterFilesystemType(sys.Name, sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - - mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{}) - if err != nil { - t.Fatalf("Failed to create new mount namespace: %v", err) - } - return testutil.NewSystem(ctx, t, k.VFS(), mns) -} - -func TestReadCPUFile(t *testing.T) { - s := newTestSystem(t) - defer s.Destroy() - k := kernel.KernelFromContext(s.Ctx) - maxCPUCores := k.ApplicationCores() - - expected := fmt.Sprintf("0-%d", maxCPUCores-1) - - for _, fname := range []string{"online", "possible", "present"} { - pop := s.PathOpAtRoot(fmt.Sprintf("devices/system/cpu/%s", fname)) - fd, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, &vfs.OpenOptions{}) - if err != nil { - t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) - } - defer fd.DecRef() - content, err := s.ReadToEnd(fd) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - if diff := cmp.Diff(expected, content); diff != "" { - t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff) - } - } -} - -func TestSysRootContainsExpectedEntries(t *testing.T) { - s := newTestSystem(t) - defer s.Destroy() - pop := s.PathOpAtRoot("/") - s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ - "block": linux.DT_DIR, - "bus": linux.DT_DIR, - "class": linux.DT_DIR, - "dev": linux.DT_DIR, - "devices": linux.DT_DIR, - "firmware": linux.DT_DIR, - "fs": linux.DT_DIR, - "kernel": linux.DT_DIR, - "module": linux.DT_DIR, - "power": linux.DT_DIR, - }) -} diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD deleted file mode 100644 index e4f36f4ae..000000000 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "testutil", - testonly = 1, - srcs = [ - "kernel.go", - "testutil.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/cpuid", - "//pkg/fspath", - "//pkg/memutil", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/platform/kvm", - "//pkg/sentry/platform/ptrace", - "//pkg/sentry/time", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go deleted file mode 100644 index 488478e29..000000000 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "flag" - "fmt" - "os" - "runtime" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/cpuid" - "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/sched" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/loader" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/time" - "gvisor.dev/gvisor/pkg/sentry/vfs" - - // Platforms are plugable. - _ "gvisor.dev/gvisor/pkg/sentry/platform/kvm" - _ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace" -) - -var ( - platformFlag = flag.String("platform", "ptrace", "specify which platform to use") -) - -// Boot initializes a new bare bones kernel for test. -func Boot() (*kernel.Kernel, error) { - platformCtr, err := platform.Lookup(*platformFlag) - if err != nil { - return nil, fmt.Errorf("platform not found: %v", err) - } - deviceFile, err := platformCtr.OpenDevice() - if err != nil { - return nil, fmt.Errorf("creating platform: %v", err) - } - plat, err := platformCtr.New(deviceFile) - if err != nil { - return nil, fmt.Errorf("creating platform: %v", err) - } - - k := &kernel.Kernel{ - Platform: plat, - } - - mf, err := createMemoryFile() - if err != nil { - return nil, err - } - k.SetMemoryFile(mf) - - // Pass k as the platform since it is savable, unlike the actual platform. - vdso, err := loader.PrepareVDSO(nil, k) - if err != nil { - return nil, fmt.Errorf("creating vdso: %v", err) - } - - // Create timekeeper. - tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange()) - if err != nil { - return nil, fmt.Errorf("creating timekeeper: %v", err) - } - tk.SetClocks(time.NewCalibratedClocks()) - - creds := auth.NewRootCredentials(auth.NewRootUserNamespace()) - - // Initiate the Kernel object, which is required by the Context passed - // to createVFS in order to mount (among other things) procfs. - if err = k.Init(kernel.InitKernelArgs{ - ApplicationCores: uint(runtime.GOMAXPROCS(-1)), - FeatureSet: cpuid.HostFeatureSet(), - Timekeeper: tk, - RootUserNamespace: creds.UserNamespace, - Vdso: vdso, - RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace), - RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace), - RootAbstractSocketNamespace: kernel.NewAbstractSocketNamespace(), - PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace), - }); err != nil { - return nil, fmt.Errorf("initializing kernel: %v", err) - } - - kernel.VFS2Enabled = true - - if err := k.VFS().Init(); err != nil { - return nil, fmt.Errorf("VFS init: %v", err) - } - k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - AllowUserList: true, - }) - - ls, err := limits.NewLinuxLimitSet() - if err != nil { - return nil, err - } - tg := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls) - k.TestOnly_SetGlobalInit(tg) - - return k, nil -} - -// CreateTask creates a new bare bones task for tests. -func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) { - k := kernel.KernelFromContext(ctx) - config := &kernel.TaskConfig{ - Kernel: k, - ThreadGroup: tc, - TaskContext: &kernel.TaskContext{Name: name}, - Credentials: auth.CredentialsFromContext(ctx), - NetworkNamespace: k.RootNetworkNamespace(), - AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), - UTSNamespace: kernel.UTSNamespaceFromContext(ctx), - IPCNamespace: kernel.IPCNamespaceFromContext(ctx), - AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(), - MountNamespaceVFS2: mntns, - FSContext: kernel.NewFSContextVFS2(root, cwd, 0022), - } - return k.TaskSet().NewTask(config) -} - -func createMemoryFile() (*pgalloc.MemoryFile, error) { - const memfileName = "test-memory" - memfd, err := memutil.CreateMemFD(memfileName, 0) - if err != nil { - return nil, fmt.Errorf("error creating memfd: %v", err) - } - memfile := os.NewFile(uintptr(memfd), memfileName) - mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) - if err != nil { - memfile.Close() - return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err) - } - return mf, nil -} diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go deleted file mode 100644 index e16808c63..000000000 --- a/pkg/sentry/fsimpl/testutil/testutil.go +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil provides common test utilities for kernfs-based -// filesystems. -package testutil - -import ( - "fmt" - "io" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" -) - -// System represents the context for a single test. -// -// Test systems must be explicitly destroyed with System.Destroy. -type System struct { - t *testing.T - Ctx context.Context - Creds *auth.Credentials - VFS *vfs.VirtualFilesystem - Root vfs.VirtualDentry - MntNs *vfs.MountNamespace -} - -// NewSystem constructs a System. -// -// Precondition: Caller must hold a reference on MntNs, whose ownership -// is transferred to the new System. -func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns *vfs.MountNamespace) *System { - s := &System{ - t: t, - Ctx: ctx, - Creds: auth.CredentialsFromContext(ctx), - VFS: v, - MntNs: mns, - Root: mns.Root(), - } - return s -} - -// WithSubtest creates a temporary test system with a new test harness, -// referencing all other resources from the original system. This is useful when -// a system is reused for multiple subtests, and the T needs to change for each -// case. Note that this is safe when test cases run in parallel, as all -// resources referenced by the system are immutable, or handle interior -// mutations in a thread-safe manner. -// -// The returned system must not outlive the original and should not be destroyed -// via System.Destroy. -func (s *System) WithSubtest(t *testing.T) *System { - return &System{ - t: t, - Ctx: s.Ctx, - Creds: s.Creds, - VFS: s.VFS, - MntNs: s.MntNs, - Root: s.Root, - } -} - -// WithTemporaryContext constructs a temporary test system with a new context -// ctx. The temporary system borrows all resources and references from the -// original system. The returned temporary system must not outlive the original -// system, and should not be destroyed via System.Destroy. -func (s *System) WithTemporaryContext(ctx context.Context) *System { - return &System{ - t: s.t, - Ctx: ctx, - Creds: s.Creds, - VFS: s.VFS, - MntNs: s.MntNs, - Root: s.Root, - } -} - -// Destroy release resources associated with a test system. -func (s *System) Destroy() { - s.Root.DecRef() - s.MntNs.DecRef() // Reference on MntNs passed to NewSystem. -} - -// ReadToEnd reads the contents of fd until EOF to a string. -func (s *System) ReadToEnd(fd *vfs.FileDescription) (string, error) { - buf := make([]byte, usermem.PageSize) - bufIOSeq := usermem.BytesIOSequence(buf) - opts := vfs.ReadOptions{} - - var content strings.Builder - for { - n, err := fd.Read(s.Ctx, bufIOSeq, opts) - if n == 0 || err != nil { - if err == io.EOF { - err = nil - } - return content.String(), err - } - content.Write(buf[:n]) - } -} - -// PathOpAtRoot constructs a PathOperation with the given path from -// the root of the filesystem. -func (s *System) PathOpAtRoot(path string) *vfs.PathOperation { - return &vfs.PathOperation{ - Root: s.Root, - Start: s.Root, - Path: fspath.Parse(path), - } -} - -// GetDentryOrDie attempts to resolve a dentry referred to by the -// provided path operation. If unsuccessful, the test fails. -func (s *System) GetDentryOrDie(pop *vfs.PathOperation) vfs.VirtualDentry { - vd, err := s.VFS.GetDentryAt(s.Ctx, s.Creds, pop, &vfs.GetDentryOptions{}) - if err != nil { - s.t.Fatalf("GetDentryAt(pop:%+v) failed: %v", pop, err) - } - return vd -} - -// DirentType is an alias for values for linux_dirent64.d_type. -type DirentType = uint8 - -// ListDirents lists the Dirents for a directory at pop. -func (s *System) ListDirents(pop *vfs.PathOperation) *DirentCollector { - fd, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, &vfs.OpenOptions{Flags: linux.O_RDONLY}) - if err != nil { - s.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) - } - defer fd.DecRef() - - collector := &DirentCollector{} - if err := fd.IterDirents(s.Ctx, collector); err != nil { - s.t.Fatalf("IterDirent failed: %v", err) - } - return collector -} - -// AssertAllDirentTypes verifies that the set of dirents in collector contains -// exactly the specified set of expected entries. AssertAllDirentTypes respects -// collector.skipDots, and implicitly checks for "." and ".." accordingly. -func (s *System) AssertAllDirentTypes(collector *DirentCollector, expected map[string]DirentType) { - // Also implicitly check for "." and "..", if enabled. - if !collector.skipDots { - expected["."] = linux.DT_DIR - expected[".."] = linux.DT_DIR - } - - dentryTypes := make(map[string]DirentType) - collector.mu.Lock() - for _, dirent := range collector.dirents { - dentryTypes[dirent.Name] = dirent.Type - } - collector.mu.Unlock() - if diff := cmp.Diff(expected, dentryTypes); diff != "" { - s.t.Fatalf("IterDirent had unexpected results:\n--- want\n+++ got\n%v", diff) - } -} - -// AssertDirentOffsets verifies that collector contains at least the entries -// specified in expected, with the given NextOff field. Entries specified in -// expected but missing from collector result in failure. Extra entries in -// collector are ignored. AssertDirentOffsets respects collector.skipDots, and -// implicitly checks for "." and ".." accordingly. -func (s *System) AssertDirentOffsets(collector *DirentCollector, expected map[string]int64) { - // Also implicitly check for "." and "..", if enabled. - if !collector.skipDots { - expected["."] = 1 - expected[".."] = 2 - } - - dentryNextOffs := make(map[string]int64) - collector.mu.Lock() - for _, dirent := range collector.dirents { - // Ignore extra entries in dentries that are not in expected. - if _, ok := expected[dirent.Name]; ok { - dentryNextOffs[dirent.Name] = dirent.NextOff - } - } - collector.mu.Unlock() - if diff := cmp.Diff(expected, dentryNextOffs); diff != "" { - s.t.Fatalf("IterDirent had unexpected results:\n--- want\n+++ got\n%v", diff) - } -} - -// DirentCollector provides an implementation for vfs.IterDirentsCallback for -// testing. It simply iterates to the end of a given directory FD and collects -// all dirents emitted by the callback. -type DirentCollector struct { - mu sync.Mutex - order []*vfs.Dirent - dirents map[string]*vfs.Dirent - // When the collector is used in various Assert* functions, should "." and - // ".." be implicitly checked? - skipDots bool -} - -// SkipDotsChecks enables or disables the implicit checks on "." and ".." when -// the collector is used in various Assert* functions. Note that "." and ".." -// are still collected if passed to d.Handle, so the caller should only disable -// the checks when they aren't expected. -func (d *DirentCollector) SkipDotsChecks(value bool) { - d.skipDots = value -} - -// Handle implements vfs.IterDirentsCallback.Handle. -func (d *DirentCollector) Handle(dirent vfs.Dirent) error { - d.mu.Lock() - if d.dirents == nil { - d.dirents = make(map[string]*vfs.Dirent) - } - d.order = append(d.order, &dirent) - d.dirents[dirent.Name] = &dirent - d.mu.Unlock() - return nil -} - -// Count returns the number of dirents currently in the collector. -func (d *DirentCollector) Count() int { - d.mu.Lock() - defer d.mu.Unlock() - return len(d.dirents) -} - -// Contains checks whether the collector has a dirent with the given name and -// type. -func (d *DirentCollector) Contains(name string, typ uint8) error { - d.mu.Lock() - defer d.mu.Unlock() - dirent, ok := d.dirents[name] - if !ok { - return fmt.Errorf("No dirent named %q found", name) - } - if dirent.Type != typ { - return fmt.Errorf("Dirent named %q found, but was expecting type %s, got: %+v", name, linux.DirentType.Parse(uint64(typ)), dirent) - } - return nil -} - -// Dirents returns all dirents discovered by this collector. -func (d *DirentCollector) Dirents() map[string]*vfs.Dirent { - d.mu.Lock() - dirents := make(map[string]*vfs.Dirent) - for n, d := range d.dirents { - dirents[n] = d - } - d.mu.Unlock() - return dirents -} - -// OrderedDirents returns an ordered list of dirents as discovered by this -// collector. -func (d *DirentCollector) OrderedDirents() []*vfs.Dirent { - d.mu.Lock() - dirents := make([]*vfs.Dirent, len(d.order)) - copy(dirents, d.order) - d.mu.Unlock() - return dirents -} diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD deleted file mode 100644 index 57abd5583..000000000 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ /dev/null @@ -1,98 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "dentry_list", - out = "dentry_list.go", - package = "tmpfs", - prefix = "dentry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*dentry", - "Linker": "*dentry", - }, -) - -go_library( - name = "tmpfs", - srcs = [ - "dentry_list.go", - "device_file.go", - "directory.go", - "filesystem.go", - "named_pipe.go", - "regular_file.go", - "symlink.go", - "tmpfs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/context", - "//pkg/fspath", - "//pkg/log", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) - -go_test( - name = "benchmark_test", - size = "small", - srcs = ["benchmark_test.go"], - deps = [ - ":tmpfs", - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/refs", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/syserror", - ], -) - -go_test( - name = "tmpfs_test", - size = "small", - srcs = [ - "pipe_test.go", - "regular_file_test.go", - "stat_test.go", - ], - library = ":tmpfs", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/contexttest", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go deleted file mode 100644 index 383133e44..000000000 --- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go +++ /dev/null @@ -1,493 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package benchmark_test - -import ( - "fmt" - "runtime" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// Differences from stat_benchmark: -// -// - Syscall interception, CopyInPath, copyOutStat, and overlayfs overheads are -// not included. -// -// - *MountStat benchmarks use a tmpfs root mount and a tmpfs submount at /tmp. -// Non-MountStat benchmarks use a tmpfs root mount and no submounts. -// stat_benchmark uses a varying root mount, a tmpfs submount at /tmp, and a -// subdirectory /tmp/<top_dir> (assuming TEST_TMPDIR == "/tmp"). Thus -// stat_benchmark at depth 1 does a comparable amount of work to *MountStat -// benchmarks at depth 2, and non-MountStat benchmarks at depth 3. -var depths = []int{1, 2, 3, 8, 64, 100} - -const ( - mountPointName = "tmp" - filename = "gvisor_test_temp_0_1557494568" -) - -// This is copied from syscalls/linux/sys_file.go, with the dependency on -// kernel.Task stripped out. -func fileOpOn(ctx context.Context, mntns *fs.MountNamespace, root, wd *fs.Dirent, dirFD int32, path string, resolve bool, fn func(root *fs.Dirent, d *fs.Dirent) error) error { - var ( - d *fs.Dirent // The file. - rel *fs.Dirent // The relative directory for search (if required.) - err error - ) - - // Extract the working directory (maybe). - if len(path) > 0 && path[0] == '/' { - // Absolute path; rel can be nil. - } else if dirFD == linux.AT_FDCWD { - // Need to reference the working directory. - rel = wd - } else { - // Need to extract the given FD. - return syserror.EBADF - } - - // Lookup the node. - remainingTraversals := uint(linux.MaxSymlinkTraversals) - if resolve { - d, err = mntns.FindInode(ctx, root, rel, path, &remainingTraversals) - } else { - d, err = mntns.FindLink(ctx, root, rel, path, &remainingTraversals) - } - if err != nil { - return err - } - - err = fn(root, d) - d.DecRef() - return err -} - -func BenchmarkVFS1TmpfsStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx := contexttest.Context(b) - - // Create VFS. - tmpfsFS, ok := fs.FindFilesystem("tmpfs") - if !ok { - b.Fatalf("failed to find tmpfs filesystem type") - } - rootInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil) - if err != nil { - b.Fatalf("failed to create tmpfs root mount: %v", err) - } - mntns, err := fs.NewMountNamespace(ctx, rootInode) - if err != nil { - b.Fatalf("failed to create mount namespace: %v", err) - } - defer mntns.DecRef() - - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - - // Create nested directories with given depth. - root := mntns.Root() - defer root.DecRef() - d := root - d.IncRef() - defer d.DecRef() - for i := depth; i > 0; i-- { - name := fmt.Sprintf("%d", i) - if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil { - b.Fatalf("failed to create directory %q: %v", name, err) - } - next, err := d.Walk(ctx, root, name) - if err != nil { - b.Fatalf("failed to walk to directory %q: %v", name, err) - } - d.DecRef() - d = next - filePathBuilder.WriteString(name) - filePathBuilder.WriteByte('/') - } - - // Create the file that will be stat'd. - file, err := d.Inode.Create(ctx, d, filename, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0644)) - if err != nil { - b.Fatalf("failed to create file %q: %v", filename, err) - } - file.DecRef() - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - dirPath := false - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := fileOpOn(ctx, mntns, root, root, linux.AT_FDCWD, filePath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error { - if dirPath && !fs.IsDir(d.Inode.StableAttr) { - return syserror.ENOTDIR - } - uattr, err := d.Inode.UnstableAttr(ctx) - if err != nil { - return err - } - // Sanity check. - if uattr.Perms.User.Execute { - b.Fatalf("got wrong permissions (%0o)", uattr.Perms.LinuxMode()) - } - return nil - }) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - } - // Don't include deferred cleanup in benchmark time. - b.StopTimer() - }) - } -} - -func BenchmarkVFS2MemfsStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx := contexttest.Context(b) - creds := auth.CredentialsFromContext(ctx) - - // Create VFS. - vfsObj := vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - b.Fatalf("VFS init: %v", err) - } - vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) - if err != nil { - b.Fatalf("failed to create tmpfs root mount: %v", err) - } - defer mntns.DecRef() - - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - - // Create nested directories with given depth. - root := mntns.Root() - defer root.DecRef() - vd := root - vd.IncRef() - for i := depth; i > 0; i-- { - name := fmt.Sprintf("%d", i) - pop := vfs.PathOperation{ - Root: root, - Start: vd, - Path: fspath.Parse(name), - } - if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - b.Fatalf("failed to create directory %q: %v", name, err) - } - nextVD, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to directory %q: %v", name, err) - } - vd.DecRef() - vd = nextVD - filePathBuilder.WriteString(name) - filePathBuilder.WriteByte('/') - } - - // Create the file that will be stat'd. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: vd, - Path: fspath.Parse(filename), - FollowFinalSymlink: true, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: 0644, - }) - vd.DecRef() - vd = vfs.VirtualDentry{} - if err != nil { - b.Fatalf("failed to create file %q: %v", filename, err) - } - defer fd.DecRef() - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filePath), - FollowFinalSymlink: true, - }, &vfs.StatOptions{}) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - // Sanity check. - if stat.Mode&^linux.S_IFMT != 0644 { - b.Fatalf("got wrong permissions (%0o)", stat.Mode) - } - } - // Don't include deferred cleanup in benchmark time. - b.StopTimer() - }) - } -} - -func BenchmarkVFS1TmpfsMountStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx := contexttest.Context(b) - - // Create VFS. - tmpfsFS, ok := fs.FindFilesystem("tmpfs") - if !ok { - b.Fatalf("failed to find tmpfs filesystem type") - } - rootInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil) - if err != nil { - b.Fatalf("failed to create tmpfs root mount: %v", err) - } - mntns, err := fs.NewMountNamespace(ctx, rootInode) - if err != nil { - b.Fatalf("failed to create mount namespace: %v", err) - } - defer mntns.DecRef() - - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - - // Create and mount the submount. - root := mntns.Root() - defer root.DecRef() - if err := root.Inode.CreateDirectory(ctx, root, mountPointName, fs.FilePermsFromMode(0755)); err != nil { - b.Fatalf("failed to create mount point: %v", err) - } - mountPoint, err := root.Walk(ctx, root, mountPointName) - if err != nil { - b.Fatalf("failed to walk to mount point: %v", err) - } - defer mountPoint.DecRef() - submountInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil) - if err != nil { - b.Fatalf("failed to create tmpfs submount: %v", err) - } - if err := mntns.Mount(ctx, mountPoint, submountInode); err != nil { - b.Fatalf("failed to mount tmpfs submount: %v", err) - } - filePathBuilder.WriteString(mountPointName) - filePathBuilder.WriteByte('/') - - // Create nested directories with given depth. - d, err := root.Walk(ctx, root, mountPointName) - if err != nil { - b.Fatalf("failed to walk to mount root: %v", err) - } - defer d.DecRef() - for i := depth; i > 0; i-- { - name := fmt.Sprintf("%d", i) - if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil { - b.Fatalf("failed to create directory %q: %v", name, err) - } - next, err := d.Walk(ctx, root, name) - if err != nil { - b.Fatalf("failed to walk to directory %q: %v", name, err) - } - d.DecRef() - d = next - filePathBuilder.WriteString(name) - filePathBuilder.WriteByte('/') - } - - // Create the file that will be stat'd. - file, err := d.Inode.Create(ctx, d, filename, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0644)) - if err != nil { - b.Fatalf("failed to create file %q: %v", filename, err) - } - file.DecRef() - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - dirPath := false - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := fileOpOn(ctx, mntns, root, root, linux.AT_FDCWD, filePath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error { - if dirPath && !fs.IsDir(d.Inode.StableAttr) { - return syserror.ENOTDIR - } - uattr, err := d.Inode.UnstableAttr(ctx) - if err != nil { - return err - } - // Sanity check. - if uattr.Perms.User.Execute { - b.Fatalf("got wrong permissions (%0o)", uattr.Perms.LinuxMode()) - } - return nil - }) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - } - // Don't include deferred cleanup in benchmark time. - b.StopTimer() - }) - } -} - -func BenchmarkVFS2MemfsMountStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx := contexttest.Context(b) - creds := auth.CredentialsFromContext(ctx) - - // Create VFS. - vfsObj := vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - b.Fatalf("VFS init: %v", err) - } - vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) - if err != nil { - b.Fatalf("failed to create tmpfs root mount: %v", err) - } - defer mntns.DecRef() - - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - - // Create the mount point. - root := mntns.Root() - defer root.DecRef() - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(mountPointName), - } - if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - b.Fatalf("failed to create mount point: %v", err) - } - // Save the mount point for later use. - mountPoint, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to mount point: %v", err) - } - defer mountPoint.DecRef() - // Create and mount the submount. - if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil { - b.Fatalf("failed to mount tmpfs submount: %v", err) - } - filePathBuilder.WriteString(mountPointName) - filePathBuilder.WriteByte('/') - - // Create nested directories with given depth. - vd, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to mount root: %v", err) - } - for i := depth; i > 0; i-- { - name := fmt.Sprintf("%d", i) - pop := vfs.PathOperation{ - Root: root, - Start: vd, - Path: fspath.Parse(name), - } - if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - b.Fatalf("failed to create directory %q: %v", name, err) - } - nextVD, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to directory %q: %v", name, err) - } - vd.DecRef() - vd = nextVD - filePathBuilder.WriteString(name) - filePathBuilder.WriteByte('/') - } - - // Verify that we didn't create any directories under the mount - // point (i.e. they were all created on the submount). - firstDirName := fmt.Sprintf("%d", depth) - if child := mountPoint.Dentry().Child(firstDirName); child != nil { - b.Fatalf("created directory %q under root mount, not submount", firstDirName) - } - - // Create the file that will be stat'd. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: vd, - Path: fspath.Parse(filename), - FollowFinalSymlink: true, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: 0644, - }) - vd.DecRef() - if err != nil { - b.Fatalf("failed to create file %q: %v", filename, err) - } - fd.DecRef() - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filePath), - FollowFinalSymlink: true, - }, &vfs.StatOptions{}) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - // Sanity check. - if stat.Mode&^linux.S_IFMT != 0644 { - b.Fatalf("got wrong permissions (%0o)", stat.Mode) - } - } - // Don't include deferred cleanup in benchmark time. - b.StopTimer() - }) - } -} - -func init() { - // Turn off reference leak checking for a fair comparison between vfs1 and - // vfs2. - refs.SetLeakMode(refs.NoLeakChecking) -} diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go deleted file mode 100644 index 84b181b90..000000000 --- a/pkg/sentry/fsimpl/tmpfs/device_file.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -type deviceFile struct { - inode inode - kind vfs.DeviceKind - major uint32 - minor uint32 -} - -func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode { - file := &deviceFile{ - kind: kind, - major: major, - minor: minor, - } - file.inode.init(file, fs, creds, mode) - file.inode.nlink = 1 // from parent directory - return &file.inode -} diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go deleted file mode 100644 index b4380af38..000000000 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -type directory struct { - inode inode - - // childList is a list containing (1) child Dentries and (2) fake Dentries - // (with inode == nil) that represent the iteration position of - // directoryFDs. childList is used to support directoryFD.IterDirents() - // efficiently. childList is protected by filesystem.mu. - childList dentryList -} - -func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode { - dir := &directory{} - dir.inode.init(dir, fs, creds, mode) - dir.inode.nlink = 2 // from "." and parent directory or ".." for root - return &dir.inode -} - -func (i *inode) isDir() bool { - _, ok := i.impl.(*directory) - return ok -} - -type directoryFD struct { - fileDescription - vfs.DirectoryFileDescriptionDefaultImpl - - // Protected by filesystem.mu. - iter *dentry - off int64 -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { - if fd.iter != nil { - fs := fd.filesystem() - dir := fd.inode().impl.(*directory) - fs.mu.Lock() - dir.childList.Remove(fd.iter) - fs.mu.Unlock() - fd.iter = nil - } -} - -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { - fs := fd.filesystem() - vfsd := fd.vfsfd.VirtualDentry().Dentry() - - fs.mu.Lock() - defer fs.mu.Unlock() - - if fd.off == 0 { - if err := cb.Handle(vfs.Dirent{ - Name: ".", - Type: linux.DT_DIR, - Ino: vfsd.Impl().(*dentry).inode.ino, - NextOff: 1, - }); err != nil { - return err - } - fd.off++ - } - if fd.off == 1 { - parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode - if err := cb.Handle(vfs.Dirent{ - Name: "..", - Type: parentInode.direntType(), - Ino: parentInode.ino, - NextOff: 2, - }); err != nil { - return err - } - fd.off++ - } - - dir := vfsd.Impl().(*dentry).inode.impl.(*directory) - var child *dentry - if fd.iter == nil { - // Start iteration at the beginning of dir. - child = dir.childList.Front() - fd.iter = &dentry{} - } else { - // Continue iteration from where we left off. - child = fd.iter.Next() - dir.childList.Remove(fd.iter) - } - for child != nil { - // Skip other directoryFD iterators. - if child.inode != nil { - if err := cb.Handle(vfs.Dirent{ - Name: child.vfsd.Name(), - Type: child.inode.direntType(), - Ino: child.inode.ino, - NextOff: fd.off + 1, - }); err != nil { - dir.childList.InsertBefore(child, fd.iter) - return err - } - fd.off++ - } - child = child.Next() - } - dir.childList.PushBack(fd.iter) - return nil -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - fs := fd.filesystem() - fs.mu.Lock() - defer fs.mu.Unlock() - - switch whence { - case linux.SEEK_SET: - // Use offset as given. - case linux.SEEK_CUR: - offset += fd.off - default: - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL - } - - // If the offset isn't changing (e.g. due to lseek(0, SEEK_CUR)), don't - // seek even if doing so might reposition the iterator due to concurrent - // mutation of the directory. Compare fs/libfs.c:dcache_dir_lseek(). - if fd.off == offset { - return offset, nil - } - - fd.off = offset - // Compensate for "." and "..". - remChildren := int64(0) - if offset >= 2 { - remChildren = offset - 2 - } - - dir := fd.inode().impl.(*directory) - - // Ensure that fd.iter exists and is not linked into dir.childList. - if fd.iter == nil { - fd.iter = &dentry{} - } else { - dir.childList.Remove(fd.iter) - } - // Insert fd.iter before the remChildren'th child, or at the end of the - // list if remChildren >= number of children. - child := dir.childList.Front() - for child != nil { - // Skip other directoryFD iterators. - if child.inode != nil { - if remChildren == 0 { - dir.childList.InsertBefore(child, fd.iter) - return offset, nil - } - remChildren-- - } - child = child.Next() - } - dir.childList.PushBack(fd.iter) - return offset, nil -} diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go deleted file mode 100644 index 02637fca6..000000000 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ /dev/null @@ -1,712 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// Sync implements vfs.FilesystemImpl.Sync. -func (fs *filesystem) Sync(ctx context.Context) error { - // All filesystem state is in-memory. - return nil -} - -// stepLocked resolves rp.Component() to an existing file, starting from the -// given directory. -// -// stepLocked is loosely analogous to fs/namei.c:walk_component(). -// -// Preconditions: filesystem.mu must be locked. !rp.Done(). -func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { - if !d.inode.isDir() { - return nil, syserror.ENOTDIR - } - if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { - return nil, err - } -afterSymlink: - nextVFSD, err := rp.ResolveComponent(&d.vfsd) - if err != nil { - return nil, err - } - if nextVFSD == nil { - // Since the Dentry tree is the sole source of truth for tmpfs, if it's - // not in the Dentry tree, it doesn't exist. - return nil, syserror.ENOENT - } - next := nextVFSD.Impl().(*dentry) - if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { - // TODO(gvisor.dev/issues/1197): Symlink traversals updates - // access time. - if err := rp.HandleSymlink(symlink.target); err != nil { - return nil, err - } - goto afterSymlink // don't check the current directory again - } - rp.Advance() - return next, nil -} - -// walkParentDirLocked resolves all but the last path component of rp to an -// existing directory, starting from the given directory (which is usually -// rp.Start().Impl().(*dentry)). It does not check that the returned directory -// is searchable by the provider of rp. -// -// walkParentDirLocked is loosely analogous to Linux's -// fs/namei.c:path_parentat(). -// -// Preconditions: filesystem.mu must be locked. !rp.Done(). -func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { - for !rp.Final() { - next, err := stepLocked(rp, d) - if err != nil { - return nil, err - } - d = next - } - if !d.inode.isDir() { - return nil, syserror.ENOTDIR - } - return d, nil -} - -// resolveLocked resolves rp to an existing file. -// -// resolveLocked is loosely analogous to Linux's fs/namei.c:path_lookupat(). -// -// Preconditions: filesystem.mu must be locked. -func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) { - d := rp.Start().Impl().(*dentry) - for !rp.Done() { - next, err := stepLocked(rp, d) - if err != nil { - return nil, err - } - d = next - } - if rp.MustBeDir() && !d.inode.isDir() { - return nil, syserror.ENOTDIR - } - return d, nil -} - -// doCreateAt checks that creating a file at rp is permitted, then invokes -// create to do so. -// -// doCreateAt is loosely analogous to a conjunction of Linux's -// fs/namei.c:filename_create() and done_path_create(). -// -// Preconditions: !rp.Done(). For the final path component in rp, -// !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string) error) error { - fs.mu.Lock() - defer fs.mu.Unlock() - parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) - if err != nil { - return err - } - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { - return err - } - name := rp.Component() - if name == "." || name == ".." { - return syserror.EEXIST - } - // Call parent.vfsd.Child() instead of stepLocked() or rp.ResolveChild(), - // because if the child exists we want to return EEXIST immediately instead - // of attempting symlink/mount traversal. - if parent.vfsd.Child(name) != nil { - return syserror.EEXIST - } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } - // In memfs, the only way to cause a dentry to be disowned is by removing - // it from the filesystem, so this check is equivalent to checking if - // parent has been removed. - if parent.vfsd.IsDisowned() { - return syserror.ENOENT - } - mnt := rp.Mount() - if err := mnt.CheckBeginWrite(); err != nil { - return err - } - defer mnt.EndWrite() - return create(parent, name) -} - -// AccessAt implements vfs.Filesystem.Impl.AccessAt. -func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - d, err := resolveLocked(rp) - if err != nil { - return err - } - return d.inode.checkPermissions(creds, ats, d.inode.isDir()) -} - -// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. -func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - d, err := resolveLocked(rp) - if err != nil { - return nil, err - } - if opts.CheckSearchable { - if !d.inode.isDir() { - return nil, syserror.ENOTDIR - } - if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true /* isDir */); err != nil { - return nil, err - } - } - d.IncRef() - return &d.vfsd, nil -} - -// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. -func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - d, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) - if err != nil { - return nil, err - } - d.IncRef() - return &d.vfsd, nil -} - -// LinkAt implements vfs.FilesystemImpl.LinkAt. -func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error { - if rp.Mount() != vd.Mount() { - return syserror.EXDEV - } - d := vd.Dentry().Impl().(*dentry) - if d.inode.isDir() { - return syserror.EPERM - } - if d.inode.nlink == 0 { - return syserror.ENOENT - } - if d.inode.nlink == maxLinks { - return syserror.EMLINK - } - d.inode.incLinksLocked() - child := fs.newDentry(d.inode) - parent.vfsd.InsertChild(&child.vfsd, name) - parent.inode.impl.(*directory).childList.PushBack(child) - return nil - }) -} - -// MkdirAt implements vfs.FilesystemImpl.MkdirAt. -func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - return fs.doCreateAt(rp, true /* dir */, func(parent *dentry, name string) error { - if parent.inode.nlink == maxLinks { - return syserror.EMLINK - } - parent.inode.incLinksLocked() // from child's ".." - child := fs.newDentry(fs.newDirectory(rp.Credentials(), opts.Mode)) - parent.vfsd.InsertChild(&child.vfsd, name) - parent.inode.impl.(*directory).childList.PushBack(child) - return nil - }) -} - -// MknodAt implements vfs.FilesystemImpl.MknodAt. -func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error { - var childInode *inode - switch opts.Mode.FileType() { - case 0, linux.S_IFREG: - childInode = fs.newRegularFile(rp.Credentials(), opts.Mode) - case linux.S_IFIFO: - childInode = fs.newNamedPipe(rp.Credentials(), opts.Mode) - case linux.S_IFBLK: - childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor) - case linux.S_IFCHR: - childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor) - case linux.S_IFSOCK: - // Not yet supported. - return syserror.EPERM - default: - return syserror.EINVAL - } - child := fs.newDentry(childInode) - parent.vfsd.InsertChild(&child.vfsd, name) - parent.inode.impl.(*directory).childList.PushBack(child) - return nil - }) -} - -// OpenAt implements vfs.FilesystemImpl.OpenAt. -func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - if opts.Flags&linux.O_TMPFILE != 0 { - // Not yet supported. - return nil, syserror.EOPNOTSUPP - } - - // Handle O_CREAT and !O_CREAT separately, since in the latter case we - // don't need fs.mu for writing. - if opts.Flags&linux.O_CREAT == 0 { - fs.mu.RLock() - defer fs.mu.RUnlock() - d, err := resolveLocked(rp) - if err != nil { - return nil, err - } - return d.open(ctx, rp, &opts, false /* afterCreate */) - } - - mustCreate := opts.Flags&linux.O_EXCL != 0 - start := rp.Start().Impl().(*dentry) - fs.mu.Lock() - defer fs.mu.Unlock() - if rp.Done() { - // Reject attempts to open directories with O_CREAT. - if rp.MustBeDir() { - return nil, syserror.EISDIR - } - if mustCreate { - return nil, syserror.EEXIST - } - return start.open(ctx, rp, &opts, false /* afterCreate */) - } -afterTrailingSymlink: - parent, err := walkParentDirLocked(rp, start) - if err != nil { - return nil, err - } - // Check for search permission in the parent directory. - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { - return nil, err - } - // Reject attempts to open directories with O_CREAT. - if rp.MustBeDir() { - return nil, syserror.EISDIR - } - name := rp.Component() - if name == "." || name == ".." { - return nil, syserror.EISDIR - } - // Determine whether or not we need to create a file. - child, err := stepLocked(rp, parent) - if err == syserror.ENOENT { - // Already checked for searchability above; now check for writability. - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil { - return nil, err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return nil, err - } - defer rp.Mount().EndWrite() - // Create and open the child. - child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode)) - parent.vfsd.InsertChild(&child.vfsd, name) - parent.inode.impl.(*directory).childList.PushBack(child) - return child.open(ctx, rp, &opts, true) - } - if err != nil { - return nil, err - } - // Do we need to resolve a trailing symlink? - if !rp.Done() { - start = parent - goto afterTrailingSymlink - } - // Open existing file. - if mustCreate { - return nil, syserror.EEXIST - } - return child.open(ctx, rp, &opts, false) -} - -func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) { - ats := vfs.AccessTypesForOpenFlags(opts) - if !afterCreate { - if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil { - return nil, err - } - } - switch impl := d.inode.impl.(type) { - case *regularFile: - var fd regularFileFD - if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - if opts.Flags&linux.O_TRUNC != 0 { - if _, err := impl.truncate(0); err != nil { - return nil, err - } - } - return &fd.vfsfd, nil - case *directory: - // Can't open directories writably. - if ats&vfs.MayWrite != 0 { - return nil, syserror.EISDIR - } - var fd directoryFD - if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - return &fd.vfsfd, nil - case *symlink: - // Can't open symlinks without O_PATH (which is unimplemented). - return nil, syserror.ELOOP - case *namedPipe: - return newNamedPipeFD(ctx, impl, rp, &d.vfsd, opts.Flags) - case *deviceFile: - return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts) - default: - panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl)) - } -} - -// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. -func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - d, err := resolveLocked(rp) - if err != nil { - return "", err - } - symlink, ok := d.inode.impl.(*symlink) - if !ok { - return "", syserror.EINVAL - } - return symlink.target, nil -} - -// RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - // TODO(b/145974740): Support renameat2 flags. - return syserror.EINVAL - } - - // Resolve newParent first to verify that it's on this Mount. - fs.mu.Lock() - defer fs.mu.Unlock() - newParent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) - if err != nil { - return err - } - newName := rp.Component() - if newName == "." || newName == ".." { - return syserror.EBUSY - } - mnt := rp.Mount() - if mnt != oldParentVD.Mount() { - return syserror.EXDEV - } - if err := mnt.CheckBeginWrite(); err != nil { - return err - } - defer mnt.EndWrite() - - oldParent := oldParentVD.Dentry().Impl().(*dentry) - if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { - return err - } - // Call vfs.Dentry.Child() instead of stepLocked() or rp.ResolveChild(), - // because if the existing child is a symlink or mount point then we want - // to rename over it rather than follow it. - renamedVFSD := oldParent.vfsd.Child(oldName) - if renamedVFSD == nil { - return syserror.ENOENT - } - renamed := renamedVFSD.Impl().(*dentry) - if renamed.inode.isDir() { - if renamed == newParent || renamedVFSD.IsAncestorOf(&newParent.vfsd) { - return syserror.EINVAL - } - if oldParent != newParent { - // Writability is needed to change renamed's "..". - if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true /* isDir */); err != nil { - return err - } - } - } else { - if opts.MustBeDir || rp.MustBeDir() { - return syserror.ENOTDIR - } - } - - if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { - return err - } - replacedVFSD := newParent.vfsd.Child(newName) - var replaced *dentry - if replacedVFSD != nil { - replaced = replacedVFSD.Impl().(*dentry) - if replaced.inode.isDir() { - if !renamed.inode.isDir() { - return syserror.EISDIR - } - if replaced.vfsd.HasChildren() { - return syserror.ENOTEMPTY - } - } else { - if rp.MustBeDir() { - return syserror.ENOTDIR - } - if renamed.inode.isDir() { - return syserror.ENOTDIR - } - } - } else { - if renamed.inode.isDir() && newParent.inode.nlink == maxLinks { - return syserror.EMLINK - } - } - if newParent.vfsd.IsDisowned() { - return syserror.ENOENT - } - - // Linux places this check before some of those above; we do it here for - // simplicity, under the assumption that applications are not intentionally - // doing noop renames expecting them to succeed where non-noop renames - // would fail. - if renamedVFSD == replacedVFSD { - return nil - } - vfsObj := rp.VirtualFilesystem() - oldParentDir := oldParent.inode.impl.(*directory) - newParentDir := newParent.inode.impl.(*directory) - mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() - if err := vfsObj.PrepareRenameDentry(mntns, renamedVFSD, replacedVFSD); err != nil { - return err - } - if replaced != nil { - newParentDir.childList.Remove(replaced) - if replaced.inode.isDir() { - newParent.inode.decLinksLocked() // from replaced's ".." - } - replaced.inode.decLinksLocked() - } - oldParentDir.childList.Remove(renamed) - newParentDir.childList.PushBack(renamed) - if renamed.inode.isDir() { - oldParent.inode.decLinksLocked() - newParent.inode.incLinksLocked() - } - // TODO(gvisor.dev/issues/1197): Update timestamps and parent directory - // sizes. - vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD) - return nil -} - -// RmdirAt implements vfs.FilesystemImpl.RmdirAt. -func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { - fs.mu.Lock() - defer fs.mu.Unlock() - parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) - if err != nil { - return err - } - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { - return err - } - name := rp.Component() - if name == "." { - return syserror.EINVAL - } - if name == ".." { - return syserror.ENOTEMPTY - } - childVFSD := parent.vfsd.Child(name) - if childVFSD == nil { - return syserror.ENOENT - } - child := childVFSD.Impl().(*dentry) - if !child.inode.isDir() { - return syserror.ENOTDIR - } - if childVFSD.HasChildren() { - return syserror.ENOTEMPTY - } - mnt := rp.Mount() - if err := mnt.CheckBeginWrite(); err != nil { - return err - } - defer mnt.EndWrite() - vfsObj := rp.VirtualFilesystem() - mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() - if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil { - return err - } - parent.inode.impl.(*directory).childList.Remove(child) - parent.inode.decLinksLocked() // from child's ".." - child.inode.decLinksLocked() - vfsObj.CommitDeleteDentry(childVFSD) - return nil -} - -// SetStatAt implements vfs.FilesystemImpl.SetStatAt. -func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - d, err := resolveLocked(rp) - if err != nil { - return err - } - return d.inode.setStat(opts.Stat) -} - -// StatAt implements vfs.FilesystemImpl.StatAt. -func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - d, err := resolveLocked(rp) - if err != nil { - return linux.Statx{}, err - } - var stat linux.Statx - d.inode.statTo(&stat) - return stat, nil -} - -// StatFSAt implements vfs.FilesystemImpl.StatFSAt. -func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - _, err := resolveLocked(rp) - if err != nil { - return linux.Statfs{}, err - } - // TODO(gvisor.dev/issues/1197): Actually implement statfs. - return linux.Statfs{}, syserror.ENOSYS -} - -// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. -func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error { - child := fs.newDentry(fs.newSymlink(rp.Credentials(), target)) - parent.vfsd.InsertChild(&child.vfsd, name) - parent.inode.impl.(*directory).childList.PushBack(child) - return nil - }) -} - -// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. -func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { - fs.mu.Lock() - defer fs.mu.Unlock() - parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) - if err != nil { - return err - } - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { - return err - } - name := rp.Component() - if name == "." || name == ".." { - return syserror.EISDIR - } - childVFSD := parent.vfsd.Child(name) - if childVFSD == nil { - return syserror.ENOENT - } - child := childVFSD.Impl().(*dentry) - if child.inode.isDir() { - return syserror.EISDIR - } - if rp.MustBeDir() { - return syserror.ENOTDIR - } - mnt := rp.Mount() - if err := mnt.CheckBeginWrite(); err != nil { - return err - } - defer mnt.EndWrite() - vfsObj := rp.VirtualFilesystem() - mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() - if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil { - return err - } - parent.inode.impl.(*directory).childList.Remove(child) - child.inode.decLinksLocked() - vfsObj.CommitDeleteDentry(childVFSD) - return nil -} - -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - _, err := resolveLocked(rp) - if err != nil { - return nil, err - } - // TODO(b/127675828): support extended attributes - return nil, syserror.ENOTSUP -} - -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - _, err := resolveLocked(rp) - if err != nil { - return "", err - } - // TODO(b/127675828): support extended attributes - return "", syserror.ENOTSUP -} - -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - _, err := resolveLocked(rp) - if err != nil { - return err - } - // TODO(b/127675828): support extended attributes - return syserror.ENOTSUP -} - -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - _, err := resolveLocked(rp) - if err != nil { - return err - } - // TODO(b/127675828): support extended attributes - return syserror.ENOTSUP -} - -// PrependPath implements vfs.FilesystemImpl.PrependPath. -func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - return vfs.GenericPrependPath(vfsroot, vd, b) -} diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go deleted file mode 100644 index 0c57fdca3..000000000 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -type namedPipe struct { - inode inode - - pipe *pipe.VFSPipe -} - -// Preconditions: -// * fs.mu must be locked. -// * rp.Mount().CheckBeginWrite() has been called successfully. -func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode { - file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)} - file.inode.init(file, fs, creds, mode) - file.inode.nlink = 1 // Only the parent has a link. - return &file.inode -} - -// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented -// entirely via struct embedding. -type namedPipeFD struct { - fileDescription - - *pipe.VFSPipeFD -} - -func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { - var err error - var fd namedPipeFD - fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, vfsd, &fd.vfsfd, flags) - if err != nil { - return nil, err - } - fd.vfsfd.Init(&fd, flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}) - return &fd.vfsfd, nil -} diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go deleted file mode 100644 index 1614f2c39..000000000 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -const fileName = "mypipe" - -func TestSeparateFDs(t *testing.T) { - ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() - - // Open the read side. This is done in a concurrently because opening - // One end the pipe blocks until the other end is opened. - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - } - rfdchan := make(chan *vfs.FileDescription) - go func() { - openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY} - rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - rfdchan <- rfd - }() - - // Open the write side. - openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY} - wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != nil { - t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) - } - defer wfd.DecRef() - - rfd, ok := <-rfdchan - if !ok { - t.Fatalf("failed to open pipe for reading %q", fileName) - } - defer rfd.DecRef() - - const msg = "vamos azul" - checkEmpty(ctx, t, rfd) - checkWrite(ctx, t, wfd, msg) - checkRead(ctx, t, rfd, msg) -} - -func TestNonblockingRead(t *testing.T) { - ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() - - // Open the read side as nonblocking. - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - } - openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK} - rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != nil { - t.Fatalf("failed to open pipe for reading %q: %v", fileName, err) - } - defer rfd.DecRef() - - // Open the write side. - openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY} - wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != nil { - t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) - } - defer wfd.DecRef() - - const msg = "geh blau" - checkEmpty(ctx, t, rfd) - checkWrite(ctx, t, wfd, msg) - checkRead(ctx, t, rfd, msg) -} - -func TestNonblockingWriteError(t *testing.T) { - ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() - - // Open the write side as nonblocking, which should return ENXIO. - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - } - openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} - _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != syserror.ENXIO { - t.Fatalf("expected ENXIO, but got error: %v", err) - } -} - -func TestSingleFD(t *testing.T) { - ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() - - // Open the pipe as readable and writable. - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - } - openOpts := vfs.OpenOptions{Flags: linux.O_RDWR} - fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != nil { - t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) - } - defer fd.DecRef() - - const msg = "forza blu" - checkEmpty(ctx, t, fd) - checkWrite(ctx, t, fd, msg) - checkRead(ctx, t, fd, msg) -} - -// setup creates a VFS with a pipe in the root directory at path fileName. The -// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal -// upon failure. -func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) { - ctx := contexttest.Context(t) - creds := auth.CredentialsFromContext(ctx) - - // Create VFS. - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - t.Fatalf("VFS init: %v", err) - } - vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) - if err != nil { - t.Fatalf("failed to create tmpfs root mount: %v", err) - } - - // Create the pipe. - root := mntns.Root() - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - } - mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644} - if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil { - t.Fatalf("failed to create file %q: %v", fileName, err) - } - - // Sanity check: the file pipe exists and has the correct mode. - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - }, &vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat(%q) failed: %v", fileName, err) - } - if stat.Mode&^linux.S_IFMT != 0644 { - t.Errorf("got wrong permissions (%0o)", stat.Mode) - } - if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe { - t.Errorf("got wrong file type (%0o)", stat.Mode) - } - - return ctx, creds, vfsObj, root -} - -// checkEmpty calls t.Fatal if the pipe in fd is not empty. -func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { - readData := make([]byte, 1) - dst := usermem.BytesIOSequence(readData) - bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) - if err != syserror.ErrWouldBlock { - t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err) - } - if bytesRead != 0 { - t.Fatalf("expected to read 0 bytes, but got %d", bytesRead) - } -} - -// checkWrite calls t.Fatal if it fails to write all of msg to fd. -func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { - writeData := []byte(msg) - src := usermem.BytesIOSequence(writeData) - bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{}) - if err != nil { - t.Fatalf("error writing to pipe %q: %v", fileName, err) - } - if bytesWritten != int64(len(writeData)) { - t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten) - } -} - -// checkRead calls t.Fatal if it fails to read msg from fd. -func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { - readData := make([]byte, len(msg)) - dst := usermem.BytesIOSequence(readData) - bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) - if err != nil { - t.Fatalf("error reading from pipe %q: %v", fileName, err) - } - if bytesRead != int64(len(msg)) { - t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead) - } - if !bytes.Equal(readData, []byte(msg)) { - t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData)) - } -} diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go deleted file mode 100644 index 711442424..000000000 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ /dev/null @@ -1,570 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "fmt" - "io" - "math" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// regularFile is a regular (=S_IFREG) tmpfs file. -type regularFile struct { - inode inode - - // memFile is a platform.File used to allocate pages to this regularFile. - memFile *pgalloc.MemoryFile - - // mapsMu protects mappings. - mapsMu sync.Mutex `state:"nosave"` - - // mappings tracks mappings of the file into memmap.MappingSpaces. - // - // Protected by mapsMu. - mappings memmap.MappingSet - - // writableMappingPages tracks how many pages of virtual memory are mapped - // as potentially writable from this file. If a page has multiple mappings, - // each mapping is counted separately. - // - // This counter is susceptible to overflow as we can potentially count - // mappings from many VMAs. We count pages rather than bytes to slightly - // mitigate this. - // - // Protected by mapsMu. - writableMappingPages uint64 - - // dataMu protects the fields below. - dataMu sync.RWMutex - - // data maps offsets into the file to offsets into memFile that store - // the file's data. - // - // Protected by dataMu. - data fsutil.FileRangeSet - - // seals represents file seals on this inode. - // - // Protected by dataMu. - seals uint32 - - // size is the size of data. - // - // Protected by both dataMu and inode.mu; reading it requires holding - // either mutex, while writing requires holding both AND using atomics. - // Readers that do not require consistency (like Stat) may read the - // value atomically without holding either lock. - size uint64 -} - -func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { - file := ®ularFile{ - memFile: fs.memFile, - } - file.inode.init(file, fs, creds, mode) - file.inode.nlink = 1 // from parent directory - return &file.inode -} - -// truncate grows or shrinks the file to the given size. It returns true if the -// file size was updated. -func (rf *regularFile) truncate(newSize uint64) (bool, error) { - rf.inode.mu.Lock() - defer rf.inode.mu.Unlock() - return rf.truncateLocked(newSize) -} - -// Preconditions: rf.inode.mu must be held. -func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) { - oldSize := rf.size - if newSize == oldSize { - // Nothing to do. - return false, nil - } - - // Need to hold inode.mu and dataMu while modifying size. - rf.dataMu.Lock() - if newSize > oldSize { - // Can we grow the file? - if rf.seals&linux.F_SEAL_GROW != 0 { - rf.dataMu.Unlock() - return false, syserror.EPERM - } - // We only need to update the file size. - atomic.StoreUint64(&rf.size, newSize) - rf.dataMu.Unlock() - return true, nil - } - - // We are shrinking the file. First check if this is allowed. - if rf.seals&linux.F_SEAL_SHRINK != 0 { - rf.dataMu.Unlock() - return false, syserror.EPERM - } - - // Update the file size. - atomic.StoreUint64(&rf.size, newSize) - rf.dataMu.Unlock() - - // Invalidate past translations of truncated pages. - oldpgend := fs.OffsetPageEnd(int64(oldSize)) - newpgend := fs.OffsetPageEnd(int64(newSize)) - if newpgend < oldpgend { - rf.mapsMu.Lock() - rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ - // Compare Linux's mm/shmem.c:shmem_setattr() => - // mm/memory.c:unmap_mapping_range(evencows=1). - InvalidatePrivate: true, - }) - rf.mapsMu.Unlock() - } - - // We are now guaranteed that there are no translations of truncated pages, - // and can remove them. - rf.dataMu.Lock() - rf.data.Truncate(newSize, rf.memFile) - rf.dataMu.Unlock() - return true, nil -} - -// AddMapping implements memmap.Mappable.AddMapping. -func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { - rf.mapsMu.Lock() - defer rf.mapsMu.Unlock() - rf.dataMu.RLock() - defer rf.dataMu.RUnlock() - - // Reject writable mapping if F_SEAL_WRITE is set. - if rf.seals&linux.F_SEAL_WRITE != 0 && writable { - return syserror.EPERM - } - - rf.mappings.AddMapping(ms, ar, offset, writable) - if writable { - pagesBefore := rf.writableMappingPages - - // ar is guaranteed to be page aligned per memmap.Mappable. - rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize) - - if rf.writableMappingPages < pagesBefore { - panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages)) - } - } - - return nil -} - -// RemoveMapping implements memmap.Mappable.RemoveMapping. -func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { - rf.mapsMu.Lock() - defer rf.mapsMu.Unlock() - - rf.mappings.RemoveMapping(ms, ar, offset, writable) - - if writable { - pagesBefore := rf.writableMappingPages - - // ar is guaranteed to be page aligned per memmap.Mappable. - rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize) - - if rf.writableMappingPages > pagesBefore { - panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages)) - } - } -} - -// CopyMapping implements memmap.Mappable.CopyMapping. -func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { - return rf.AddMapping(ctx, ms, dstAR, offset, writable) -} - -// Translate implements memmap.Mappable.Translate. -func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { - rf.dataMu.Lock() - defer rf.dataMu.Unlock() - - // Constrain translations to f.attr.Size (rounded up) to prevent - // translation to pages that may be concurrently truncated. - pgend := fs.OffsetPageEnd(int64(rf.size)) - var beyondEOF bool - if required.End > pgend { - if required.Start >= pgend { - return nil, &memmap.BusError{io.EOF} - } - beyondEOF = true - required.End = pgend - } - if optional.End > pgend { - optional.End = pgend - } - - cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { - // Newly-allocated pages are zeroed, so we don't need to do anything. - return dsts.NumBytes(), nil - }) - - var ts []memmap.Translation - var translatedEnd uint64 - for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() { - segMR := seg.Range().Intersect(optional) - ts = append(ts, memmap.Translation{ - Source: segMR, - File: rf.memFile, - Offset: seg.FileRangeOf(segMR).Start, - Perms: usermem.AnyAccess, - }) - translatedEnd = segMR.End - } - - // Don't return the error returned by f.data.Fill if it occurred outside of - // required. - if translatedEnd < required.End && cerr != nil { - return ts, &memmap.BusError{cerr} - } - if beyondEOF { - return ts, &memmap.BusError{io.EOF} - } - return ts, nil -} - -// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. -func (*regularFile) InvalidateUnsavable(context.Context) error { - return nil -} - -type regularFileFD struct { - fileDescription - - // off is the file offset. off is accessed using atomic memory operations. - // offMu serializes operations that may mutate off. - off int64 - offMu sync.Mutex -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() { - // noop -} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - if offset < 0 { - return 0, syserror.EINVAL - } - if dst.NumBytes() == 0 { - return 0, nil - } - f := fd.inode().impl.(*regularFile) - rw := getRegularFileReadWriter(f, offset) - n, err := dst.CopyOutFrom(ctx, rw) - putRegularFileReadWriter(rw) - return int64(n), err -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - fd.offMu.Lock() - n, err := fd.PRead(ctx, dst, fd.off, opts) - fd.off += n - fd.offMu.Unlock() - return n, err -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - if offset < 0 { - return 0, syserror.EINVAL - } - srclen := src.NumBytes() - if srclen == 0 { - return 0, nil - } - f := fd.inode().impl.(*regularFile) - end := offset + srclen - if end < offset { - // Overflow. - return 0, syserror.EFBIG - } - f.inode.mu.Lock() - rw := getRegularFileReadWriter(f, offset) - n, err := src.CopyInTo(ctx, rw) - f.inode.mu.Unlock() - putRegularFileReadWriter(rw) - return n, err -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - fd.offMu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n - fd.offMu.Unlock() - return n, err -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - fd.offMu.Lock() - defer fd.offMu.Unlock() - switch whence { - case linux.SEEK_SET: - // use offset as specified - case linux.SEEK_CUR: - offset += fd.off - case linux.SEEK_END: - offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size)) - default: - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL - } - fd.off = offset - return offset, nil -} - -// Sync implements vfs.FileDescriptionImpl.Sync. -func (fd *regularFileFD) Sync(ctx context.Context) error { - return nil -} - -// LockBSD implements vfs.FileDescriptionImpl.LockBSD. -func (fd *regularFileFD) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error { - return fd.inode().lockBSD(uid, t, block) -} - -// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. -func (fd *regularFileFD) UnlockBSD(ctx context.Context, uid lock.UniqueID) error { - fd.inode().unlockBSD(uid) - return nil -} - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error { - return fd.inode().lockPOSIX(uid, t, rng, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error { - fd.inode().unlockPOSIX(uid, rng) - return nil -} - -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - file := fd.inode().impl.(*regularFile) - return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts) -} - -// regularFileReadWriter implements safemem.Reader and Safemem.Writer. -type regularFileReadWriter struct { - file *regularFile - - // Offset into the file to read/write at. Note that this may be - // different from the FD offset if PRead/PWrite is used. - off uint64 -} - -var regularFileReadWriterPool = sync.Pool{ - New: func() interface{} { - return ®ularFileReadWriter{} - }, -} - -func getRegularFileReadWriter(file *regularFile, offset int64) *regularFileReadWriter { - rw := regularFileReadWriterPool.Get().(*regularFileReadWriter) - rw.file = file - rw.off = uint64(offset) - return rw -} - -func putRegularFileReadWriter(rw *regularFileReadWriter) { - rw.file = nil - regularFileReadWriterPool.Put(rw) -} - -// ReadToBlocks implements safemem.Reader.ReadToBlocks. -func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - rw.file.dataMu.RLock() - defer rw.file.dataMu.RUnlock() - size := rw.file.size - - // Compute the range to read (limited by file size and overflow-checked). - if rw.off >= size { - return 0, io.EOF - } - end := size - if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end { - end = rend - } - - var done uint64 - seg, gap := rw.file.data.Find(uint64(rw.off)) - for rw.off < end { - mr := memmap.MappableRange{uint64(rw.off), uint64(end)} - switch { - case seg.Ok(): - // Get internal mappings. - ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) - if err != nil { - return done, err - } - - // Copy from internal mappings. - n, err := safemem.CopySeq(dsts, ims) - done += n - rw.off += uint64(n) - dsts = dsts.DropFirst64(n) - if err != nil { - return done, err - } - - // Continue. - seg, gap = seg.NextNonEmpty() - - case gap.Ok(): - // Tmpfs holes are zero-filled. - gapmr := gap.Range().Intersect(mr) - dst := dsts.TakeFirst64(gapmr.Length()) - n, err := safemem.ZeroSeq(dst) - done += n - rw.off += uint64(n) - dsts = dsts.DropFirst64(n) - if err != nil { - return done, err - } - - // Continue. - seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} - } - } - return done, nil -} - -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. -// -// Preconditions: inode.mu must be held. -func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - // Hold dataMu so we can modify size. - rw.file.dataMu.Lock() - defer rw.file.dataMu.Unlock() - - // Compute the range to write (overflow-checked). - end := rw.off + srcs.NumBytes() - if end <= rw.off { - end = math.MaxInt64 - } - - // Check if seals prevent either file growth or all writes. - switch { - case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed - return 0, syserror.EPERM - case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed - // When growth is sealed, Linux effectively allows writes which would - // normally grow the file to partially succeed up to the current EOF, - // rounded down to the page boundary before the EOF. - // - // This happens because writes (and thus the growth check) for tmpfs - // files proceed page-by-page on Linux, and the final write to the page - // containing EOF fails, resulting in a partial write up to the start of - // that page. - // - // To emulate this behaviour, artifically truncate the write to the - // start of the page containing the current EOF. - // - // See Linux, mm/filemap.c:generic_perform_write() and - // mm/shmem.c:shmem_write_begin(). - if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart { - end = pgstart - } - if end <= rw.off { - // Truncation would result in no data being written. - return 0, syserror.EPERM - } - } - - // Page-aligned mr for when we need to allocate memory. RoundUp can't - // overflow since end is an int64. - pgstartaddr := usermem.Addr(rw.off).RoundDown() - pgendaddr, _ := usermem.Addr(end).RoundUp() - pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)} - - var ( - done uint64 - retErr error - ) - seg, gap := rw.file.data.Find(uint64(rw.off)) - for rw.off < end { - mr := memmap.MappableRange{uint64(rw.off), uint64(end)} - switch { - case seg.Ok(): - // Get internal mappings. - ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write) - if err != nil { - retErr = err - goto exitLoop - } - - // Copy to internal mappings. - n, err := safemem.CopySeq(ims, srcs) - done += n - rw.off += uint64(n) - srcs = srcs.DropFirst64(n) - if err != nil { - retErr = err - goto exitLoop - } - - // Continue. - seg, gap = seg.NextNonEmpty() - - case gap.Ok(): - // Allocate memory for the write. - gapMR := gap.Range().Intersect(pgMR) - fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs) - if err != nil { - retErr = err - goto exitLoop - } - - // Write to that memory as usual. - seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{} - } - } -exitLoop: - // If the write ends beyond the file's previous size, it causes the - // file to grow. - if rw.off > rw.file.size { - rw.file.size = rw.off - } - - return done, retErr -} diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go deleted file mode 100644 index 0399725cf..000000000 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ /dev/null @@ -1,496 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "bytes" - "fmt" - "io" - "sync/atomic" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// nextFileID is used to generate unique file names. -var nextFileID int64 - -// newTmpfsRoot creates a new tmpfs mount, and returns the root. If the error -// is not nil, then cleanup should be called when the root is no longer needed. -func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) { - creds := auth.CredentialsFromContext(ctx) - - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) - } - - vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) - if err != nil { - return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err) - } - root := mntns.Root() - return vfsObj, root, func() { - root.DecRef() - mntns.DecRef() - }, nil -} - -// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If -// the returned err is not nil, then cleanup should be called when the FD is no -// longer needed. -func newFileFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - filename := fmt.Sprintf("tmpfs-test-file-%d", atomic.AddInt64(&nextFileID, 1)) - - // Create the file that will be write/read. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: linux.ModeRegular | mode, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err) - } - - return fd, cleanup, nil -} - -// newDirFD is like newFileFD, but for directories. -func newDirFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - dirname := fmt.Sprintf("tmpfs-test-dir-%d", atomic.AddInt64(&nextFileID, 1)) - - // Create the dir. - if err := vfsObj.MkdirAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(dirname), - }, &vfs.MkdirOptions{ - Mode: linux.ModeDirectory | mode, - }); err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create directory %q: %v", dirname, err) - } - - // Open the dir and return it. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(dirname), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY | linux.O_DIRECTORY, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to open directory %q: %v", dirname, err) - } - - return fd, cleanup, nil -} - -// newPipeFD is like newFileFD, but for pipes. -func newPipeFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - pipename := fmt.Sprintf("tmpfs-test-pipe-%d", atomic.AddInt64(&nextFileID, 1)) - - // Create the pipe. - if err := vfsObj.MknodAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(pipename), - }, &vfs.MknodOptions{ - Mode: linux.ModeNamedPipe | mode, - }); err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create pipe %q: %v", pipename, err) - } - - // Open the pipe and return it. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(pipename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to open pipe %q: %v", pipename, err) - } - - return fd, cleanup, nil -} - -// Test that we can write some data to a file and read it back.` -func TestSimpleWriteRead(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - // Write. - data := []byte("foobarbaz") - n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatalf("fd.Write failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) - } - if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want { - t.Errorf("fd.Write left offset at %d, want %d", got, want) - } - - // Seek back to beginning. - if _, err := fd.Seek(ctx, 0, linux.SEEK_SET); err != nil { - t.Fatalf("fd.Seek failed: %v", err) - } - if got, want := fd.Impl().(*regularFileFD).off, int64(0); got != want { - t.Errorf("fd.Seek(0) left offset at %d, want %d", got, want) - } - - // Read. - buf := make([]byte, len(data)) - n, err = fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.Read failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.Read got short read length %d, want %d", n, len(data)) - } - if got, want := string(buf), string(data); got != want { - t.Errorf("Read got %q want %s", got, want) - } - if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want { - t.Errorf("fd.Write left offset at %d, want %d", got, want) - } -} - -func TestPWrite(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - // Fill file with 1k 'a's. - data := bytes.Repeat([]byte{'a'}, 1000) - n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatalf("fd.Write failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) - } - - // Write "gVisor is awesome" at various offsets. - buf := []byte("gVisor is awesome") - offsets := []int{0, 1, 2, 10, 20, 50, 100, len(data) - 100, len(data) - 1, len(data), len(data) + 1} - for _, offset := range offsets { - name := fmt.Sprintf("PWrite offset=%d", offset) - t.Run(name, func(t *testing.T) { - n, err := fd.PWrite(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.WriteOptions{}) - if err != nil { - t.Errorf("fd.PWrite got err %v want nil", err) - } - if n != int64(len(buf)) { - t.Errorf("fd.PWrite got %d bytes want %d", n, len(buf)) - } - - // Update data to reflect expected file contents. - if len(data) < offset+len(buf) { - data = append(data, make([]byte, (offset+len(buf))-len(data))...) - } - copy(data[offset:], buf) - - // Read the whole file and compare with data. - readBuf := make([]byte, len(data)) - n, err = fd.PRead(ctx, usermem.BytesIOSequence(readBuf), 0, vfs.ReadOptions{}) - if err != nil { - t.Fatalf("fd.PRead failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.PRead got short read length %d, want %d", n, len(data)) - } - if got, want := string(readBuf), string(data); got != want { - t.Errorf("PRead got %q want %s", got, want) - } - - }) - } -} - -func TestLocks(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - var ( - uid1 lock.UniqueID - uid2 lock.UniqueID - // Non-blocking. - block lock.Blocker - ) - - uid1 = 123 - uid2 = 456 - - if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, block); err != nil { - t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) - } - if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, block); err != nil { - t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) - } - if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block), syserror.ErrWouldBlock; got != want { - t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want) - } - if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil { - t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err) - } - if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block); err != nil { - t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) - } - - rng1 := lock.LockRange{0, 1} - rng2 := lock.LockRange{1, 2} - - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, rng1, block); err != nil { - t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) - } - if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng2, block); err != nil { - t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) - } - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, rng1, block); err != nil { - t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) - } - if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng1, block), syserror.ErrWouldBlock; got != want { - t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want) - } - if err := fd.Impl().UnlockPOSIX(ctx, uid1, rng1); err != nil { - t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err) - } -} - -func TestPRead(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - // Write 100 sequences of 'gVisor is awesome'. - data := bytes.Repeat([]byte("gVisor is awsome"), 100) - n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatalf("fd.Write failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) - } - - // Read various sizes from various offsets. - sizes := []int{0, 1, 2, 10, 20, 50, 100, 1000} - offsets := []int{0, 1, 2, 10, 20, 50, 100, 1000, len(data) - 100, len(data) - 1, len(data), len(data) + 1} - - for _, size := range sizes { - for _, offset := range offsets { - name := fmt.Sprintf("PRead offset=%d size=%d", offset, size) - t.Run(name, func(t *testing.T) { - var ( - wantRead []byte - wantErr error - ) - if offset < len(data) { - wantRead = data[offset:] - } else if size > 0 { - wantErr = io.EOF - } - if offset+size < len(data) { - wantRead = wantRead[:size] - } - buf := make([]byte, size) - n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.ReadOptions{}) - if err != wantErr { - t.Errorf("fd.PRead got err %v want %v", err, wantErr) - } - if n != int64(len(wantRead)) { - t.Errorf("fd.PRead got %d bytes want %d", n, len(wantRead)) - } - if got := string(buf[:n]); got != string(wantRead) { - t.Errorf("fd.PRead got %q want %q", got, string(wantRead)) - } - }) - } - } -} - -func TestTruncate(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - // Fill the file with some data. - data := bytes.Repeat([]byte("gVisor is awsome"), 100) - written, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatalf("fd.Write failed: %v", err) - } - - // Size should be same as written. - sizeStatOpts := vfs.StatOptions{Mask: linux.STATX_SIZE} - stat, err := fd.Stat(ctx, sizeStatOpts) - if err != nil { - t.Fatalf("fd.Stat failed: %v", err) - } - if got, want := int64(stat.Size), written; got != want { - t.Errorf("fd.Stat got size %d, want %d", got, want) - } - - // Truncate down. - newSize := uint64(10) - if err := fd.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_SIZE, - Size: newSize, - }, - }); err != nil { - t.Errorf("fd.Truncate failed: %v", err) - } - // Size should be updated. - statAfterTruncateDown, err := fd.Stat(ctx, sizeStatOpts) - if err != nil { - t.Fatalf("fd.Stat failed: %v", err) - } - if got, want := statAfterTruncateDown.Size, newSize; got != want { - t.Errorf("fd.Stat got size %d, want %d", got, want) - } - // We should only read newSize worth of data. - buf := make([]byte, 1000) - if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF { - t.Fatalf("fd.PRead failed: %v", err) - } else if uint64(n) != newSize { - t.Errorf("fd.PRead got size %d, want %d", n, newSize) - } - // Mtime and Ctime should be bumped. - if got := statAfterTruncateDown.Mtime.ToNsec(); got <= stat.Mtime.ToNsec() { - t.Errorf("fd.Stat got Mtime %v, want > %v", got, stat.Mtime) - } - if got := statAfterTruncateDown.Ctime.ToNsec(); got <= stat.Ctime.ToNsec() { - t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime) - } - - // Truncate up. - newSize = 100 - if err := fd.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_SIZE, - Size: newSize, - }, - }); err != nil { - t.Errorf("fd.Truncate failed: %v", err) - } - // Size should be updated. - statAfterTruncateUp, err := fd.Stat(ctx, sizeStatOpts) - if err != nil { - t.Fatalf("fd.Stat failed: %v", err) - } - if got, want := statAfterTruncateUp.Size, newSize; got != want { - t.Errorf("fd.Stat got size %d, want %d", got, want) - } - // We should read newSize worth of data. - buf = make([]byte, 1000) - if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF { - t.Fatalf("fd.PRead failed: %v", err) - } else if uint64(n) != newSize { - t.Errorf("fd.PRead got size %d, want %d", n, newSize) - } - // Bytes should be null after 10, since we previously truncated to 10. - for i := uint64(10); i < newSize; i++ { - if buf[i] != 0 { - t.Errorf("fd.PRead got byte %d=%x, want 0", i, buf[i]) - break - } - } - // Mtime and Ctime should be bumped. - if got := statAfterTruncateUp.Mtime.ToNsec(); got <= statAfterTruncateDown.Mtime.ToNsec() { - t.Errorf("fd.Stat got Mtime %v, want > %v", got, statAfterTruncateDown.Mtime) - } - if got := statAfterTruncateUp.Ctime.ToNsec(); got <= statAfterTruncateDown.Ctime.ToNsec() { - t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime) - } - - // Truncate to the current size. - newSize = statAfterTruncateUp.Size - if err := fd.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_SIZE, - Size: newSize, - }, - }); err != nil { - t.Errorf("fd.Truncate failed: %v", err) - } - statAfterTruncateNoop, err := fd.Stat(ctx, sizeStatOpts) - if err != nil { - t.Fatalf("fd.Stat failed: %v", err) - } - // Mtime and Ctime should not be bumped, since operation is a noop. - if got := statAfterTruncateNoop.Mtime.ToNsec(); got != statAfterTruncateUp.Mtime.ToNsec() { - t.Errorf("fd.Stat got Mtime %v, want %v", got, statAfterTruncateUp.Mtime) - } - if got := statAfterTruncateNoop.Ctime.ToNsec(); got != statAfterTruncateUp.Ctime.ToNsec() { - t.Errorf("fd.Stat got Ctime %v, want %v", got, statAfterTruncateUp.Ctime) - } -} diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go deleted file mode 100644 index ebe035dee..000000000 --- a/pkg/sentry/fsimpl/tmpfs/stat_test.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -func TestStatAfterCreate(t *testing.T) { - ctx := contexttest.Context(t) - mode := linux.FileMode(0644) - - // Run with different file types. - // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets. - for _, typ := range []string{"file", "dir", "pipe"} { - t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { - var ( - fd *vfs.FileDescription - cleanup func() - err error - ) - switch typ { - case "file": - fd, cleanup, err = newFileFD(ctx, mode) - case "dir": - fd, cleanup, err = newDirFD(ctx, mode) - case "pipe": - fd, cleanup, err = newPipeFD(ctx, mode) - default: - panic(fmt.Sprintf("unknown typ %q", typ)) - } - if err != nil { - t.Fatal(err) - } - defer cleanup() - - got, err := fd.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - - // Atime, Ctime, Mtime should all be current time (non-zero). - atime, ctime, mtime := got.Atime.ToNsec(), got.Ctime.ToNsec(), got.Mtime.ToNsec() - if atime != ctime || ctime != mtime { - t.Errorf("got atime=%d ctime=%d mtime=%d, wanted equal values", atime, ctime, mtime) - } - if atime == 0 { - t.Errorf("got atime=%d, want non-zero", atime) - } - - // Btime should be 0, as it is not set by tmpfs. - if btime := got.Btime.ToNsec(); btime != 0 { - t.Errorf("got btime %d, want 0", got.Btime.ToNsec()) - } - - // Size should be 0. - if got.Size != 0 { - t.Errorf("got size %d, want 0", got.Size) - } - - // Nlink should be 1 for files, 2 for dirs. - wantNlink := uint32(1) - if typ == "dir" { - wantNlink = 2 - } - if got.Nlink != wantNlink { - t.Errorf("got nlink %d, want %d", got.Nlink, wantNlink) - } - - // UID and GID are set from context creds. - creds := auth.CredentialsFromContext(ctx) - if got.UID != uint32(creds.EffectiveKUID) { - t.Errorf("got uid %d, want %d", got.UID, uint32(creds.EffectiveKUID)) - } - if got.GID != uint32(creds.EffectiveKGID) { - t.Errorf("got gid %d, want %d", got.GID, uint32(creds.EffectiveKGID)) - } - - // Mode. - wantMode := uint16(mode) - switch typ { - case "file": - wantMode |= linux.S_IFREG - case "dir": - wantMode |= linux.S_IFDIR - case "pipe": - wantMode |= linux.S_IFIFO - default: - panic(fmt.Sprintf("unknown typ %q", typ)) - } - - if got.Mode != wantMode { - t.Errorf("got mode %x, want %x", got.Mode, wantMode) - } - - // Ino. - if got.Ino == 0 { - t.Errorf("got ino %d, want not 0", got.Ino) - } - }) - } -} - -func TestSetStatAtime(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL} - - // Get initial stat. - initialStat, err := fd.Stat(ctx, allStatOptions) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - - // Set atime, but without the mask. - if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{ - Mask: 0, - Atime: linux.NsecToStatxTimestamp(100), - }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v") - } - // Atime should be unchanged. - if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { - t.Errorf("Stat got error: %v", err) - } else if gotStat.Atime != initialStat.Atime { - t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime) - } - - // Set atime, this time included in the mask. - setStat := linux.Statx{ - Mask: linux.STATX_ATIME, - Atime: linux.NsecToStatxTimestamp(100), - } - if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v") - } - if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { - t.Errorf("Stat got error: %v", err) - } else if gotStat.Atime != setStat.Atime { - t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime) - } -} - -func TestSetStat(t *testing.T) { - ctx := contexttest.Context(t) - mode := linux.FileMode(0644) - - // Run with different file types. - // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets. - for _, typ := range []string{"file", "dir", "pipe"} { - t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { - var ( - fd *vfs.FileDescription - cleanup func() - err error - ) - switch typ { - case "file": - fd, cleanup, err = newFileFD(ctx, mode) - case "dir": - fd, cleanup, err = newDirFD(ctx, mode) - case "pipe": - fd, cleanup, err = newPipeFD(ctx, mode) - default: - panic(fmt.Sprintf("unknown typ %q", typ)) - } - if err != nil { - t.Fatal(err) - } - defer cleanup() - - allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL} - - // Get initial stat. - initialStat, err := fd.Stat(ctx, allStatOptions) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - - // Set atime, but without the mask. - if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{ - Mask: 0, - Atime: linux.NsecToStatxTimestamp(100), - }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v") - } - // Atime should be unchanged. - if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { - t.Errorf("Stat got error: %v", err) - } else if gotStat.Atime != initialStat.Atime { - t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime) - } - - // Set atime, this time included in the mask. - setStat := linux.Statx{ - Mask: linux.STATX_ATIME, - Atime: linux.NsecToStatxTimestamp(100), - } - if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v") - } - if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { - t.Errorf("Stat got error: %v", err) - } else if gotStat.Atime != setStat.Atime { - t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime) - } - }) - } -} diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go deleted file mode 100644 index 5246aca84..000000000 --- a/pkg/sentry/fsimpl/tmpfs/symlink.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" -) - -type symlink struct { - inode inode - target string // immutable -} - -func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode { - link := &symlink{ - target: target, - } - link.inode.init(link, fs, creds, 0777) - link.inode.nlink = 1 // from parent directory - return &link.inode -} - -// O_PATH is unimplemented, so there's no way to get a FileDescription -// representing a symlink yet. diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go deleted file mode 100644 index 521206305..000000000 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ /dev/null @@ -1,461 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tmpfs provides a filesystem implementation that behaves like tmpfs: -// the Dentry tree is the sole source of truth for the state of the filesystem. -// -// Lock order: -// -// filesystem.mu -// inode.mu -// regularFileFD.offMu -// regularFile.mapsMu -// regularFile.dataMu -package tmpfs - -import ( - "fmt" - "math" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" -) - -// Name is the default filesystem name. -const Name = "tmpfs" - -// FilesystemType implements vfs.FilesystemType. -type FilesystemType struct{} - -// filesystem implements vfs.FilesystemImpl. -type filesystem struct { - vfsfs vfs.Filesystem - - // memFile is used to allocate pages to for regular files. - memFile *pgalloc.MemoryFile - - // clock is a realtime clock used to set timestamps in file operations. - clock time.Clock - - // mu serializes changes to the Dentry tree. - mu sync.RWMutex - - nextInoMinusOne uint64 // accessed using atomic memory operations -} - -// GetFilesystem implements vfs.FilesystemType.GetFilesystem. -func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx) - if memFileProvider == nil { - panic("MemoryFileProviderFromContext returned nil") - } - clock := time.RealtimeClockFromContext(ctx) - fs := filesystem{ - memFile: memFileProvider.MemoryFile(), - clock: clock, - } - fs.vfsfs.Init(vfsObj, &fs) - root := fs.newDentry(fs.newDirectory(creds, 01777)) - return &fs.vfsfs, &root.vfsd, nil -} - -// Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { -} - -// dentry implements vfs.DentryImpl. -type dentry struct { - vfsd vfs.Dentry - - // inode is the inode represented by this dentry. Multiple Dentries may - // share a single non-directory inode (with hard links). inode is - // immutable. - inode *inode - - // tmpfs doesn't count references on dentries; because the dentry tree is - // the sole source of truth, it is by definition always consistent with the - // state of the filesystem. However, it does count references on inodes, - // because inode resources are released when all references are dropped. - // (tmpfs doesn't really have resources to release, but we implement - // reference counting because tmpfs regular files will.) - - // dentryEntry (ugh) links dentries into their parent directory.childList. - dentryEntry -} - -func (fs *filesystem) newDentry(inode *inode) *dentry { - d := &dentry{ - inode: inode, - } - d.vfsd.Init(d) - return d -} - -// IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef() { - d.inode.incRef() -} - -// TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef() bool { - return d.inode.tryIncRef() -} - -// DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { - d.inode.decRef() -} - -// inode represents a filesystem object. -type inode struct { - // clock is a realtime clock used to set timestamps in file operations. - clock time.Clock - - // refs is a reference count. refs is accessed using atomic memory - // operations. - // - // A reference is held on all inodes that are reachable in the filesystem - // tree. For non-directories (which may have multiple hard links), this - // means that a reference is dropped when nlink reaches 0. For directories, - // nlink never reaches 0 due to the "." entry; instead, - // filesystem.RmdirAt() drops the reference. - refs int64 - - // Inode metadata. Writing multiple fields atomically requires holding - // mu, othewise atomic operations can be used. - mu sync.Mutex - mode uint32 // excluding file type bits, which are based on impl - nlink uint32 // protected by filesystem.mu instead of inode.mu - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - ino uint64 // immutable - - // Linux's tmpfs has no concept of btime. - atime int64 // nanoseconds - ctime int64 // nanoseconds - mtime int64 // nanoseconds - - // Only meaningful for device special files. - rdevMajor uint32 - rdevMinor uint32 - - // Advisory file locks, which lock at the inode level. - locks lock.FileLocks - - impl interface{} // immutable -} - -const maxLinks = math.MaxUint32 - -func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) { - i.clock = fs.clock - i.refs = 1 - i.mode = uint32(mode) - i.uid = uint32(creds.EffectiveKUID) - i.gid = uint32(creds.EffectiveKGID) - i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1) - // Tmpfs creation sets atime, ctime, and mtime to current time. - now := i.clock.Now().Nanoseconds() - i.atime = now - i.ctime = now - i.mtime = now - // i.nlink initialized by caller - i.impl = impl -} - -// incLinksLocked increments i's link count. -// -// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. -// i.nlink < maxLinks. -func (i *inode) incLinksLocked() { - if i.nlink == 0 { - panic("tmpfs.inode.incLinksLocked() called with no existing links") - } - if i.nlink == maxLinks { - panic("memfs.inode.incLinksLocked() called with maximum link count") - } - atomic.AddUint32(&i.nlink, 1) -} - -// decLinksLocked decrements i's link count. -// -// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. -func (i *inode) decLinksLocked() { - if i.nlink == 0 { - panic("tmpfs.inode.decLinksLocked() called with no existing links") - } - atomic.AddUint32(&i.nlink, ^uint32(0)) -} - -func (i *inode) incRef() { - if atomic.AddInt64(&i.refs, 1) <= 1 { - panic("tmpfs.inode.incRef() called without holding a reference") - } -} - -func (i *inode) tryIncRef() bool { - for { - refs := atomic.LoadInt64(&i.refs) - if refs == 0 { - return false - } - if atomic.CompareAndSwapInt64(&i.refs, refs, refs+1) { - return true - } - } -} - -func (i *inode) decRef() { - if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { - if regFile, ok := i.impl.(*regularFile); ok { - // Hold inode.mu and regFile.dataMu while mutating - // size. - i.mu.Lock() - regFile.dataMu.Lock() - regFile.data.DropAll(regFile.memFile) - atomic.StoreUint64(®File.size, 0) - regFile.dataMu.Unlock() - i.mu.Unlock() - } - } else if refs < 0 { - panic("tmpfs.inode.decRef() called without holding a reference") - } -} - -func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error { - return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))) -} - -// Go won't inline this function, and returning linux.Statx (which is quite -// big) means spending a lot of time in runtime.duffcopy(), so instead it's an -// output parameter. -// -// Note that Linux does not guarantee to return consistent data (in the case of -// a concurrent modification), so we do not require holding inode.mu. -func (i *inode) statTo(stat *linux.Statx) { - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | - linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_ATIME | - linux.STATX_BTIME | linux.STATX_CTIME | linux.STATX_MTIME - stat.Blksize = 1 // usermem.PageSize in tmpfs - stat.Nlink = atomic.LoadUint32(&i.nlink) - stat.UID = atomic.LoadUint32(&i.uid) - stat.GID = atomic.LoadUint32(&i.gid) - stat.Mode = uint16(atomic.LoadUint32(&i.mode)) - stat.Ino = i.ino - // Linux's tmpfs has no concept of btime, so zero-value is returned. - stat.Atime = linux.NsecToStatxTimestamp(i.atime) - stat.Ctime = linux.NsecToStatxTimestamp(i.ctime) - stat.Mtime = linux.NsecToStatxTimestamp(i.mtime) - // TODO(gvisor.dev/issues/1197): Device number. - switch impl := i.impl.(type) { - case *regularFile: - stat.Mode |= linux.S_IFREG - stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS - stat.Size = uint64(atomic.LoadUint64(&impl.size)) - // In tmpfs, this will be FileRangeSet.Span() / 512 (but also cached in - // a uint64 accessed using atomic memory operations to avoid taking - // locks). - stat.Blocks = allocatedBlocksForSize(stat.Size) - case *directory: - stat.Mode |= linux.S_IFDIR - case *symlink: - stat.Mode |= linux.S_IFLNK - stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS - stat.Size = uint64(len(impl.target)) - stat.Blocks = allocatedBlocksForSize(stat.Size) - case *namedPipe: - stat.Mode |= linux.S_IFIFO - case *deviceFile: - switch impl.kind { - case vfs.BlockDevice: - stat.Mode |= linux.S_IFBLK - case vfs.CharDevice: - stat.Mode |= linux.S_IFCHR - } - stat.RdevMajor = impl.major - stat.RdevMinor = impl.minor - default: - panic(fmt.Sprintf("unknown inode type: %T", i.impl)) - } -} - -func (i *inode) setStat(stat linux.Statx) error { - if stat.Mask == 0 { - return nil - } - i.mu.Lock() - var ( - needsMtimeBump bool - needsCtimeBump bool - ) - mask := stat.Mask - if mask&linux.STATX_MODE != 0 { - atomic.StoreUint32(&i.mode, uint32(stat.Mode)) - needsCtimeBump = true - } - if mask&linux.STATX_UID != 0 { - atomic.StoreUint32(&i.uid, stat.UID) - needsCtimeBump = true - } - if mask&linux.STATX_GID != 0 { - atomic.StoreUint32(&i.gid, stat.GID) - needsCtimeBump = true - } - if mask&linux.STATX_SIZE != 0 { - switch impl := i.impl.(type) { - case *regularFile: - updated, err := impl.truncateLocked(stat.Size) - if err != nil { - return err - } - if updated { - needsMtimeBump = true - needsCtimeBump = true - } - case *directory: - return syserror.EISDIR - default: - return syserror.EINVAL - } - } - if mask&linux.STATX_ATIME != 0 { - atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped()) - needsCtimeBump = true - } - if mask&linux.STATX_MTIME != 0 { - atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped()) - needsCtimeBump = true - // Ignore the mtime bump, since we just set it ourselves. - needsMtimeBump = false - } - if mask&linux.STATX_CTIME != 0 { - atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped()) - // Ignore the ctime bump, since we just set it ourselves. - needsCtimeBump = false - } - now := i.clock.Now().Nanoseconds() - if needsMtimeBump { - atomic.StoreInt64(&i.mtime, now) - } - if needsCtimeBump { - atomic.StoreInt64(&i.ctime, now) - } - i.mu.Unlock() - return nil -} - -// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular. -func (i *inode) lockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { - switch i.impl.(type) { - case *regularFile: - return i.locks.LockBSD(uid, t, block) - } - return syserror.EBADF -} - -// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular. -func (i *inode) unlockBSD(uid fslock.UniqueID) error { - switch i.impl.(type) { - case *regularFile: - i.locks.UnlockBSD(uid) - return nil - } - return syserror.EBADF -} - -// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular. -func (i *inode) lockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { - switch i.impl.(type) { - case *regularFile: - return i.locks.LockPOSIX(uid, t, rng, block) - } - return syserror.EBADF -} - -// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular. -func (i *inode) unlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) error { - switch i.impl.(type) { - case *regularFile: - i.locks.UnlockPOSIX(uid, rng) - return nil - } - return syserror.EBADF -} - -// allocatedBlocksForSize returns the number of 512B blocks needed to -// accommodate the given size in bytes, as appropriate for struct -// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block -// size is independent of the "preferred block size for I/O", struct -// stat::st_blksize and struct statx::stx_blksize.) -func allocatedBlocksForSize(size uint64) uint64 { - return (size + 511) / 512 -} - -func (i *inode) direntType() uint8 { - switch impl := i.impl.(type) { - case *regularFile: - return linux.DT_REG - case *directory: - return linux.DT_DIR - case *symlink: - return linux.DT_LNK - case *deviceFile: - switch impl.kind { - case vfs.BlockDevice: - return linux.DT_BLK - case vfs.CharDevice: - return linux.DT_CHR - default: - panic(fmt.Sprintf("unknown vfs.DeviceKind: %v", impl.kind)) - } - default: - panic(fmt.Sprintf("unknown inode type: %T", i.impl)) - } -} - -// fileDescription is embedded by tmpfs implementations of -// vfs.FileDescriptionImpl. -type fileDescription struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl -} - -func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) -} - -func (fd *fileDescription) inode() *inode { - return fd.vfsfd.Dentry().Impl().(*dentry).inode -} - -// Stat implements vfs.FileDescriptionImpl.Stat. -func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - var stat linux.Statx - fd.inode().statTo(&stat) - return stat, nil -} - -// SetStat implements vfs.FileDescriptionImpl.SetStat. -func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - return fd.inode().setStat(opts.Stat) -} diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD deleted file mode 100644 index e6933aa70..000000000 --- a/pkg/sentry/hostcpu/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "hostcpu", - srcs = [ - "getcpu_amd64.s", - "getcpu_arm64.s", - "hostcpu.go", - ], - visibility = ["//:sandbox"], -) - -go_test( - name = "hostcpu_test", - size = "small", - srcs = ["hostcpu_test.go"], - library = ":hostcpu", -) diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s index caf9abb89..caf9abb89 100644..100755 --- a/pkg/sentry/hostcpu/getcpu_arm64.s +++ b/pkg/sentry/hostcpu/getcpu_arm64.s diff --git a/pkg/sentry/hostcpu/hostcpu_state_autogen.go b/pkg/sentry/hostcpu/hostcpu_state_autogen.go new file mode 100755 index 000000000..97d33d8bf --- /dev/null +++ b/pkg/sentry/hostcpu/hostcpu_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostcpu diff --git a/pkg/sentry/hostcpu/hostcpu_test.go b/pkg/sentry/hostcpu/hostcpu_test.go deleted file mode 100644 index 7d6885c9e..000000000 --- a/pkg/sentry/hostcpu/hostcpu_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package hostcpu - -import ( - "fmt" - "testing" -) - -func TestMaxValueInLinuxBitmap(t *testing.T) { - for _, test := range []struct { - str string - max uint64 - }{ - {"0", 0}, - {"0\n", 0}, - {"0,2", 2}, - {"0-63", 63}, - {"0-3,8-11", 11}, - } { - t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { - max, err := maxValueInLinuxBitmap(test.str) - if err != nil || max != test.max { - t.Errorf("maxValueInLinuxBitmap: got (%d, %v), wanted (%d, nil)", max, err, test.max) - } - }) - } -} - -func TestMaxValueInLinuxBitmapErrors(t *testing.T) { - for _, str := range []string{"", "\n"} { - t.Run(fmt.Sprintf("%q", str), func(t *testing.T) { - max, err := maxValueInLinuxBitmap(str) - if err == nil { - t.Errorf("maxValueInLinuxBitmap: got (%d, nil), wanted (_, error)", max) - } - t.Log(err) - }) - } -} diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD deleted file mode 100644 index 61c78569d..000000000 --- a/pkg/sentry/hostmm/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hostmm", - srcs = [ - "cgroup.go", - "hostmm.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/hostmm/hostmm_state_autogen.go b/pkg/sentry/hostmm/hostmm_state_autogen.go new file mode 100755 index 000000000..925c56e14 --- /dev/null +++ b/pkg/sentry/hostmm/hostmm_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostmm diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD deleted file mode 100644 index 07bf39fed..000000000 --- a/pkg/sentry/inet/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -go_library( - name = "inet", - srcs = [ - "context.go", - "inet.go", - "namespace.go", - "test_stack.go", - ], - deps = [ - "//pkg/context", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/sentry/inet/inet_state_autogen.go b/pkg/sentry/inet/inet_state_autogen.go new file mode 100755 index 000000000..d2985113b --- /dev/null +++ b/pkg/sentry/inet/inet_state_autogen.go @@ -0,0 +1,40 @@ +// automatically generated by stateify. + +package inet + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *TCPBufferSize) beforeSave() {} +func (x *TCPBufferSize) save(m state.Map) { + x.beforeSave() + m.Save("Min", &x.Min) + m.Save("Default", &x.Default) + m.Save("Max", &x.Max) +} + +func (x *TCPBufferSize) afterLoad() {} +func (x *TCPBufferSize) load(m state.Map) { + m.Load("Min", &x.Min) + m.Load("Default", &x.Default) + m.Load("Max", &x.Max) +} + +func (x *Namespace) beforeSave() {} +func (x *Namespace) save(m state.Map) { + x.beforeSave() + m.Save("creator", &x.creator) + m.Save("isRoot", &x.isRoot) +} + +func (x *Namespace) load(m state.Map) { + m.LoadWait("creator", &x.creator) + m.Load("isRoot", &x.isRoot) + m.AfterLoad(x.afterLoad) +} + +func init() { + state.Register("pkg/sentry/inet.TCPBufferSize", (*TCPBufferSize)(nil), state.Fns{Save: (*TCPBufferSize).save, Load: (*TCPBufferSize).load}) + state.Register("pkg/sentry/inet.Namespace", (*Namespace)(nil), state.Fns{Save: (*Namespace).save, Load: (*Namespace).load}) +} diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go index 029af3025..029af3025 100644..100755 --- a/pkg/sentry/inet/namespace.go +++ b/pkg/sentry/inet/namespace.go diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD deleted file mode 100644 index beba29a09..000000000 --- a/pkg/sentry/kernel/BUILD +++ /dev/null @@ -1,234 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "pending_signals_list", - out = "pending_signals_list.go", - package = "kernel", - prefix = "pendingSignal", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*pendingSignal", - "Linker": "*pendingSignal", - }, -) - -go_template_instance( - name = "process_group_list", - out = "process_group_list.go", - package = "kernel", - prefix = "processGroup", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*ProcessGroup", - "Linker": "*ProcessGroup", - }, -) - -go_template_instance( - name = "seqatomic_taskgoroutineschedinfo", - out = "seqatomic_taskgoroutineschedinfo_unsafe.go", - package = "kernel", - suffix = "TaskGoroutineSchedInfo", - template = "//pkg/sync:generic_seqatomic", - types = { - "Value": "TaskGoroutineSchedInfo", - }, -) - -go_template_instance( - name = "session_list", - out = "session_list.go", - package = "kernel", - prefix = "session", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Session", - "Linker": "*Session", - }, -) - -go_template_instance( - name = "task_list", - out = "task_list.go", - package = "kernel", - prefix = "task", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Task", - "Linker": "*Task", - }, -) - -go_template_instance( - name = "socket_list", - out = "socket_list.go", - package = "kernel", - prefix = "socket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*SocketEntry", - "Linker": "*SocketEntry", - }, -) - -proto_library( - name = "uncaught_signal", - srcs = ["uncaught_signal.proto"], - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_proto"], -) - -go_library( - name = "kernel", - srcs = [ - "abstract_socket_namespace.go", - "context.go", - "fd_table.go", - "fd_table_unsafe.go", - "fs_context.go", - "ipc_namespace.go", - "kernel.go", - "kernel_opts.go", - "kernel_state.go", - "pending_signals.go", - "pending_signals_list.go", - "pending_signals_state.go", - "posixtimer.go", - "process_group_list.go", - "ptrace.go", - "ptrace_amd64.go", - "ptrace_arm64.go", - "rseq.go", - "seccomp.go", - "seqatomic_taskgoroutineschedinfo_unsafe.go", - "session_list.go", - "sessions.go", - "signal.go", - "signal_handlers.go", - "socket_list.go", - "syscalls.go", - "syscalls_state.go", - "syslog.go", - "task.go", - "task_acct.go", - "task_block.go", - "task_clone.go", - "task_context.go", - "task_exec.go", - "task_exit.go", - "task_futex.go", - "task_identity.go", - "task_list.go", - "task_log.go", - "task_net.go", - "task_run.go", - "task_sched.go", - "task_signals.go", - "task_start.go", - "task_stop.go", - "task_syscall.go", - "task_usermem.go", - "thread_group.go", - "threads.go", - "timekeeper.go", - "timekeeper_state.go", - "tty.go", - "uts_namespace.go", - "vdso.go", - "version.go", - ], - imports = [ - "gvisor.dev/gvisor/pkg/bpf", - "gvisor.dev/gvisor/pkg/sentry/device", - "gvisor.dev/gvisor/pkg/tcpip", - ], - visibility = ["//:sandbox"], - deps = [ - ":uncaught_signal_go_proto", - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/binary", - "//pkg/bits", - "//pkg/bpf", - "//pkg/context", - "//pkg/cpuid", - "//pkg/eventchannel", - "//pkg/fspath", - "//pkg/log", - "//pkg/metric", - "//pkg/refs", - "//pkg/safemem", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fs/timerfd", - "//pkg/sentry/fsbridge", - "//pkg/sentry/hostcpu", - "//pkg/sentry/inet", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/futex", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/kernel/semaphore", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/socket/netlink/port", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/time", - "//pkg/sentry/unimpl", - "//pkg/sentry/unimpl:unimplemented_syscall_go_proto", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/state", - "//pkg/state/statefile", - "//pkg/sync", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/stack", - "//pkg/usermem", - "//pkg/waiter", - "//tools/go_marshal/marshal", - ], -) - -go_test( - name = "kernel_test", - size = "small", - srcs = [ - "fd_table_test.go", - "table_test.go", - "task_test.go", - "timekeeper_test.go", - ], - library = ":kernel", - deps = [ - "//pkg/abi", - "//pkg/context", - "//pkg/sentry/arch", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/filetest", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/limits", - "//pkg/sentry/pgalloc", - "//pkg/sentry/time", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/kernel/README.md b/pkg/sentry/kernel/README.md deleted file mode 100644 index 427311be8..000000000 --- a/pkg/sentry/kernel/README.md +++ /dev/null @@ -1,108 +0,0 @@ -This package contains: - -- A (partial) emulation of the "core Linux kernel", which governs task - execution and scheduling, system call dispatch, and signal handling. See - below for details. - -- The top-level interface for the sentry's Linux kernel emulation in general, - used by the `main` function of all versions of the sentry. This interface - revolves around the `Env` type (defined in `kernel.go`). - -# Background - -In Linux, each schedulable context is referred to interchangeably as a "task" or -"thread". Tasks can be divided into userspace and kernel tasks. In the sentry, -scheduling is managed by the Go runtime, so each schedulable context is a -goroutine; only "userspace" (application) contexts are referred to as tasks, and -represented by Task objects. (From this point forward, "task" refers to the -sentry's notion of a task unless otherwise specified.) - -At a high level, Linux application threads can be thought of as repeating a "run -loop": - -- Some amount of application code is executed in userspace. - -- A trap (explicit syscall invocation, hardware interrupt or exception, etc.) - causes control flow to switch to the kernel. - -- Some amount of kernel code is executed in kernelspace, e.g. to handle the - cause of the trap. - -- The kernel "returns from the trap" into application code. - -Analogously, each task in the sentry is associated with a *task goroutine* that -executes that task's run loop (`Task.run` in `task_run.go`). However, the -sentry's task run loop differs in structure in order to support saving execution -state to, and resuming execution from, checkpoints. - -While in kernelspace, a Linux thread can be descheduled (cease execution) in a -variety of ways: - -- It can yield or be preempted, becoming temporarily descheduled but still - runnable. At present, the sentry delegates scheduling of runnable threads to - the Go runtime. - -- It can exit, becoming permanently descheduled. The sentry's equivalent is - returning from `Task.run`, terminating the task goroutine. - -- It can enter interruptible sleep, a state in which it can be woken by a - caller-defined wakeup or the receipt of a signal. In the sentry, - interruptible sleep (which is ambiguously referred to as *blocking*) is - implemented by making all events that can end blocking (including signal - notifications) communicated via Go channels and using `select` to multiplex - wakeup sources; see `task_block.go`. - -- It can enter uninterruptible sleep, a state in which it can only be woken by - a caller-defined wakeup. Killable sleep is a closely related variant in - which the task can also be woken by SIGKILL. (These definitions also include - Linux's "group-stopped" (`TASK_STOPPED`) and "ptrace-stopped" - (`TASK_TRACED`) states.) - -To maximize compatibility with Linux, sentry checkpointing appears as a spurious -signal-delivery interrupt on all tasks; interrupted system calls return `EINTR` -or are automatically restarted as usual. However, these semantics require that -uninterruptible and killable sleeps do not appear to be interrupted. In other -words, the state of the task, including its progress through the interrupted -operation, must be preserved by checkpointing. For many such sleeps, the wakeup -condition is application-controlled, making it infeasible to wait for the sleep -to end before checkpointing. Instead, we must support checkpointing progress -through sleeping operations. - -# Implementation - -We break the task's control flow graph into *states*, delimited by: - -1. Points where uninterruptible and killable sleeps may occur. For example, - there exists a state boundary between signal dequeueing and signal delivery - because there may be an intervening ptrace signal-delivery-stop. - -2. Points where sleep-induced branches may "rejoin" normal execution. For - example, the syscall exit state exists because it can be reached immediately - following a synchronous syscall, or after a task that is sleeping in - `execve()` or `vfork()` resumes execution. - -3. Points containing large branches. This is strictly for organizational - purposes. For example, the state that processes interrupt-signaled - conditions is kept separate from the main "app" state to reduce the size of - the latter. - -4. `SyscallReinvoke`, which does not correspond to anything in Linux, and - exists solely to serve the autosave feature. - -![dot -Tpng -Goverlap=false -orun_states.png run_states.dot](g3doc/run_states.png "Task control flow graph") - -States before which a stop may occur are represented as implementations of the -`taskRunState` interface named `run(state)`, allowing them to be saved and -restored. States that cannot be immediately preceded by a stop are simply `Task` -methods named `do(state)`. - -Conditions that can require task goroutines to cease execution for unknown -lengths of time are called *stops*. Stops are divided into *internal stops*, -which are stops whose start and end conditions are implemented within the -sentry, and *external stops*, which are stops whose start and end conditions are -not known to the sentry. Hence all uninterruptible and killable sleeps are -internal stops, and the existence of a pending checkpoint operation is an -external stop. Internal stops are reified into instances of the `TaskStop` type, -while external stops are merely counted. The task run loop alternates between -checking for stops and advancing the task's state. This allows checkpointing to -hold tasks in a stopped state while waiting for all tasks in the system to stop. diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD deleted file mode 100644 index 2bc49483a..000000000 --- a/pkg/sentry/kernel/auth/BUILD +++ /dev/null @@ -1,69 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_credentials", - out = "atomicptr_credentials_unsafe.go", - package = "auth", - suffix = "Credentials", - template = "//pkg/sync:generic_atomicptr", - types = { - "Value": "Credentials", - }, -) - -go_template_instance( - name = "id_map_range", - out = "id_map_range.go", - package = "auth", - prefix = "idMap", - template = "//pkg/segment:generic_range", - types = { - "T": "uint32", - }, -) - -go_template_instance( - name = "id_map_set", - out = "id_map_set.go", - consts = { - "minDegree": "3", - }, - package = "auth", - prefix = "idMap", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint32", - "Range": "idMapRange", - "Value": "uint32", - "Functions": "idMapFunctions", - }, -) - -go_library( - name = "auth", - srcs = [ - "atomicptr_credentials_unsafe.go", - "auth.go", - "capability_set.go", - "context.go", - "credentials.go", - "id.go", - "id_map.go", - "id_map_functions.go", - "id_map_range.go", - "id_map_set.go", - "user_namespace.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/context", - "//pkg/log", - "//pkg/sync", - "//pkg/syserror", - ], -) diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sentry/kernel/auth/atomicptr_credentials_unsafe.go index 525c4beed..4535c958f 100644..100755 --- a/pkg/sync/atomicptr_unsafe.go +++ b/pkg/sentry/kernel/auth/atomicptr_credentials_unsafe.go @@ -1,20 +1,10 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template +package auth import ( "sync/atomic" "unsafe" ) -// Value is a required type parameter. -type Value struct{} - // An AtomicPtr is a pointer to a value of type Value that can be atomically // loaded and stored. The zero value of an AtomicPtr represents nil. // @@ -23,25 +13,25 @@ type Value struct{} // this case, do `dst.Store(src.Load())` instead. // // +stateify savable -type AtomicPtr struct { - ptr unsafe.Pointer `state:".(*Value)"` +type AtomicPtrCredentials struct { + ptr unsafe.Pointer `state:".(*Credentials)"` } -func (p *AtomicPtr) savePtr() *Value { +func (p *AtomicPtrCredentials) savePtr() *Credentials { return p.Load() } -func (p *AtomicPtr) loadPtr(v *Value) { +func (p *AtomicPtrCredentials) loadPtr(v *Credentials) { p.Store(v) } // Load returns the value set by the most recent Store. It returns nil if there // has been no previous call to Store. -func (p *AtomicPtr) Load() *Value { - return (*Value)(atomic.LoadPointer(&p.ptr)) +func (p *AtomicPtrCredentials) Load() *Credentials { + return (*Credentials)(atomic.LoadPointer(&p.ptr)) } // Store sets the value returned by Load to x. -func (p *AtomicPtr) Store(x *Value) { +func (p *AtomicPtrCredentials) Store(x *Credentials) { atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) } diff --git a/pkg/sentry/kernel/auth/auth_state_autogen.go b/pkg/sentry/kernel/auth/auth_state_autogen.go new file mode 100755 index 000000000..09ca564b8 --- /dev/null +++ b/pkg/sentry/kernel/auth/auth_state_autogen.go @@ -0,0 +1,164 @@ +// automatically generated by stateify. + +package auth + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *AtomicPtrCredentials) beforeSave() {} +func (x *AtomicPtrCredentials) save(m state.Map) { + x.beforeSave() + var ptr *Credentials = x.savePtr() + m.SaveValue("ptr", ptr) +} + +func (x *AtomicPtrCredentials) afterLoad() {} +func (x *AtomicPtrCredentials) load(m state.Map) { + m.LoadValue("ptr", new(*Credentials), func(y interface{}) { x.loadPtr(y.(*Credentials)) }) +} + +func (x *Credentials) beforeSave() {} +func (x *Credentials) save(m state.Map) { + x.beforeSave() + m.Save("RealKUID", &x.RealKUID) + m.Save("EffectiveKUID", &x.EffectiveKUID) + m.Save("SavedKUID", &x.SavedKUID) + m.Save("RealKGID", &x.RealKGID) + m.Save("EffectiveKGID", &x.EffectiveKGID) + m.Save("SavedKGID", &x.SavedKGID) + m.Save("ExtraKGIDs", &x.ExtraKGIDs) + m.Save("PermittedCaps", &x.PermittedCaps) + m.Save("InheritableCaps", &x.InheritableCaps) + m.Save("EffectiveCaps", &x.EffectiveCaps) + m.Save("BoundingCaps", &x.BoundingCaps) + m.Save("KeepCaps", &x.KeepCaps) + m.Save("UserNamespace", &x.UserNamespace) +} + +func (x *Credentials) afterLoad() {} +func (x *Credentials) load(m state.Map) { + m.Load("RealKUID", &x.RealKUID) + m.Load("EffectiveKUID", &x.EffectiveKUID) + m.Load("SavedKUID", &x.SavedKUID) + m.Load("RealKGID", &x.RealKGID) + m.Load("EffectiveKGID", &x.EffectiveKGID) + m.Load("SavedKGID", &x.SavedKGID) + m.Load("ExtraKGIDs", &x.ExtraKGIDs) + m.Load("PermittedCaps", &x.PermittedCaps) + m.Load("InheritableCaps", &x.InheritableCaps) + m.Load("EffectiveCaps", &x.EffectiveCaps) + m.Load("BoundingCaps", &x.BoundingCaps) + m.Load("KeepCaps", &x.KeepCaps) + m.Load("UserNamespace", &x.UserNamespace) +} + +func (x *IDMapEntry) beforeSave() {} +func (x *IDMapEntry) save(m state.Map) { + x.beforeSave() + m.Save("FirstID", &x.FirstID) + m.Save("FirstParentID", &x.FirstParentID) + m.Save("Length", &x.Length) +} + +func (x *IDMapEntry) afterLoad() {} +func (x *IDMapEntry) load(m state.Map) { + m.Load("FirstID", &x.FirstID) + m.Load("FirstParentID", &x.FirstParentID) + m.Load("Length", &x.Length) +} + +func (x *idMapRange) beforeSave() {} +func (x *idMapRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *idMapRange) afterLoad() {} +func (x *idMapRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *idMapSet) beforeSave() {} +func (x *idMapSet) save(m state.Map) { + x.beforeSave() + var root *idMapSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *idMapSet) afterLoad() {} +func (x *idMapSet) load(m state.Map) { + m.LoadValue("root", new(*idMapSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*idMapSegmentDataSlices)) }) +} + +func (x *idMapnode) beforeSave() {} +func (x *idMapnode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *idMapnode) afterLoad() {} +func (x *idMapnode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *idMapSegmentDataSlices) beforeSave() {} +func (x *idMapSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *idMapSegmentDataSlices) afterLoad() {} +func (x *idMapSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *UserNamespace) beforeSave() {} +func (x *UserNamespace) save(m state.Map) { + x.beforeSave() + m.Save("parent", &x.parent) + m.Save("owner", &x.owner) + m.Save("uidMapFromParent", &x.uidMapFromParent) + m.Save("uidMapToParent", &x.uidMapToParent) + m.Save("gidMapFromParent", &x.gidMapFromParent) + m.Save("gidMapToParent", &x.gidMapToParent) +} + +func (x *UserNamespace) afterLoad() {} +func (x *UserNamespace) load(m state.Map) { + m.Load("parent", &x.parent) + m.Load("owner", &x.owner) + m.Load("uidMapFromParent", &x.uidMapFromParent) + m.Load("uidMapToParent", &x.uidMapToParent) + m.Load("gidMapFromParent", &x.gidMapFromParent) + m.Load("gidMapToParent", &x.gidMapToParent) +} + +func init() { + state.Register("pkg/sentry/kernel/auth.AtomicPtrCredentials", (*AtomicPtrCredentials)(nil), state.Fns{Save: (*AtomicPtrCredentials).save, Load: (*AtomicPtrCredentials).load}) + state.Register("pkg/sentry/kernel/auth.Credentials", (*Credentials)(nil), state.Fns{Save: (*Credentials).save, Load: (*Credentials).load}) + state.Register("pkg/sentry/kernel/auth.IDMapEntry", (*IDMapEntry)(nil), state.Fns{Save: (*IDMapEntry).save, Load: (*IDMapEntry).load}) + state.Register("pkg/sentry/kernel/auth.idMapRange", (*idMapRange)(nil), state.Fns{Save: (*idMapRange).save, Load: (*idMapRange).load}) + state.Register("pkg/sentry/kernel/auth.idMapSet", (*idMapSet)(nil), state.Fns{Save: (*idMapSet).save, Load: (*idMapSet).load}) + state.Register("pkg/sentry/kernel/auth.idMapnode", (*idMapnode)(nil), state.Fns{Save: (*idMapnode).save, Load: (*idMapnode).load}) + state.Register("pkg/sentry/kernel/auth.idMapSegmentDataSlices", (*idMapSegmentDataSlices)(nil), state.Fns{Save: (*idMapSegmentDataSlices).save, Load: (*idMapSegmentDataSlices).load}) + state.Register("pkg/sentry/kernel/auth.UserNamespace", (*UserNamespace)(nil), state.Fns{Save: (*UserNamespace).save, Load: (*UserNamespace).load}) +} diff --git a/pkg/sentry/kernel/auth/id_map_range.go b/pkg/sentry/kernel/auth/id_map_range.go new file mode 100755 index 000000000..833fa3518 --- /dev/null +++ b/pkg/sentry/kernel/auth/id_map_range.go @@ -0,0 +1,62 @@ +package auth + +// A Range represents a contiguous range of T. +// +// +stateify savable +type idMapRange struct { + // Start is the inclusive start of the range. + Start uint32 + + // End is the exclusive end of the range. + End uint32 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r idMapRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r idMapRange) Length() uint32 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r idMapRange) Contains(x uint32) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r idMapRange) Overlaps(r2 idMapRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r idMapRange) IsSupersetOf(r2 idMapRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r idMapRange) Intersect(r2 idMapRange) idMapRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r idMapRange) CanSplitAt(x uint32) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/kernel/auth/id_map_set.go b/pkg/sentry/kernel/auth/id_map_set.go new file mode 100755 index 000000000..73a17f281 --- /dev/null +++ b/pkg/sentry/kernel/auth/id_map_set.go @@ -0,0 +1,1270 @@ +package auth + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + idMapminDegree = 3 + + idMapmaxDegree = 2 * idMapminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type idMapSet struct { + root idMapnode `state:".(*idMapSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *idMapSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *idMapSet) IsEmptyRange(r idMapRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *idMapSet) Span() uint32 { + var sz uint32 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *idMapSet) SpanRange(r idMapRange) uint32 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint32 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *idMapSet) FirstSegment() idMapIterator { + if s.root.nrSegments == 0 { + return idMapIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *idMapSet) LastSegment() idMapIterator { + if s.root.nrSegments == 0 { + return idMapIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *idMapSet) FirstGap() idMapGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return idMapGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *idMapSet) LastGap() idMapGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return idMapGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *idMapSet) Find(key uint32) (idMapIterator, idMapGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return idMapIterator{n, i}, idMapGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return idMapIterator{}, idMapGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *idMapSet) FindSegment(key uint32) idMapIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *idMapSet) LowerBoundSegment(min uint32) idMapIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *idMapSet) UpperBoundSegment(max uint32) idMapIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *idMapSet) FindGap(key uint32) idMapGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *idMapSet) LowerBoundGap(min uint32) idMapGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *idMapSet) UpperBoundGap(max uint32) idMapGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *idMapSet) Add(r idMapRange, val uint32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *idMapSet) AddWithoutMerging(r idMapRange, val uint32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *idMapSet) Insert(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (idMapFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (idMapFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (idMapFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *idMapSet) InsertWithoutMerging(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *idMapSet) InsertWithoutMergingUnchecked(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return idMapIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *idMapSet) Remove(seg idMapIterator) idMapGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + idMapFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(idMapGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *idMapSet) RemoveAll() { + s.root = idMapnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *idMapSet) RemoveRange(r idMapRange) idMapGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *idMapSet) Merge(first, second idMapIterator) idMapIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *idMapSet) MergeUnchecked(first, second idMapIterator) idMapIterator { + if first.End() == second.Start() { + if mval, ok := (idMapFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return idMapIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *idMapSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *idMapSet) MergeRange(r idMapRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *idMapSet) MergeAdjacent(r idMapRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *idMapSet) Split(seg idMapIterator, split uint32) (idMapIterator, idMapIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *idMapSet) SplitUnchecked(seg idMapIterator, split uint32) (idMapIterator, idMapIterator) { + val1, val2 := (idMapFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), idMapRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *idMapSet) SplitAt(split uint32) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *idMapSet) Isolate(seg idMapIterator, r idMapRange) idMapIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *idMapSet) ApplyContiguous(r idMapRange, fn func(seg idMapIterator)) idMapGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return idMapGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return idMapGapIterator{} + } + } +} + +// +stateify savable +type idMapnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *idMapnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [idMapmaxDegree - 1]idMapRange + values [idMapmaxDegree - 1]uint32 + children [idMapmaxDegree]*idMapnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *idMapnode) firstSegment() idMapIterator { + for n.hasChildren { + n = n.children[0] + } + return idMapIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *idMapnode) lastSegment() idMapIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return idMapIterator{n, n.nrSegments - 1} +} + +func (n *idMapnode) prevSibling() *idMapnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *idMapnode) nextSibling() *idMapnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *idMapnode) rebalanceBeforeInsert(gap idMapGapIterator) idMapGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < idMapmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:idMapminDegree-1], n.keys[:idMapminDegree-1]) + copy(left.values[:idMapminDegree-1], n.values[:idMapminDegree-1]) + copy(right.keys[:idMapminDegree-1], n.keys[idMapminDegree:]) + copy(right.values[:idMapminDegree-1], n.values[idMapminDegree:]) + n.keys[0], n.values[0] = n.keys[idMapminDegree-1], n.values[idMapminDegree-1] + idMapzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:idMapminDegree], n.children[:idMapminDegree]) + copy(right.children[:idMapminDegree], n.children[idMapminDegree:]) + idMapzeroNodeSlice(n.children[2:]) + for i := 0; i < idMapminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < idMapminDegree { + return idMapGapIterator{left, gap.index} + } + return idMapGapIterator{right, gap.index - idMapminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[idMapminDegree-1], n.values[idMapminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:idMapminDegree-1], n.keys[idMapminDegree:]) + copy(sibling.values[:idMapminDegree-1], n.values[idMapminDegree:]) + idMapzeroValueSlice(n.values[idMapminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:idMapminDegree], n.children[idMapminDegree:]) + idMapzeroNodeSlice(n.children[idMapminDegree:]) + for i := 0; i < idMapminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = idMapminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < idMapminDegree { + return gap + } + return idMapGapIterator{sibling, gap.index - idMapminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *idMapnode) rebalanceAfterRemove(gap idMapGapIterator) idMapGapIterator { + for { + if n.nrSegments >= idMapminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= idMapminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + idMapFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return idMapGapIterator{n, 0} + } + if gap.node == n { + return idMapGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= idMapminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + idMapFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return idMapGapIterator{n, n.nrSegments} + } + return idMapGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return idMapGapIterator{p, gap.index} + } + if gap.node == right { + return idMapGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *idMapnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = idMapGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + idMapFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type idMapIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *idMapnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg idMapIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg idMapIterator) Range() idMapRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg idMapIterator) Start() uint32 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg idMapIterator) End() uint32 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg idMapIterator) SetRangeUnchecked(r idMapRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetRange(r idMapRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg idMapIterator) SetStartUnchecked(start uint32) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetStart(start uint32) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg idMapIterator) SetEndUnchecked(end uint32) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetEnd(end uint32) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg idMapIterator) Value() uint32 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg idMapIterator) ValuePtr() *uint32 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg idMapIterator) SetValue(val uint32) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg idMapIterator) PrevSegment() idMapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return idMapIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return idMapIterator{} + } + return idMapsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg idMapIterator) NextSegment() idMapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return idMapIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return idMapIterator{} + } + return idMapsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg idMapIterator) PrevGap() idMapGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return idMapGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg idMapIterator) NextGap() idMapGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return idMapGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg idMapIterator) PrevNonEmpty() (idMapIterator, idMapGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return idMapIterator{}, gap + } + return gap.PrevSegment(), idMapGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg idMapIterator) NextNonEmpty() (idMapIterator, idMapGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return idMapIterator{}, gap + } + return gap.NextSegment(), idMapGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type idMapGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *idMapnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap idMapGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap idMapGapIterator) Range() idMapRange { + return idMapRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap idMapGapIterator) Start() uint32 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return idMapFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap idMapGapIterator) End() uint32 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return idMapFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap idMapGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap idMapGapIterator) PrevSegment() idMapIterator { + return idMapsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap idMapGapIterator) NextSegment() idMapIterator { + return idMapsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap idMapGapIterator) PrevGap() idMapGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return idMapGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap idMapGapIterator) NextGap() idMapGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return idMapGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func idMapsegmentBeforePosition(n *idMapnode, i int) idMapIterator { + for i == 0 { + if n.parent == nil { + return idMapIterator{} + } + n, i = n.parent, n.parentIndex + } + return idMapIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func idMapsegmentAfterPosition(n *idMapnode, i int) idMapIterator { + for i == n.nrSegments { + if n.parent == nil { + return idMapIterator{} + } + n, i = n.parent, n.parentIndex + } + return idMapIterator{n, i} +} + +func idMapzeroValueSlice(slice []uint32) { + + for i := range slice { + idMapFunctions{}.ClearValue(&slice[i]) + } +} + +func idMapzeroNodeSlice(slice []*idMapnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *idMapSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *idMapnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *idMapnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type idMapSegmentDataSlices struct { + Start []uint32 + End []uint32 + Values []uint32 +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *idMapSet) ExportSortedSlices() *idMapSegmentDataSlices { + var sds idMapSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *idMapSet) ImportSortedSlices(sds *idMapSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := idMapRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *idMapSet) saveRoot() *idMapSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *idMapSet) loadRoot(sds *idMapSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD deleted file mode 100644 index 9d26392c0..000000000 --- a/pkg/sentry/kernel/contexttest/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "contexttest", - testonly = 1, - srcs = ["contexttest.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/kernel", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - ], -) diff --git a/pkg/sentry/kernel/contexttest/contexttest.go b/pkg/sentry/kernel/contexttest/contexttest.go deleted file mode 100644 index 22c340e56..000000000 --- a/pkg/sentry/kernel/contexttest/contexttest.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package contexttest provides a test context.Context which includes -// a dummy kernel pointing to a valid platform. -package contexttest - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" -) - -// Context returns a Context that may be used in tests. Uses ptrace as the -// platform.Platform, and provides a stub kernel that only serves to point to -// the platform. -func Context(tb testing.TB) context.Context { - ctx := contexttest.Context(tb) - k := &kernel.Kernel{ - Platform: platform.FromContext(ctx), - } - k.SetMemoryFile(pgalloc.MemoryFileFromContext(ctx)) - ctx.(*contexttest.TestContext).RegisterValue(kernel.CtxKernel, k) - return ctx -} diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD deleted file mode 100644 index dedf0fa15..000000000 --- a/pkg/sentry/kernel/epoll/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "epoll_list", - out = "epoll_list.go", - package = "epoll", - prefix = "pollEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*pollEntry", - "Linker": "*pollEntry", - }, -) - -go_library( - name = "epoll", - srcs = [ - "epoll.go", - "epoll_list.go", - "epoll_state.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/refs", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "epoll_test", - size = "small", - srcs = [ - "epoll_test.go", - ], - library = ":epoll", - deps = [ - "//pkg/sentry/contexttest", - "//pkg/sentry/fs/filetest", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/epoll/epoll_list.go b/pkg/sentry/kernel/epoll/epoll_list.go new file mode 100755 index 000000000..37f757fa8 --- /dev/null +++ b/pkg/sentry/kernel/epoll/epoll_list.go @@ -0,0 +1,186 @@ +package epoll + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type pollEntryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (pollEntryElementMapper) linkerFor(elem *pollEntry) *pollEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type pollEntryList struct { + head *pollEntry + tail *pollEntry +} + +// Reset resets list l to the empty state. +func (l *pollEntryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *pollEntryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *pollEntryList) Front() *pollEntry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *pollEntryList) Back() *pollEntry { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *pollEntryList) PushFront(e *pollEntry) { + linker := pollEntryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + pollEntryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *pollEntryList) PushBack(e *pollEntry) { + linker := pollEntryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + pollEntryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *pollEntryList) PushBackList(m *pollEntryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + pollEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + pollEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *pollEntryList) InsertAfter(b, e *pollEntry) { + bLinker := pollEntryElementMapper{}.linkerFor(b) + eLinker := pollEntryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + pollEntryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *pollEntryList) InsertBefore(a, e *pollEntry) { + aLinker := pollEntryElementMapper{}.linkerFor(a) + eLinker := pollEntryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + pollEntryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *pollEntryList) Remove(e *pollEntry) { + linker := pollEntryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + pollEntryElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + pollEntryElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type pollEntryEntry struct { + next *pollEntry + prev *pollEntry +} + +// Next returns the entry that follows e in the list. +func (e *pollEntryEntry) Next() *pollEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *pollEntryEntry) Prev() *pollEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *pollEntryEntry) SetNext(elem *pollEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *pollEntryEntry) SetPrev(elem *pollEntry) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/epoll/epoll_state_autogen.go b/pkg/sentry/kernel/epoll/epoll_state_autogen.go new file mode 100755 index 000000000..afdea5bcf --- /dev/null +++ b/pkg/sentry/kernel/epoll/epoll_state_autogen.go @@ -0,0 +1,113 @@ +// automatically generated by stateify. + +package epoll + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FileIdentifier) beforeSave() {} +func (x *FileIdentifier) save(m state.Map) { + x.beforeSave() + m.Save("File", &x.File) + m.Save("Fd", &x.Fd) +} + +func (x *FileIdentifier) afterLoad() {} +func (x *FileIdentifier) load(m state.Map) { + m.LoadWait("File", &x.File) + m.Load("Fd", &x.Fd) +} + +func (x *pollEntry) beforeSave() {} +func (x *pollEntry) save(m state.Map) { + x.beforeSave() + m.Save("pollEntryEntry", &x.pollEntryEntry) + m.Save("id", &x.id) + m.Save("userData", &x.userData) + m.Save("mask", &x.mask) + m.Save("flags", &x.flags) + m.Save("epoll", &x.epoll) +} + +func (x *pollEntry) load(m state.Map) { + m.Load("pollEntryEntry", &x.pollEntryEntry) + m.LoadWait("id", &x.id) + m.Load("userData", &x.userData) + m.Load("mask", &x.mask) + m.Load("flags", &x.flags) + m.Load("epoll", &x.epoll) + m.AfterLoad(x.afterLoad) +} + +func (x *EventPoll) beforeSave() {} +func (x *EventPoll) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.FilePipeSeek) { + m.Failf("FilePipeSeek is %v, expected zero", x.FilePipeSeek) + } + if !state.IsZeroValue(x.FileNotDirReaddir) { + m.Failf("FileNotDirReaddir is %v, expected zero", x.FileNotDirReaddir) + } + if !state.IsZeroValue(x.FileNoFsync) { + m.Failf("FileNoFsync is %v, expected zero", x.FileNoFsync) + } + if !state.IsZeroValue(x.FileNoopFlush) { + m.Failf("FileNoopFlush is %v, expected zero", x.FileNoopFlush) + } + if !state.IsZeroValue(x.FileNoIoctl) { + m.Failf("FileNoIoctl is %v, expected zero", x.FileNoIoctl) + } + if !state.IsZeroValue(x.FileNoMMap) { + m.Failf("FileNoMMap is %v, expected zero", x.FileNoMMap) + } + if !state.IsZeroValue(x.Queue) { + m.Failf("Queue is %v, expected zero", x.Queue) + } + m.Save("files", &x.files) + m.Save("readyList", &x.readyList) + m.Save("waitingList", &x.waitingList) + m.Save("disabledList", &x.disabledList) +} + +func (x *EventPoll) load(m state.Map) { + m.Load("files", &x.files) + m.Load("readyList", &x.readyList) + m.Load("waitingList", &x.waitingList) + m.Load("disabledList", &x.disabledList) + m.AfterLoad(x.afterLoad) +} + +func (x *pollEntryList) beforeSave() {} +func (x *pollEntryList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *pollEntryList) afterLoad() {} +func (x *pollEntryList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *pollEntryEntry) beforeSave() {} +func (x *pollEntryEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *pollEntryEntry) afterLoad() {} +func (x *pollEntryEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/sentry/kernel/epoll.FileIdentifier", (*FileIdentifier)(nil), state.Fns{Save: (*FileIdentifier).save, Load: (*FileIdentifier).load}) + state.Register("pkg/sentry/kernel/epoll.pollEntry", (*pollEntry)(nil), state.Fns{Save: (*pollEntry).save, Load: (*pollEntry).load}) + state.Register("pkg/sentry/kernel/epoll.EventPoll", (*EventPoll)(nil), state.Fns{Save: (*EventPoll).save, Load: (*EventPoll).load}) + state.Register("pkg/sentry/kernel/epoll.pollEntryList", (*pollEntryList)(nil), state.Fns{Save: (*pollEntryList).save, Load: (*pollEntryList).load}) + state.Register("pkg/sentry/kernel/epoll.pollEntryEntry", (*pollEntryEntry)(nil), state.Fns{Save: (*pollEntryEntry).save, Load: (*pollEntryEntry).load}) +} diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go deleted file mode 100644 index 22630e9c5..000000000 --- a/pkg/sentry/kernel/epoll/epoll_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package epoll - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs/filetest" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestFileDestroyed(t *testing.T) { - f := filetest.NewTestFile(t) - id := FileIdentifier{f, 12} - - efile := NewEventPoll(contexttest.Context(t)) - e := efile.FileOperations.(*EventPoll) - if err := e.AddEntry(id, 0, waiter.EventIn, [2]int32{}); err != nil { - t.Fatalf("addEntry failed: %v", err) - } - - // Check that we get an event reported twice in a row. - evt := e.ReadEvents(1) - if len(evt) != 1 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 1, len(evt)) - } - - evt = e.ReadEvents(1) - if len(evt) != 1 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 1, len(evt)) - } - - // Destroy the file. Check that we get no more events. - f.DecRef() - - evt = e.ReadEvents(1) - if len(evt) != 0 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 0, len(evt)) - } - -} diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD deleted file mode 100644 index 9983a32e5..000000000 --- a/pkg/sentry/kernel/eventfd/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "eventfd", - srcs = ["eventfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fdnotifier", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "eventfd_test", - size = "small", - srcs = ["eventfd_test.go"], - library = ":eventfd", - deps = [ - "//pkg/sentry/contexttest", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go b/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go new file mode 100755 index 000000000..9cf0ac817 --- /dev/null +++ b/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go @@ -0,0 +1,29 @@ +// automatically generated by stateify. + +package eventfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *EventOperations) beforeSave() {} +func (x *EventOperations) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.wq) { + m.Failf("wq is %v, expected zero", x.wq) + } + m.Save("val", &x.val) + m.Save("semMode", &x.semMode) + m.Save("hostfd", &x.hostfd) +} + +func (x *EventOperations) afterLoad() {} +func (x *EventOperations) load(m state.Map) { + m.Load("val", &x.val) + m.Load("semMode", &x.semMode) + m.Load("hostfd", &x.hostfd) +} + +func init() { + state.Register("pkg/sentry/kernel/eventfd.EventOperations", (*EventOperations)(nil), state.Fns{Save: (*EventOperations).save, Load: (*EventOperations).load}) +} diff --git a/pkg/sentry/kernel/eventfd/eventfd_test.go b/pkg/sentry/kernel/eventfd/eventfd_test.go deleted file mode 100644 index 9b4892f74..000000000 --- a/pkg/sentry/kernel/eventfd/eventfd_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package eventfd - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestEventfd(t *testing.T) { - initVals := []uint64{ - 0, - // Using a non-zero initial value verifies that writing to an - // eventfd signals when the eventfd's counter was already - // non-zero. - 343, - } - - for _, initVal := range initVals { - ctx := contexttest.Context(t) - - // Make a new event that is writable. - event := New(ctx, initVal, false) - - // Register a callback for a write event. - w, ch := waiter.NewChannelEntry(nil) - event.EventRegister(&w, waiter.EventIn) - defer event.EventUnregister(&w) - - data := []byte("00000124") - // Create and submit a write request. - n, err := event.Writev(ctx, usermem.BytesIOSequence(data)) - if err != nil { - t.Fatal(err) - } - if n != 8 { - t.Errorf("eventfd.write wrote %d bytes, not full int64", n) - } - - // Check if the callback fired due to the write event. - select { - case <-ch: - default: - t.Errorf("Didn't get notified of EventIn after write") - } - } -} - -func TestEventfdStat(t *testing.T) { - ctx := contexttest.Context(t) - - // Make a new event that is writable. - event := New(ctx, 0, false) - - // Create and submit an stat request. - uattr, err := event.Dirent.Inode.UnstableAttr(ctx) - if err != nil { - t.Fatalf("eventfd stat request failed: %v", err) - } - if uattr.Size != 0 { - t.Fatal("EventFD size should be 0") - } -} diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD deleted file mode 100644 index b9126e946..000000000 --- a/pkg/sentry/kernel/fasync/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fasync", - srcs = ["fasync.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/fs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sync", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/fasync/fasync_state_autogen.go b/pkg/sentry/kernel/fasync/fasync_state_autogen.go new file mode 100755 index 000000000..fdcd48f64 --- /dev/null +++ b/pkg/sentry/kernel/fasync/fasync_state_autogen.go @@ -0,0 +1,32 @@ +// automatically generated by stateify. + +package fasync + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FileAsync) beforeSave() {} +func (x *FileAsync) save(m state.Map) { + x.beforeSave() + m.Save("e", &x.e) + m.Save("requester", &x.requester) + m.Save("registered", &x.registered) + m.Save("recipientPG", &x.recipientPG) + m.Save("recipientTG", &x.recipientTG) + m.Save("recipientT", &x.recipientT) +} + +func (x *FileAsync) afterLoad() {} +func (x *FileAsync) load(m state.Map) { + m.Load("e", &x.e) + m.Load("requester", &x.requester) + m.Load("registered", &x.registered) + m.Load("recipientPG", &x.recipientPG) + m.Load("recipientTG", &x.recipientTG) + m.Load("recipientT", &x.recipientT) +} + +func init() { + state.Register("pkg/sentry/kernel/fasync.FileAsync", (*FileAsync)(nil), state.Fns{Save: (*FileAsync).save, Load: (*FileAsync).load}) +} diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go deleted file mode 100644 index 29f95a2c4..000000000 --- a/pkg/sentry/kernel/fd_table_test.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/filetest" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sync" -) - -const ( - // maxFD is the maximum FD to try to create in the map. - // - // This number of open files has been seen in the wild. - maxFD = 2 * 1024 -) - -func runTest(t testing.TB, fn func(ctx context.Context, fdTable *FDTable, file *fs.File, limitSet *limits.LimitSet)) { - t.Helper() // Don't show in stacks. - - // Create the limits and context. - limitSet := limits.NewLimitSet() - limitSet.Set(limits.NumberOfFiles, limits.Limit{maxFD, maxFD}, true) - ctx := contexttest.WithLimitSet(contexttest.Context(t), limitSet) - - // Create a test file.; - file := filetest.NewTestFile(t) - - // Create the table. - fdTable := new(FDTable) - fdTable.init() - - // Run the test. - fn(ctx, fdTable, file, limitSet) -} - -// TestFDTableMany allocates maxFD FDs, i.e. maxes out the FDTable, until there -// is no room, then makes sure that NewFDAt works and also that if we remove -// one and add one that works too. -func TestFDTableMany(t *testing.T) { - runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - for i := 0; i < maxFD; i++ { - if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { - t.Fatalf("Allocated %v FDs but wanted to allocate %v", i, maxFD) - } - } - - if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err == nil { - t.Fatalf("fdTable.NewFDs(0, r) in full map: got nil, wanted error") - } - - if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { - t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) - } - - i := int32(2) - fdTable.Remove(i) - if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { - t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) - } - }) -} - -func TestFDTableOverLimit(t *testing.T) { - runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil { - t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error") - } - - if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil { - t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error") - } - - if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil { - t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) - } else { - for _, fd := range fds { - fdTable.Remove(fd) - } - } - - if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 { - t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) - } - - if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { - t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) - } else if len(fds) != 1 || fds[0] != 0 { - t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) - } - }) -} - -// TestFDTable does a set of simple tests to make sure simple adds, removes, -// GetRefs, and DecRefs work. The ordering is just weird enough that a -// table-driven approach seemed clumsy. -func TestFDTable(t *testing.T) { - runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, limitSet *limits.LimitSet) { - // Cap the limit at one. - limitSet.Set(limits.NumberOfFiles, limits.Limit{1, maxFD}, true) - - if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { - t.Fatalf("Adding an FD to an empty 1-size map: got %v, want nil", err) - } - - if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err == nil { - t.Fatalf("Adding an FD to a filled 1-size map: got nil, wanted an error") - } - - // Remove the previous limit. - limitSet.Set(limits.NumberOfFiles, limits.Limit{maxFD, maxFD}, true) - - if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { - t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) - } else if len(fds) != 1 || fds[0] != 1 { - t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) - } - - if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { - t.Fatalf("Replacing FD 1 via fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) - } - - if err := fdTable.NewFDAt(ctx, maxFD+1, file, FDFlags{}); err == nil { - t.Fatalf("Using an FD that was too large via fdTable.NewFDAt(%v, r, FDFlags{}): got nil, wanted an error", maxFD+1) - } - - if ref, _ := fdTable.Get(1); ref == nil { - t.Fatalf("fdTable.Get(1): got nil, wanted %v", file) - } - - if ref, _ := fdTable.Get(2); ref != nil { - t.Fatalf("fdTable.Get(2): got a %v, wanted nil", ref) - } - - ref, _ := fdTable.Remove(1) - if ref == nil { - t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success") - } - ref.DecRef() - - if ref, _ := fdTable.Remove(1); ref != nil { - t.Fatalf("r.Remove(1) for a removed FD: got success, want failure") - } - }) -} - -func TestDescriptorFlags(t *testing.T) { - runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - if err := fdTable.NewFDAt(ctx, 2, file, FDFlags{CloseOnExec: true}); err != nil { - t.Fatalf("fdTable.NewFDAt(2, r, FDFlags{}): got %v, wanted nil", err) - } - - newFile, flags := fdTable.Get(2) - if newFile == nil { - t.Fatalf("fdTable.Get(2): got a %v, wanted nil", newFile) - } - - if !flags.CloseOnExec { - t.Fatalf("new File flags %v don't match original %d\n", flags, 0) - } - }) -} - -func BenchmarkFDLookupAndDecRef(b *testing.B) { - b.StopTimer() // Setup. - - runTest(b, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file, file, file, file, file}, FDFlags{}) - if err != nil { - b.Fatalf("fdTable.NewFDs: got %v, wanted nil", err) - } - - b.StartTimer() // Benchmark. - for i := 0; i < b.N; i++ { - tf, _ := fdTable.Get(fds[i%len(fds)]) - tf.DecRef() - } - }) -} - -func BenchmarkFDLookupAndDecRefConcurrent(b *testing.B) { - b.StopTimer() // Setup. - - runTest(b, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file, file, file, file, file}, FDFlags{}) - if err != nil { - b.Fatalf("fdTable.NewFDs: got %v, wanted nil", err) - } - - concurrency := runtime.GOMAXPROCS(0) - if concurrency < 4 { - concurrency = 4 - } - each := b.N / concurrency - - b.StartTimer() // Benchmark. - var wg sync.WaitGroup - for i := 0; i < concurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < each; i++ { - tf, _ := fdTable.Get(fds[i%len(fds)]) - tf.DecRef() - } - }() - } - wg.Wait() - }) -} diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD deleted file mode 100644 index c5021f2db..000000000 --- a/pkg/sentry/kernel/futex/BUILD +++ /dev/null @@ -1,57 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_bucket", - out = "atomicptr_bucket_unsafe.go", - package = "futex", - suffix = "Bucket", - template = "//pkg/sync:generic_atomicptr", - types = { - "Value": "bucket", - }, -) - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "futex", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Waiter", - "Linker": "*Waiter", - }, -) - -go_library( - name = "futex", - srcs = [ - "atomicptr_bucket_unsafe.go", - "futex.go", - "waiter_list.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/sentry/memmap", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) - -go_test( - name = "futex_test", - size = "small", - srcs = ["futex_test.go"], - library = ":futex", - deps = [ - "//pkg/sync", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/kernel/futex/atomicptr_bucket_unsafe.go b/pkg/sentry/kernel/futex/atomicptr_bucket_unsafe.go new file mode 100755 index 000000000..d3fdf09b0 --- /dev/null +++ b/pkg/sentry/kernel/futex/atomicptr_bucket_unsafe.go @@ -0,0 +1,37 @@ +package futex + +import ( + "sync/atomic" + "unsafe" +) + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +// +// +stateify savable +type AtomicPtrBucket struct { + ptr unsafe.Pointer `state:".(*bucket)"` +} + +func (p *AtomicPtrBucket) savePtr() *bucket { + return p.Load() +} + +func (p *AtomicPtrBucket) loadPtr(v *bucket) { + p.Store(v) +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtrBucket) Load() *bucket { + return (*bucket)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtrBucket) Store(x *bucket) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/sentry/kernel/futex/futex_state_autogen.go b/pkg/sentry/kernel/futex/futex_state_autogen.go new file mode 100755 index 000000000..d5ed3466f --- /dev/null +++ b/pkg/sentry/kernel/futex/futex_state_autogen.go @@ -0,0 +1,79 @@ +// automatically generated by stateify. + +package futex + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *AtomicPtrBucket) beforeSave() {} +func (x *AtomicPtrBucket) save(m state.Map) { + x.beforeSave() + var ptr *bucket = x.savePtr() + m.SaveValue("ptr", ptr) +} + +func (x *AtomicPtrBucket) afterLoad() {} +func (x *AtomicPtrBucket) load(m state.Map) { + m.LoadValue("ptr", new(*bucket), func(y interface{}) { x.loadPtr(y.(*bucket)) }) +} + +func (x *bucket) beforeSave() {} +func (x *bucket) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.waiters) { + m.Failf("waiters is %v, expected zero", x.waiters) + } +} + +func (x *bucket) afterLoad() {} +func (x *bucket) load(m state.Map) { +} + +func (x *Manager) beforeSave() {} +func (x *Manager) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.privateBuckets) { + m.Failf("privateBuckets is %v, expected zero", x.privateBuckets) + } + m.Save("sharedBucket", &x.sharedBucket) +} + +func (x *Manager) afterLoad() {} +func (x *Manager) load(m state.Map) { + m.Load("sharedBucket", &x.sharedBucket) +} + +func (x *waiterList) beforeSave() {} +func (x *waiterList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *waiterList) afterLoad() {} +func (x *waiterList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *waiterEntry) beforeSave() {} +func (x *waiterEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *waiterEntry) afterLoad() {} +func (x *waiterEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/sentry/kernel/futex.AtomicPtrBucket", (*AtomicPtrBucket)(nil), state.Fns{Save: (*AtomicPtrBucket).save, Load: (*AtomicPtrBucket).load}) + state.Register("pkg/sentry/kernel/futex.bucket", (*bucket)(nil), state.Fns{Save: (*bucket).save, Load: (*bucket).load}) + state.Register("pkg/sentry/kernel/futex.Manager", (*Manager)(nil), state.Fns{Save: (*Manager).save, Load: (*Manager).load}) + state.Register("pkg/sentry/kernel/futex.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load}) + state.Register("pkg/sentry/kernel/futex.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load}) +} diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go deleted file mode 100644 index 7c5c7665b..000000000 --- a/pkg/sentry/kernel/futex/futex_test.go +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package futex - -import ( - "math" - "runtime" - "sync/atomic" - "syscall" - "testing" - "unsafe" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" -) - -// testData implements the Target interface, and allows us to -// treat the address passed for futex operations as an index in -// a byte slice for testing simplicity. -type testData []byte - -const sizeofInt32 = 4 - -func newTestData(size uint) testData { - return make([]byte, size) -} - -func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) { - val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t[addr])), new) - return val, nil -} - -func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) { - if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t[addr])), old, new) { - return old, nil - } - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil -} - -func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) { - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil -} - -func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) { - return Key{ - Kind: KindSharedMappable, - Offset: uint64(addr), - }, nil -} - -func futexKind(private bool) string { - if private { - return "private" - } - return "shared" -} - -func newPreparedTestWaiter(t *testing.T, m *Manager, ta Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) *Waiter { - w := NewWaiter() - if err := m.WaitPrepare(w, ta, addr, private, val, bitmask); err != nil { - t.Fatalf("WaitPrepare failed: %v", err) - } - return w -} - -func TestFutexWake(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start waiting for wakeup. - w := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w) - - // Perform a wakeup. - if n, err := m.Wake(d, 0, private, ^uint32(0), 1); err != nil || n != 1 { - t.Errorf("Wake: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect the waiter to have been woken. - if !w.woken() { - t.Error("waiter not woken") - } - }) - } -} - -func TestFutexWakeBitmask(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start waiting for wakeup. - w := newPreparedTestWaiter(t, m, d, 0, private, 0, 0x0000ffff) - defer m.WaitComplete(w) - - // Perform a wakeup using the wrong bitmask. - if n, err := m.Wake(d, 0, private, 0xffff0000, 1); err != nil || n != 0 { - t.Errorf("Wake with non-matching bitmask: got (%d, %v), wanted (0, nil)", n, err) - } - - // Expect the waiter to still be waiting. - if w.woken() { - t.Error("waiter woken unexpectedly") - } - - // Perform a wakeup using the right bitmask. - if n, err := m.Wake(d, 0, private, 0x00000001, 1); err != nil || n != 1 { - t.Errorf("Wake with matching bitmask: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect that the waiter was woken. - if !w.woken() { - t.Error("waiter not woken") - } - }) - } -} - -func TestFutexWakeTwo(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start three waiters waiting for wakeup. - var ws [3]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) - } - - // Perform two wakeups. - const wakeups = 2 - if n, err := m.Wake(d, 0, private, ^uint32(0), 2); err != nil || n != wakeups { - t.Errorf("Wake: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly two waiters were woken. - // We don't get guarantees about exactly which two, - // (although we expect them to be w1 and w2). - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -func TestFutexWakeUnrelated(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(2 * sizeofInt32) - - // Start two waiters waiting for wakeup on different addresses. - w1 := newPreparedTestWaiter(t, m, d, 0*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, 1*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Perform two wakeups on the second address. - if n, err := m.Wake(d, 1*sizeofInt32, private, ^uint32(0), 2); err != nil || n != 1 { - t.Errorf("Wake: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect that only the second waiter was woken. - if w1.woken() { - t.Error("w1 woken unexpectedly") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(2 * sizeofInt32) - - // Perform wakeups with no waiters. - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 0); err != nil || n != 0 { - t.Fatalf("WakeOp: got (%d, %v), wanted (0, nil)", n, err) - } - }) - } -} - -func TestWakeOpFirstNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Perform 10 wakeups on address 0. - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 0, 0); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that both waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpSecondNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address sizeofInt32. - w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Perform 10 wakeups on address sizeofInt32 (contingent on - // d.Op(0), which should succeed). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 0, 10, 0); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that both waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpSecondNonEmptyFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address sizeofInt32. - w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Perform 10 wakeups on address sizeofInt32 (contingent on - // d.Op(1), which should fail). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 0, 10, 1); err != nil || n != 0 { - t.Errorf("WakeOp: got (%d, %v), wanted (0, nil)", n, err) - } - - // Expect that neither waiter was woken. - if w1.woken() { - t.Error("w1 woken unexpectedly") - } - if w2.woken() { - t.Error("w2 woken unexpectedly") - } - }) - } -} - -func TestWakeOpAllNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Add two waiters on address sizeofInt32. - w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3) - w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4) - - // Perform 10 wakeups on address 0 (unconditionally), and 10 - // wakeups on address sizeofInt32 (contingent on d.Op(0), which - // should succeed). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 0); err != nil || n != 4 { - t.Errorf("WakeOp: got (%d, %v), wanted (4, nil)", n, err) - } - - // Expect that all waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - if !w3.woken() { - t.Error("w3 not woken") - } - if !w4.woken() { - t.Error("w4 not woken") - } - }) - } -} - -func TestWakeOpAllNonEmptyFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Add two waiters on address sizeofInt32. - w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3) - w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4) - - // Perform 10 wakeups on address 0 (unconditionally), and 10 - // wakeups on address sizeofInt32 (contingent on d.Op(1), which - // should fail). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 1); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that only the first two waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - if w3.woken() { - t.Error("w3 woken unexpectedly") - } - if w4.woken() { - t.Error("w4 woken unexpectedly") - } - }) - } -} - -func TestWakeOpSameAddress(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add four waiters on address 0. - var ws [4]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) - } - - // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup - // on address 0 (contingent on d.Op(0), which should succeed). - const wakeups = 2 - if n, err := m.WakeOp(d, 0, 0, private, 1, 1, 0); err != nil || n != wakeups { - t.Errorf("WakeOp: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly two waiters were woken. - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -func TestWakeOpSameAddressFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add four waiters on address 0. - var ws [4]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) - } - - // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup - // on address 0 (contingent on d.Op(1), which should fail). - const wakeups = 1 - if n, err := m.WakeOp(d, 0, 0, private, 1, 1, 1); err != nil || n != wakeups { - t.Errorf("WakeOp: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly one waiter was woken. - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -const ( - testMutexSize = sizeofInt32 - testMutexLocked uint32 = 1 - testMutexUnlocked uint32 = 0 -) - -// testMutex ties together a testData slice, an address, and a -// futex manager in order to implement the sync.Locker interface. -// Beyond being used as a Locker, this is a simple mechanism for -// changing the underlying values for simpler tests. -type testMutex struct { - a usermem.Addr - d testData - m *Manager -} - -func newTestMutex(addr usermem.Addr, d testData, m *Manager) *testMutex { - return &testMutex{a: addr, d: d, m: m} -} - -// Lock acquires the testMutex. -// This may wait for it to be available via the futex manager. -func (t *testMutex) Lock() { - for { - // Attempt to grab the lock. - if atomic.CompareAndSwapUint32( - (*uint32)(unsafe.Pointer(&t.d[t.a])), - testMutexUnlocked, - testMutexLocked) { - // Lock held. - return - } - - // Wait for it to be "not locked". - w := NewWaiter() - err := t.m.WaitPrepare(w, t.d, t.a, true, testMutexLocked, ^uint32(0)) - if err == syscall.EAGAIN { - continue - } - if err != nil { - // Should never happen. - panic("WaitPrepare returned unexpected error: " + err.Error()) - } - <-w.C - t.m.WaitComplete(w) - } -} - -// Unlock releases the testMutex. -// This will notify any waiters via the futex manager. -func (t *testMutex) Unlock() { - // Unlock. - atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d[t.a])), testMutexUnlocked) - - // Notify all waiters. - t.m.Wake(t.d, t.a, true, ^uint32(0), math.MaxInt32) -} - -// This function was shamelessly stolen from mutex_test.go. -func HammerMutex(l sync.Locker, loops int, cdone chan bool) { - for i := 0; i < loops; i++ { - l.Lock() - runtime.Gosched() - l.Unlock() - } - cdone <- true -} - -func TestMutexStress(t *testing.T) { - m := NewManager() - d := newTestData(testMutexSize) - tm := newTestMutex(0*testMutexSize, d, m) - c := make(chan bool) - - for i := 0; i < 10; i++ { - go HammerMutex(tm, 1000, c) - } - - for i := 0; i < 10; i++ { - <-c - } -} diff --git a/pkg/sentry/kernel/futex/waiter_list.go b/pkg/sentry/kernel/futex/waiter_list.go new file mode 100755 index 000000000..204eededf --- /dev/null +++ b/pkg/sentry/kernel/futex/waiter_list.go @@ -0,0 +1,186 @@ +package futex + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *Waiter) *Waiter { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *Waiter + tail *Waiter +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *waiterList) Front() *Waiter { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *waiterList) Back() *Waiter { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *waiterList) PushFront(e *Waiter) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *waiterList) PushBack(e *Waiter) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *waiterList) InsertAfter(b, e *Waiter) { + bLinker := waiterElementMapper{}.linkerFor(b) + eLinker := waiterElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *waiterList) InsertBefore(a, e *Waiter) { + aLinker := waiterElementMapper{}.linkerFor(a) + eLinker := waiterElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *waiterList) Remove(e *Waiter) { + linker := waiterElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *Waiter + prev *Waiter +} + +// Next returns the entry that follows e in the list. +func (e *waiterEntry) Next() *Waiter { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *waiterEntry) Prev() *Waiter { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *waiterEntry) SetNext(elem *Waiter) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *waiterEntry) SetPrev(elem *Waiter) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/g3doc/run_states.dot b/pkg/sentry/kernel/g3doc/run_states.dot deleted file mode 100644 index 7861fe1f5..000000000 --- a/pkg/sentry/kernel/g3doc/run_states.dot +++ /dev/null @@ -1,99 +0,0 @@ -digraph { - subgraph { - App; - } - subgraph { - Interrupt; - InterruptAfterSignalDeliveryStop; - } - subgraph { - Syscall; - SyscallAfterPtraceEventSeccomp; - SyscallEnter; - SyscallAfterSyscallEnterStop; - SyscallAfterSysemuStop; - SyscallInvoke; - SyscallAfterPtraceEventClone; - SyscallAfterExecStop; - SyscallAfterVforkStop; - SyscallReinvoke; - SyscallExit; - } - subgraph { - Vsyscall; - VsyscallAfterPtraceEventSeccomp; - VsyscallInvoke; - } - subgraph { - Exit; - ExitMain; // leave thread group, release resources, reparent children, kill PID namespace and wait if TGID 1 - ExitNotify; // signal parent/tracer, become waitable - ExitDone; // represented by t.runState == nil - } - - // Task exit - Exit -> ExitMain; - ExitMain -> ExitNotify; - ExitNotify -> ExitDone; - - // Execution of untrusted application code - App -> App; - - // Interrupts (usually signal delivery) - App -> Interrupt; - Interrupt -> Interrupt; // if other interrupt conditions may still apply - Interrupt -> Exit; // if killed - - // Syscalls - App -> Syscall; - Syscall -> SyscallEnter; - SyscallEnter -> SyscallInvoke; - SyscallInvoke -> SyscallExit; - SyscallExit -> App; - - // exit, exit_group - SyscallInvoke -> Exit; - - // execve - SyscallInvoke -> SyscallAfterExecStop; - SyscallAfterExecStop -> SyscallExit; - SyscallAfterExecStop -> App; // fatal signal pending - - // vfork - SyscallInvoke -> SyscallAfterVforkStop; - SyscallAfterVforkStop -> SyscallExit; - - // Vsyscalls - App -> Vsyscall; - Vsyscall -> VsyscallInvoke; - Vsyscall -> App; // fault while reading return address from stack - VsyscallInvoke -> App; - - // ptrace-specific branches - Interrupt -> InterruptAfterSignalDeliveryStop; - InterruptAfterSignalDeliveryStop -> Interrupt; - SyscallEnter -> SyscallAfterSyscallEnterStop; - SyscallAfterSyscallEnterStop -> SyscallInvoke; - SyscallAfterSyscallEnterStop -> SyscallExit; // skipped by tracer - SyscallAfterSyscallEnterStop -> App; // fatal signal pending - SyscallEnter -> SyscallAfterSysemuStop; - SyscallAfterSysemuStop -> SyscallExit; - SyscallAfterSysemuStop -> App; // fatal signal pending - SyscallInvoke -> SyscallAfterPtraceEventClone; - SyscallAfterPtraceEventClone -> SyscallExit; - SyscallAfterPtraceEventClone -> SyscallAfterVforkStop; - - // seccomp - Syscall -> App; // SECCOMP_RET_TRAP, SECCOMP_RET_ERRNO, SECCOMP_RET_KILL, SECCOMP_RET_TRACE without tracer - Syscall -> SyscallAfterPtraceEventSeccomp; // SECCOMP_RET_TRACE - SyscallAfterPtraceEventSeccomp -> SyscallEnter; - SyscallAfterPtraceEventSeccomp -> SyscallExit; // skipped by tracer - SyscallAfterPtraceEventSeccomp -> App; // fatal signal pending - Vsyscall -> VsyscallAfterPtraceEventSeccomp; - VsyscallAfterPtraceEventSeccomp -> VsyscallInvoke; - VsyscallAfterPtraceEventSeccomp -> App; - - // Autosave - SyscallInvoke -> SyscallReinvoke; - SyscallReinvoke -> SyscallInvoke; -} diff --git a/pkg/sentry/kernel/g3doc/run_states.png b/pkg/sentry/kernel/g3doc/run_states.png Binary files differdeleted file mode 100644 index b63b60f02..000000000 --- a/pkg/sentry/kernel/g3doc/run_states.png +++ /dev/null diff --git a/pkg/sentry/kernel/kernel_amd64_state_autogen.go b/pkg/sentry/kernel/kernel_amd64_state_autogen.go new file mode 100755 index 000000000..12de47ad0 --- /dev/null +++ b/pkg/sentry/kernel/kernel_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package kernel diff --git a/pkg/sentry/kernel/kernel_arm64_state_autogen.go b/pkg/sentry/kernel/kernel_arm64_state_autogen.go new file mode 100755 index 000000000..3c040d283 --- /dev/null +++ b/pkg/sentry/kernel/kernel_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package kernel diff --git a/pkg/sentry/kernel/kernel_opts.go b/pkg/sentry/kernel/kernel_opts.go index 2e66ec587..2e66ec587 100644..100755 --- a/pkg/sentry/kernel/kernel_opts.go +++ b/pkg/sentry/kernel/kernel_opts.go diff --git a/pkg/sentry/kernel/kernel_opts_state_autogen.go b/pkg/sentry/kernel/kernel_opts_state_autogen.go new file mode 100755 index 000000000..9ed7e27c9 --- /dev/null +++ b/pkg/sentry/kernel/kernel_opts_state_autogen.go @@ -0,0 +1,20 @@ +// automatically generated by stateify. + +package kernel + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SpecialOpts) beforeSave() {} +func (x *SpecialOpts) save(m state.Map) { + x.beforeSave() +} + +func (x *SpecialOpts) afterLoad() {} +func (x *SpecialOpts) load(m state.Map) { +} + +func init() { + state.Register("pkg/sentry/kernel.SpecialOpts", (*SpecialOpts)(nil), state.Fns{Save: (*SpecialOpts).save, Load: (*SpecialOpts).load}) +} diff --git a/pkg/sentry/kernel/kernel_state_autogen.go b/pkg/sentry/kernel/kernel_state_autogen.go new file mode 100755 index 000000000..57a261086 --- /dev/null +++ b/pkg/sentry/kernel/kernel_state_autogen.go @@ -0,0 +1,1230 @@ +// automatically generated by stateify. + +package kernel + +import ( + "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/sentry/device" + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip" +) + +func (x *abstractEndpoint) beforeSave() {} +func (x *abstractEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("ep", &x.ep) + m.Save("wr", &x.wr) + m.Save("name", &x.name) + m.Save("ns", &x.ns) +} + +func (x *abstractEndpoint) afterLoad() {} +func (x *abstractEndpoint) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("wr", &x.wr) + m.Load("name", &x.name) + m.Load("ns", &x.ns) +} + +func (x *AbstractSocketNamespace) beforeSave() {} +func (x *AbstractSocketNamespace) save(m state.Map) { + x.beforeSave() + m.Save("endpoints", &x.endpoints) +} + +func (x *AbstractSocketNamespace) afterLoad() {} +func (x *AbstractSocketNamespace) load(m state.Map) { + m.Load("endpoints", &x.endpoints) +} + +func (x *FDFlags) beforeSave() {} +func (x *FDFlags) save(m state.Map) { + x.beforeSave() + m.Save("CloseOnExec", &x.CloseOnExec) +} + +func (x *FDFlags) afterLoad() {} +func (x *FDFlags) load(m state.Map) { + m.Load("CloseOnExec", &x.CloseOnExec) +} + +func (x *descriptor) beforeSave() {} +func (x *descriptor) save(m state.Map) { + x.beforeSave() + m.Save("file", &x.file) + m.Save("fileVFS2", &x.fileVFS2) + m.Save("flags", &x.flags) +} + +func (x *descriptor) afterLoad() {} +func (x *descriptor) load(m state.Map) { + m.Load("file", &x.file) + m.Load("fileVFS2", &x.fileVFS2) + m.Load("flags", &x.flags) +} + +func (x *FDTable) beforeSave() {} +func (x *FDTable) save(m state.Map) { + x.beforeSave() + var descriptorTable map[int32]descriptor = x.saveDescriptorTable() + m.SaveValue("descriptorTable", descriptorTable) + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("k", &x.k) + m.Save("uid", &x.uid) + m.Save("next", &x.next) + m.Save("used", &x.used) +} + +func (x *FDTable) afterLoad() {} +func (x *FDTable) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("k", &x.k) + m.Load("uid", &x.uid) + m.Load("next", &x.next) + m.Load("used", &x.used) + m.LoadValue("descriptorTable", new(map[int32]descriptor), func(y interface{}) { x.loadDescriptorTable(y.(map[int32]descriptor)) }) +} + +func (x *FSContext) beforeSave() {} +func (x *FSContext) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("root", &x.root) + m.Save("rootVFS2", &x.rootVFS2) + m.Save("cwd", &x.cwd) + m.Save("cwdVFS2", &x.cwdVFS2) + m.Save("umask", &x.umask) +} + +func (x *FSContext) afterLoad() {} +func (x *FSContext) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("root", &x.root) + m.Load("rootVFS2", &x.rootVFS2) + m.Load("cwd", &x.cwd) + m.Load("cwdVFS2", &x.cwdVFS2) + m.Load("umask", &x.umask) +} + +func (x *IPCNamespace) beforeSave() {} +func (x *IPCNamespace) save(m state.Map) { + x.beforeSave() + m.Save("userNS", &x.userNS) + m.Save("semaphores", &x.semaphores) + m.Save("shms", &x.shms) +} + +func (x *IPCNamespace) afterLoad() {} +func (x *IPCNamespace) load(m state.Map) { + m.Load("userNS", &x.userNS) + m.Load("semaphores", &x.semaphores) + m.Load("shms", &x.shms) +} + +func (x *Kernel) beforeSave() {} +func (x *Kernel) save(m state.Map) { + x.beforeSave() + var danglingEndpoints []tcpip.Endpoint = x.saveDanglingEndpoints() + m.SaveValue("danglingEndpoints", danglingEndpoints) + var deviceRegistry *device.Registry = x.saveDeviceRegistry() + m.SaveValue("deviceRegistry", deviceRegistry) + m.Save("featureSet", &x.featureSet) + m.Save("timekeeper", &x.timekeeper) + m.Save("tasks", &x.tasks) + m.Save("rootUserNamespace", &x.rootUserNamespace) + m.Save("rootNetworkNamespace", &x.rootNetworkNamespace) + m.Save("applicationCores", &x.applicationCores) + m.Save("useHostCores", &x.useHostCores) + m.Save("extraAuxv", &x.extraAuxv) + m.Save("vdso", &x.vdso) + m.Save("rootUTSNamespace", &x.rootUTSNamespace) + m.Save("rootIPCNamespace", &x.rootIPCNamespace) + m.Save("rootAbstractSocketNamespace", &x.rootAbstractSocketNamespace) + m.Save("futexes", &x.futexes) + m.Save("globalInit", &x.globalInit) + m.Save("realtimeClock", &x.realtimeClock) + m.Save("monotonicClock", &x.monotonicClock) + m.Save("syslog", &x.syslog) + m.Save("runningTasks", &x.runningTasks) + m.Save("cpuClock", &x.cpuClock) + m.Save("cpuClockTickerDisabled", &x.cpuClockTickerDisabled) + m.Save("cpuClockTickerSetting", &x.cpuClockTickerSetting) + m.Save("fdMapUids", &x.fdMapUids) + m.Save("uniqueID", &x.uniqueID) + m.Save("nextInotifyCookie", &x.nextInotifyCookie) + m.Save("netlinkPorts", &x.netlinkPorts) + m.Save("sockets", &x.sockets) + m.Save("nextSocketEntry", &x.nextSocketEntry) + m.Save("DirentCacheLimiter", &x.DirentCacheLimiter) + m.Save("SpecialOpts", &x.SpecialOpts) + m.Save("vfs", &x.vfs) + m.Save("SleepForAddressSpaceActivation", &x.SleepForAddressSpaceActivation) +} + +func (x *Kernel) afterLoad() {} +func (x *Kernel) load(m state.Map) { + m.Load("featureSet", &x.featureSet) + m.Load("timekeeper", &x.timekeeper) + m.Load("tasks", &x.tasks) + m.Load("rootUserNamespace", &x.rootUserNamespace) + m.Load("rootNetworkNamespace", &x.rootNetworkNamespace) + m.Load("applicationCores", &x.applicationCores) + m.Load("useHostCores", &x.useHostCores) + m.Load("extraAuxv", &x.extraAuxv) + m.Load("vdso", &x.vdso) + m.Load("rootUTSNamespace", &x.rootUTSNamespace) + m.Load("rootIPCNamespace", &x.rootIPCNamespace) + m.Load("rootAbstractSocketNamespace", &x.rootAbstractSocketNamespace) + m.Load("futexes", &x.futexes) + m.Load("globalInit", &x.globalInit) + m.Load("realtimeClock", &x.realtimeClock) + m.Load("monotonicClock", &x.monotonicClock) + m.Load("syslog", &x.syslog) + m.Load("runningTasks", &x.runningTasks) + m.Load("cpuClock", &x.cpuClock) + m.Load("cpuClockTickerDisabled", &x.cpuClockTickerDisabled) + m.Load("cpuClockTickerSetting", &x.cpuClockTickerSetting) + m.Load("fdMapUids", &x.fdMapUids) + m.Load("uniqueID", &x.uniqueID) + m.Load("nextInotifyCookie", &x.nextInotifyCookie) + m.Load("netlinkPorts", &x.netlinkPorts) + m.Load("sockets", &x.sockets) + m.Load("nextSocketEntry", &x.nextSocketEntry) + m.Load("DirentCacheLimiter", &x.DirentCacheLimiter) + m.Load("SpecialOpts", &x.SpecialOpts) + m.Load("vfs", &x.vfs) + m.Load("SleepForAddressSpaceActivation", &x.SleepForAddressSpaceActivation) + m.LoadValue("danglingEndpoints", new([]tcpip.Endpoint), func(y interface{}) { x.loadDanglingEndpoints(y.([]tcpip.Endpoint)) }) + m.LoadValue("deviceRegistry", new(*device.Registry), func(y interface{}) { x.loadDeviceRegistry(y.(*device.Registry)) }) +} + +func (x *SocketEntry) beforeSave() {} +func (x *SocketEntry) save(m state.Map) { + x.beforeSave() + m.Save("socketEntry", &x.socketEntry) + m.Save("k", &x.k) + m.Save("Sock", &x.Sock) + m.Save("ID", &x.ID) +} + +func (x *SocketEntry) afterLoad() {} +func (x *SocketEntry) load(m state.Map) { + m.Load("socketEntry", &x.socketEntry) + m.Load("k", &x.k) + m.Load("Sock", &x.Sock) + m.Load("ID", &x.ID) +} + +func (x *pendingSignals) beforeSave() {} +func (x *pendingSignals) save(m state.Map) { + x.beforeSave() + var signals []savedPendingSignal = x.saveSignals() + m.SaveValue("signals", signals) +} + +func (x *pendingSignals) afterLoad() {} +func (x *pendingSignals) load(m state.Map) { + m.LoadValue("signals", new([]savedPendingSignal), func(y interface{}) { x.loadSignals(y.([]savedPendingSignal)) }) +} + +func (x *pendingSignalQueue) beforeSave() {} +func (x *pendingSignalQueue) save(m state.Map) { + x.beforeSave() + m.Save("pendingSignalList", &x.pendingSignalList) + m.Save("length", &x.length) +} + +func (x *pendingSignalQueue) afterLoad() {} +func (x *pendingSignalQueue) load(m state.Map) { + m.Load("pendingSignalList", &x.pendingSignalList) + m.Load("length", &x.length) +} + +func (x *pendingSignal) beforeSave() {} +func (x *pendingSignal) save(m state.Map) { + x.beforeSave() + m.Save("pendingSignalEntry", &x.pendingSignalEntry) + m.Save("SignalInfo", &x.SignalInfo) + m.Save("timer", &x.timer) +} + +func (x *pendingSignal) afterLoad() {} +func (x *pendingSignal) load(m state.Map) { + m.Load("pendingSignalEntry", &x.pendingSignalEntry) + m.Load("SignalInfo", &x.SignalInfo) + m.Load("timer", &x.timer) +} + +func (x *pendingSignalList) beforeSave() {} +func (x *pendingSignalList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *pendingSignalList) afterLoad() {} +func (x *pendingSignalList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *pendingSignalEntry) beforeSave() {} +func (x *pendingSignalEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *pendingSignalEntry) afterLoad() {} +func (x *pendingSignalEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *savedPendingSignal) beforeSave() {} +func (x *savedPendingSignal) save(m state.Map) { + x.beforeSave() + m.Save("si", &x.si) + m.Save("timer", &x.timer) +} + +func (x *savedPendingSignal) afterLoad() {} +func (x *savedPendingSignal) load(m state.Map) { + m.Load("si", &x.si) + m.Load("timer", &x.timer) +} + +func (x *IntervalTimer) beforeSave() {} +func (x *IntervalTimer) save(m state.Map) { + x.beforeSave() + m.Save("timer", &x.timer) + m.Save("target", &x.target) + m.Save("signo", &x.signo) + m.Save("id", &x.id) + m.Save("sigval", &x.sigval) + m.Save("group", &x.group) + m.Save("sigpending", &x.sigpending) + m.Save("sigorphan", &x.sigorphan) + m.Save("overrunCur", &x.overrunCur) + m.Save("overrunLast", &x.overrunLast) +} + +func (x *IntervalTimer) afterLoad() {} +func (x *IntervalTimer) load(m state.Map) { + m.Load("timer", &x.timer) + m.Load("target", &x.target) + m.Load("signo", &x.signo) + m.Load("id", &x.id) + m.Load("sigval", &x.sigval) + m.Load("group", &x.group) + m.Load("sigpending", &x.sigpending) + m.Load("sigorphan", &x.sigorphan) + m.Load("overrunCur", &x.overrunCur) + m.Load("overrunLast", &x.overrunLast) +} + +func (x *processGroupList) beforeSave() {} +func (x *processGroupList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *processGroupList) afterLoad() {} +func (x *processGroupList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *processGroupEntry) beforeSave() {} +func (x *processGroupEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *processGroupEntry) afterLoad() {} +func (x *processGroupEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *ptraceOptions) beforeSave() {} +func (x *ptraceOptions) save(m state.Map) { + x.beforeSave() + m.Save("ExitKill", &x.ExitKill) + m.Save("SysGood", &x.SysGood) + m.Save("TraceClone", &x.TraceClone) + m.Save("TraceExec", &x.TraceExec) + m.Save("TraceExit", &x.TraceExit) + m.Save("TraceFork", &x.TraceFork) + m.Save("TraceSeccomp", &x.TraceSeccomp) + m.Save("TraceVfork", &x.TraceVfork) + m.Save("TraceVforkDone", &x.TraceVforkDone) +} + +func (x *ptraceOptions) afterLoad() {} +func (x *ptraceOptions) load(m state.Map) { + m.Load("ExitKill", &x.ExitKill) + m.Load("SysGood", &x.SysGood) + m.Load("TraceClone", &x.TraceClone) + m.Load("TraceExec", &x.TraceExec) + m.Load("TraceExit", &x.TraceExit) + m.Load("TraceFork", &x.TraceFork) + m.Load("TraceSeccomp", &x.TraceSeccomp) + m.Load("TraceVfork", &x.TraceVfork) + m.Load("TraceVforkDone", &x.TraceVforkDone) +} + +func (x *ptraceStop) beforeSave() {} +func (x *ptraceStop) save(m state.Map) { + x.beforeSave() + m.Save("frozen", &x.frozen) + m.Save("listen", &x.listen) +} + +func (x *ptraceStop) afterLoad() {} +func (x *ptraceStop) load(m state.Map) { + m.Load("frozen", &x.frozen) + m.Load("listen", &x.listen) +} + +func (x *OldRSeqCriticalRegion) beforeSave() {} +func (x *OldRSeqCriticalRegion) save(m state.Map) { + x.beforeSave() + m.Save("CriticalSection", &x.CriticalSection) + m.Save("Restart", &x.Restart) +} + +func (x *OldRSeqCriticalRegion) afterLoad() {} +func (x *OldRSeqCriticalRegion) load(m state.Map) { + m.Load("CriticalSection", &x.CriticalSection) + m.Load("Restart", &x.Restart) +} + +func (x *sessionList) beforeSave() {} +func (x *sessionList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *sessionList) afterLoad() {} +func (x *sessionList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *sessionEntry) beforeSave() {} +func (x *sessionEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *sessionEntry) afterLoad() {} +func (x *sessionEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *Session) beforeSave() {} +func (x *Session) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) + m.Save("leader", &x.leader) + m.Save("id", &x.id) + m.Save("foreground", &x.foreground) + m.Save("processGroups", &x.processGroups) + m.Save("sessionEntry", &x.sessionEntry) +} + +func (x *Session) afterLoad() {} +func (x *Session) load(m state.Map) { + m.Load("refs", &x.refs) + m.Load("leader", &x.leader) + m.Load("id", &x.id) + m.Load("foreground", &x.foreground) + m.Load("processGroups", &x.processGroups) + m.Load("sessionEntry", &x.sessionEntry) +} + +func (x *ProcessGroup) beforeSave() {} +func (x *ProcessGroup) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) + m.Save("originator", &x.originator) + m.Save("id", &x.id) + m.Save("session", &x.session) + m.Save("ancestors", &x.ancestors) + m.Save("processGroupEntry", &x.processGroupEntry) +} + +func (x *ProcessGroup) afterLoad() {} +func (x *ProcessGroup) load(m state.Map) { + m.Load("refs", &x.refs) + m.Load("originator", &x.originator) + m.Load("id", &x.id) + m.Load("session", &x.session) + m.Load("ancestors", &x.ancestors) + m.Load("processGroupEntry", &x.processGroupEntry) +} + +func (x *SignalHandlers) beforeSave() {} +func (x *SignalHandlers) save(m state.Map) { + x.beforeSave() + m.Save("actions", &x.actions) +} + +func (x *SignalHandlers) afterLoad() {} +func (x *SignalHandlers) load(m state.Map) { + m.Load("actions", &x.actions) +} + +func (x *socketList) beforeSave() {} +func (x *socketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *socketList) afterLoad() {} +func (x *socketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *socketEntry) beforeSave() {} +func (x *socketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *socketEntry) afterLoad() {} +func (x *socketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *SyscallTable) beforeSave() {} +func (x *SyscallTable) save(m state.Map) { + x.beforeSave() + m.Save("OS", &x.OS) + m.Save("Arch", &x.Arch) +} + +func (x *SyscallTable) load(m state.Map) { + m.LoadWait("OS", &x.OS) + m.LoadWait("Arch", &x.Arch) + m.AfterLoad(x.afterLoad) +} + +func (x *syslog) beforeSave() {} +func (x *syslog) save(m state.Map) { + x.beforeSave() + m.Save("msg", &x.msg) +} + +func (x *syslog) afterLoad() {} +func (x *syslog) load(m state.Map) { + m.Load("msg", &x.msg) +} + +func (x *Task) beforeSave() {} +func (x *Task) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.signalQueue) { + m.Failf("signalQueue is %v, expected zero", x.signalQueue) + } + var ptraceTracer *Task = x.savePtraceTracer() + m.SaveValue("ptraceTracer", ptraceTracer) + var syscallFilters []bpf.Program = x.saveSyscallFilters() + m.SaveValue("syscallFilters", syscallFilters) + m.Save("taskNode", &x.taskNode) + m.Save("runState", &x.runState) + m.Save("haveSyscallReturn", &x.haveSyscallReturn) + m.Save("gosched", &x.gosched) + m.Save("yieldCount", &x.yieldCount) + m.Save("pendingSignals", &x.pendingSignals) + m.Save("signalMask", &x.signalMask) + m.Save("realSignalMask", &x.realSignalMask) + m.Save("haveSavedSignalMask", &x.haveSavedSignalMask) + m.Save("savedSignalMask", &x.savedSignalMask) + m.Save("signalStack", &x.signalStack) + m.Save("groupStopPending", &x.groupStopPending) + m.Save("groupStopAcknowledged", &x.groupStopAcknowledged) + m.Save("trapStopPending", &x.trapStopPending) + m.Save("trapNotifyPending", &x.trapNotifyPending) + m.Save("stop", &x.stop) + m.Save("exitStatus", &x.exitStatus) + m.Save("syscallRestartBlock", &x.syscallRestartBlock) + m.Save("k", &x.k) + m.Save("containerID", &x.containerID) + m.Save("tc", &x.tc) + m.Save("fsContext", &x.fsContext) + m.Save("fdTable", &x.fdTable) + m.Save("vforkParent", &x.vforkParent) + m.Save("exitState", &x.exitState) + m.Save("exitTracerNotified", &x.exitTracerNotified) + m.Save("exitTracerAcked", &x.exitTracerAcked) + m.Save("exitParentNotified", &x.exitParentNotified) + m.Save("exitParentAcked", &x.exitParentAcked) + m.Save("ptraceTracees", &x.ptraceTracees) + m.Save("ptraceSeized", &x.ptraceSeized) + m.Save("ptraceOpts", &x.ptraceOpts) + m.Save("ptraceSyscallMode", &x.ptraceSyscallMode) + m.Save("ptraceSinglestep", &x.ptraceSinglestep) + m.Save("ptraceCode", &x.ptraceCode) + m.Save("ptraceSiginfo", &x.ptraceSiginfo) + m.Save("ptraceEventMsg", &x.ptraceEventMsg) + m.Save("ioUsage", &x.ioUsage) + m.Save("creds", &x.creds) + m.Save("utsns", &x.utsns) + m.Save("ipcns", &x.ipcns) + m.Save("abstractSockets", &x.abstractSockets) + m.Save("mountNamespaceVFS2", &x.mountNamespaceVFS2) + m.Save("parentDeathSignal", &x.parentDeathSignal) + m.Save("cleartid", &x.cleartid) + m.Save("allowedCPUMask", &x.allowedCPUMask) + m.Save("cpu", &x.cpu) + m.Save("niceness", &x.niceness) + m.Save("numaPolicy", &x.numaPolicy) + m.Save("numaNodeMask", &x.numaNodeMask) + m.Save("netns", &x.netns) + m.Save("rseqCPU", &x.rseqCPU) + m.Save("oldRSeqCPUAddr", &x.oldRSeqCPUAddr) + m.Save("rseqAddr", &x.rseqAddr) + m.Save("rseqSignature", &x.rseqSignature) + m.Save("startTime", &x.startTime) +} + +func (x *Task) load(m state.Map) { + m.Load("taskNode", &x.taskNode) + m.Load("runState", &x.runState) + m.Load("haveSyscallReturn", &x.haveSyscallReturn) + m.Load("gosched", &x.gosched) + m.Load("yieldCount", &x.yieldCount) + m.Load("pendingSignals", &x.pendingSignals) + m.Load("signalMask", &x.signalMask) + m.Load("realSignalMask", &x.realSignalMask) + m.Load("haveSavedSignalMask", &x.haveSavedSignalMask) + m.Load("savedSignalMask", &x.savedSignalMask) + m.Load("signalStack", &x.signalStack) + m.Load("groupStopPending", &x.groupStopPending) + m.Load("groupStopAcknowledged", &x.groupStopAcknowledged) + m.Load("trapStopPending", &x.trapStopPending) + m.Load("trapNotifyPending", &x.trapNotifyPending) + m.Load("stop", &x.stop) + m.Load("exitStatus", &x.exitStatus) + m.Load("syscallRestartBlock", &x.syscallRestartBlock) + m.Load("k", &x.k) + m.Load("containerID", &x.containerID) + m.Load("tc", &x.tc) + m.Load("fsContext", &x.fsContext) + m.Load("fdTable", &x.fdTable) + m.Load("vforkParent", &x.vforkParent) + m.Load("exitState", &x.exitState) + m.Load("exitTracerNotified", &x.exitTracerNotified) + m.Load("exitTracerAcked", &x.exitTracerAcked) + m.Load("exitParentNotified", &x.exitParentNotified) + m.Load("exitParentAcked", &x.exitParentAcked) + m.Load("ptraceTracees", &x.ptraceTracees) + m.Load("ptraceSeized", &x.ptraceSeized) + m.Load("ptraceOpts", &x.ptraceOpts) + m.Load("ptraceSyscallMode", &x.ptraceSyscallMode) + m.Load("ptraceSinglestep", &x.ptraceSinglestep) + m.Load("ptraceCode", &x.ptraceCode) + m.Load("ptraceSiginfo", &x.ptraceSiginfo) + m.Load("ptraceEventMsg", &x.ptraceEventMsg) + m.Load("ioUsage", &x.ioUsage) + m.Load("creds", &x.creds) + m.Load("utsns", &x.utsns) + m.Load("ipcns", &x.ipcns) + m.Load("abstractSockets", &x.abstractSockets) + m.Load("mountNamespaceVFS2", &x.mountNamespaceVFS2) + m.Load("parentDeathSignal", &x.parentDeathSignal) + m.Load("cleartid", &x.cleartid) + m.Load("allowedCPUMask", &x.allowedCPUMask) + m.Load("cpu", &x.cpu) + m.Load("niceness", &x.niceness) + m.Load("numaPolicy", &x.numaPolicy) + m.Load("numaNodeMask", &x.numaNodeMask) + m.Load("netns", &x.netns) + m.Load("rseqCPU", &x.rseqCPU) + m.Load("oldRSeqCPUAddr", &x.oldRSeqCPUAddr) + m.Load("rseqAddr", &x.rseqAddr) + m.Load("rseqSignature", &x.rseqSignature) + m.Load("startTime", &x.startTime) + m.LoadValue("ptraceTracer", new(*Task), func(y interface{}) { x.loadPtraceTracer(y.(*Task)) }) + m.LoadValue("syscallFilters", new([]bpf.Program), func(y interface{}) { x.loadSyscallFilters(y.([]bpf.Program)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *runSyscallAfterPtraceEventClone) beforeSave() {} +func (x *runSyscallAfterPtraceEventClone) save(m state.Map) { + x.beforeSave() + m.Save("vforkChild", &x.vforkChild) + m.Save("vforkChildTID", &x.vforkChildTID) +} + +func (x *runSyscallAfterPtraceEventClone) afterLoad() {} +func (x *runSyscallAfterPtraceEventClone) load(m state.Map) { + m.Load("vforkChild", &x.vforkChild) + m.Load("vforkChildTID", &x.vforkChildTID) +} + +func (x *runSyscallAfterVforkStop) beforeSave() {} +func (x *runSyscallAfterVforkStop) save(m state.Map) { + x.beforeSave() + m.Save("childTID", &x.childTID) +} + +func (x *runSyscallAfterVforkStop) afterLoad() {} +func (x *runSyscallAfterVforkStop) load(m state.Map) { + m.Load("childTID", &x.childTID) +} + +func (x *vforkStop) beforeSave() {} +func (x *vforkStop) save(m state.Map) { + x.beforeSave() +} + +func (x *vforkStop) afterLoad() {} +func (x *vforkStop) load(m state.Map) { +} + +func (x *TaskContext) beforeSave() {} +func (x *TaskContext) save(m state.Map) { + x.beforeSave() + m.Save("Name", &x.Name) + m.Save("Arch", &x.Arch) + m.Save("MemoryManager", &x.MemoryManager) + m.Save("fu", &x.fu) + m.Save("st", &x.st) +} + +func (x *TaskContext) afterLoad() {} +func (x *TaskContext) load(m state.Map) { + m.Load("Name", &x.Name) + m.Load("Arch", &x.Arch) + m.Load("MemoryManager", &x.MemoryManager) + m.Load("fu", &x.fu) + m.Load("st", &x.st) +} + +func (x *execStop) beforeSave() {} +func (x *execStop) save(m state.Map) { + x.beforeSave() +} + +func (x *execStop) afterLoad() {} +func (x *execStop) load(m state.Map) { +} + +func (x *runSyscallAfterExecStop) beforeSave() {} +func (x *runSyscallAfterExecStop) save(m state.Map) { + x.beforeSave() + m.Save("tc", &x.tc) +} + +func (x *runSyscallAfterExecStop) afterLoad() {} +func (x *runSyscallAfterExecStop) load(m state.Map) { + m.Load("tc", &x.tc) +} + +func (x *ExitStatus) beforeSave() {} +func (x *ExitStatus) save(m state.Map) { + x.beforeSave() + m.Save("Code", &x.Code) + m.Save("Signo", &x.Signo) +} + +func (x *ExitStatus) afterLoad() {} +func (x *ExitStatus) load(m state.Map) { + m.Load("Code", &x.Code) + m.Load("Signo", &x.Signo) +} + +func (x *runExit) beforeSave() {} +func (x *runExit) save(m state.Map) { + x.beforeSave() +} + +func (x *runExit) afterLoad() {} +func (x *runExit) load(m state.Map) { +} + +func (x *runExitMain) beforeSave() {} +func (x *runExitMain) save(m state.Map) { + x.beforeSave() +} + +func (x *runExitMain) afterLoad() {} +func (x *runExitMain) load(m state.Map) { +} + +func (x *runExitNotify) beforeSave() {} +func (x *runExitNotify) save(m state.Map) { + x.beforeSave() +} + +func (x *runExitNotify) afterLoad() {} +func (x *runExitNotify) load(m state.Map) { +} + +func (x *taskList) beforeSave() {} +func (x *taskList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *taskList) afterLoad() {} +func (x *taskList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *taskEntry) beforeSave() {} +func (x *taskEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *taskEntry) afterLoad() {} +func (x *taskEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *runApp) beforeSave() {} +func (x *runApp) save(m state.Map) { + x.beforeSave() +} + +func (x *runApp) afterLoad() {} +func (x *runApp) load(m state.Map) { +} + +func (x *TaskGoroutineSchedInfo) beforeSave() {} +func (x *TaskGoroutineSchedInfo) save(m state.Map) { + x.beforeSave() + m.Save("Timestamp", &x.Timestamp) + m.Save("State", &x.State) + m.Save("UserTicks", &x.UserTicks) + m.Save("SysTicks", &x.SysTicks) +} + +func (x *TaskGoroutineSchedInfo) afterLoad() {} +func (x *TaskGoroutineSchedInfo) load(m state.Map) { + m.Load("Timestamp", &x.Timestamp) + m.Load("State", &x.State) + m.Load("UserTicks", &x.UserTicks) + m.Load("SysTicks", &x.SysTicks) +} + +func (x *taskClock) beforeSave() {} +func (x *taskClock) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("includeSys", &x.includeSys) +} + +func (x *taskClock) afterLoad() {} +func (x *taskClock) load(m state.Map) { + m.Load("t", &x.t) + m.Load("includeSys", &x.includeSys) +} + +func (x *tgClock) beforeSave() {} +func (x *tgClock) save(m state.Map) { + x.beforeSave() + m.Save("tg", &x.tg) + m.Save("includeSys", &x.includeSys) +} + +func (x *tgClock) afterLoad() {} +func (x *tgClock) load(m state.Map) { + m.Load("tg", &x.tg) + m.Load("includeSys", &x.includeSys) +} + +func (x *groupStop) beforeSave() {} +func (x *groupStop) save(m state.Map) { + x.beforeSave() +} + +func (x *groupStop) afterLoad() {} +func (x *groupStop) load(m state.Map) { +} + +func (x *runInterrupt) beforeSave() {} +func (x *runInterrupt) save(m state.Map) { + x.beforeSave() +} + +func (x *runInterrupt) afterLoad() {} +func (x *runInterrupt) load(m state.Map) { +} + +func (x *runInterruptAfterSignalDeliveryStop) beforeSave() {} +func (x *runInterruptAfterSignalDeliveryStop) save(m state.Map) { + x.beforeSave() +} + +func (x *runInterruptAfterSignalDeliveryStop) afterLoad() {} +func (x *runInterruptAfterSignalDeliveryStop) load(m state.Map) { +} + +func (x *runSyscallAfterSyscallEnterStop) beforeSave() {} +func (x *runSyscallAfterSyscallEnterStop) save(m state.Map) { + x.beforeSave() +} + +func (x *runSyscallAfterSyscallEnterStop) afterLoad() {} +func (x *runSyscallAfterSyscallEnterStop) load(m state.Map) { +} + +func (x *runSyscallAfterSysemuStop) beforeSave() {} +func (x *runSyscallAfterSysemuStop) save(m state.Map) { + x.beforeSave() +} + +func (x *runSyscallAfterSysemuStop) afterLoad() {} +func (x *runSyscallAfterSysemuStop) load(m state.Map) { +} + +func (x *runSyscallReinvoke) beforeSave() {} +func (x *runSyscallReinvoke) save(m state.Map) { + x.beforeSave() +} + +func (x *runSyscallReinvoke) afterLoad() {} +func (x *runSyscallReinvoke) load(m state.Map) { +} + +func (x *runSyscallExit) beforeSave() {} +func (x *runSyscallExit) save(m state.Map) { + x.beforeSave() +} + +func (x *runSyscallExit) afterLoad() {} +func (x *runSyscallExit) load(m state.Map) { +} + +func (x *ThreadGroup) beforeSave() {} +func (x *ThreadGroup) save(m state.Map) { + x.beforeSave() + var oldRSeqCritical *OldRSeqCriticalRegion = x.saveOldRSeqCritical() + m.SaveValue("oldRSeqCritical", oldRSeqCritical) + m.Save("threadGroupNode", &x.threadGroupNode) + m.Save("signalHandlers", &x.signalHandlers) + m.Save("pendingSignals", &x.pendingSignals) + m.Save("groupStopDequeued", &x.groupStopDequeued) + m.Save("groupStopSignal", &x.groupStopSignal) + m.Save("groupStopPendingCount", &x.groupStopPendingCount) + m.Save("groupStopComplete", &x.groupStopComplete) + m.Save("groupStopWaitable", &x.groupStopWaitable) + m.Save("groupContNotify", &x.groupContNotify) + m.Save("groupContInterrupted", &x.groupContInterrupted) + m.Save("groupContWaitable", &x.groupContWaitable) + m.Save("exiting", &x.exiting) + m.Save("exitStatus", &x.exitStatus) + m.Save("terminationSignal", &x.terminationSignal) + m.Save("itimerRealTimer", &x.itimerRealTimer) + m.Save("itimerVirtSetting", &x.itimerVirtSetting) + m.Save("itimerProfSetting", &x.itimerProfSetting) + m.Save("rlimitCPUSoftSetting", &x.rlimitCPUSoftSetting) + m.Save("cpuTimersEnabled", &x.cpuTimersEnabled) + m.Save("timers", &x.timers) + m.Save("nextTimerID", &x.nextTimerID) + m.Save("exitedCPUStats", &x.exitedCPUStats) + m.Save("childCPUStats", &x.childCPUStats) + m.Save("ioUsage", &x.ioUsage) + m.Save("maxRSS", &x.maxRSS) + m.Save("childMaxRSS", &x.childMaxRSS) + m.Save("limits", &x.limits) + m.Save("processGroup", &x.processGroup) + m.Save("execed", &x.execed) + m.Save("mounts", &x.mounts) + m.Save("tty", &x.tty) + m.Save("oomScoreAdj", &x.oomScoreAdj) +} + +func (x *ThreadGroup) afterLoad() {} +func (x *ThreadGroup) load(m state.Map) { + m.Load("threadGroupNode", &x.threadGroupNode) + m.Load("signalHandlers", &x.signalHandlers) + m.Load("pendingSignals", &x.pendingSignals) + m.Load("groupStopDequeued", &x.groupStopDequeued) + m.Load("groupStopSignal", &x.groupStopSignal) + m.Load("groupStopPendingCount", &x.groupStopPendingCount) + m.Load("groupStopComplete", &x.groupStopComplete) + m.Load("groupStopWaitable", &x.groupStopWaitable) + m.Load("groupContNotify", &x.groupContNotify) + m.Load("groupContInterrupted", &x.groupContInterrupted) + m.Load("groupContWaitable", &x.groupContWaitable) + m.Load("exiting", &x.exiting) + m.Load("exitStatus", &x.exitStatus) + m.Load("terminationSignal", &x.terminationSignal) + m.Load("itimerRealTimer", &x.itimerRealTimer) + m.Load("itimerVirtSetting", &x.itimerVirtSetting) + m.Load("itimerProfSetting", &x.itimerProfSetting) + m.Load("rlimitCPUSoftSetting", &x.rlimitCPUSoftSetting) + m.Load("cpuTimersEnabled", &x.cpuTimersEnabled) + m.Load("timers", &x.timers) + m.Load("nextTimerID", &x.nextTimerID) + m.Load("exitedCPUStats", &x.exitedCPUStats) + m.Load("childCPUStats", &x.childCPUStats) + m.Load("ioUsage", &x.ioUsage) + m.Load("maxRSS", &x.maxRSS) + m.Load("childMaxRSS", &x.childMaxRSS) + m.Load("limits", &x.limits) + m.Load("processGroup", &x.processGroup) + m.Load("execed", &x.execed) + m.Load("mounts", &x.mounts) + m.Load("tty", &x.tty) + m.Load("oomScoreAdj", &x.oomScoreAdj) + m.LoadValue("oldRSeqCritical", new(*OldRSeqCriticalRegion), func(y interface{}) { x.loadOldRSeqCritical(y.(*OldRSeqCriticalRegion)) }) +} + +func (x *itimerRealListener) beforeSave() {} +func (x *itimerRealListener) save(m state.Map) { + x.beforeSave() + m.Save("tg", &x.tg) +} + +func (x *itimerRealListener) afterLoad() {} +func (x *itimerRealListener) load(m state.Map) { + m.Load("tg", &x.tg) +} + +func (x *TaskSet) beforeSave() {} +func (x *TaskSet) save(m state.Map) { + x.beforeSave() + m.Save("Root", &x.Root) + m.Save("sessions", &x.sessions) +} + +func (x *TaskSet) afterLoad() {} +func (x *TaskSet) load(m state.Map) { + m.Load("Root", &x.Root) + m.Load("sessions", &x.sessions) +} + +func (x *PIDNamespace) beforeSave() {} +func (x *PIDNamespace) save(m state.Map) { + x.beforeSave() + m.Save("owner", &x.owner) + m.Save("parent", &x.parent) + m.Save("userns", &x.userns) + m.Save("last", &x.last) + m.Save("tasks", &x.tasks) + m.Save("tids", &x.tids) + m.Save("tgids", &x.tgids) + m.Save("sessions", &x.sessions) + m.Save("sids", &x.sids) + m.Save("processGroups", &x.processGroups) + m.Save("pgids", &x.pgids) + m.Save("exiting", &x.exiting) +} + +func (x *PIDNamespace) afterLoad() {} +func (x *PIDNamespace) load(m state.Map) { + m.Load("owner", &x.owner) + m.Load("parent", &x.parent) + m.Load("userns", &x.userns) + m.Load("last", &x.last) + m.Load("tasks", &x.tasks) + m.Load("tids", &x.tids) + m.Load("tgids", &x.tgids) + m.Load("sessions", &x.sessions) + m.Load("sids", &x.sids) + m.Load("processGroups", &x.processGroups) + m.Load("pgids", &x.pgids) + m.Load("exiting", &x.exiting) +} + +func (x *threadGroupNode) beforeSave() {} +func (x *threadGroupNode) save(m state.Map) { + x.beforeSave() + m.Save("pidns", &x.pidns) + m.Save("leader", &x.leader) + m.Save("execing", &x.execing) + m.Save("tasks", &x.tasks) + m.Save("tasksCount", &x.tasksCount) + m.Save("liveTasks", &x.liveTasks) + m.Save("activeTasks", &x.activeTasks) +} + +func (x *threadGroupNode) afterLoad() {} +func (x *threadGroupNode) load(m state.Map) { + m.Load("pidns", &x.pidns) + m.Load("leader", &x.leader) + m.Load("execing", &x.execing) + m.Load("tasks", &x.tasks) + m.Load("tasksCount", &x.tasksCount) + m.Load("liveTasks", &x.liveTasks) + m.Load("activeTasks", &x.activeTasks) +} + +func (x *taskNode) beforeSave() {} +func (x *taskNode) save(m state.Map) { + x.beforeSave() + m.Save("tg", &x.tg) + m.Save("taskEntry", &x.taskEntry) + m.Save("parent", &x.parent) + m.Save("children", &x.children) + m.Save("childPIDNamespace", &x.childPIDNamespace) +} + +func (x *taskNode) afterLoad() {} +func (x *taskNode) load(m state.Map) { + m.LoadWait("tg", &x.tg) + m.Load("taskEntry", &x.taskEntry) + m.Load("parent", &x.parent) + m.Load("children", &x.children) + m.Load("childPIDNamespace", &x.childPIDNamespace) +} + +func (x *Timekeeper) save(m state.Map) { + x.beforeSave() + m.Save("bootTime", &x.bootTime) + m.Save("saveMonotonic", &x.saveMonotonic) + m.Save("saveRealtime", &x.saveRealtime) + m.Save("params", &x.params) +} + +func (x *Timekeeper) load(m state.Map) { + m.Load("bootTime", &x.bootTime) + m.Load("saveMonotonic", &x.saveMonotonic) + m.Load("saveRealtime", &x.saveRealtime) + m.Load("params", &x.params) + m.AfterLoad(x.afterLoad) +} + +func (x *timekeeperClock) beforeSave() {} +func (x *timekeeperClock) save(m state.Map) { + x.beforeSave() + m.Save("tk", &x.tk) + m.Save("c", &x.c) +} + +func (x *timekeeperClock) afterLoad() {} +func (x *timekeeperClock) load(m state.Map) { + m.Load("tk", &x.tk) + m.Load("c", &x.c) +} + +func (x *TTY) beforeSave() {} +func (x *TTY) save(m state.Map) { + x.beforeSave() + m.Save("Index", &x.Index) + m.Save("tg", &x.tg) +} + +func (x *TTY) afterLoad() {} +func (x *TTY) load(m state.Map) { + m.Load("Index", &x.Index) + m.Load("tg", &x.tg) +} + +func (x *UTSNamespace) beforeSave() {} +func (x *UTSNamespace) save(m state.Map) { + x.beforeSave() + m.Save("hostName", &x.hostName) + m.Save("domainName", &x.domainName) + m.Save("userns", &x.userns) +} + +func (x *UTSNamespace) afterLoad() {} +func (x *UTSNamespace) load(m state.Map) { + m.Load("hostName", &x.hostName) + m.Load("domainName", &x.domainName) + m.Load("userns", &x.userns) +} + +func (x *VDSOParamPage) beforeSave() {} +func (x *VDSOParamPage) save(m state.Map) { + x.beforeSave() + m.Save("mfp", &x.mfp) + m.Save("fr", &x.fr) + m.Save("seq", &x.seq) +} + +func (x *VDSOParamPage) afterLoad() {} +func (x *VDSOParamPage) load(m state.Map) { + m.Load("mfp", &x.mfp) + m.Load("fr", &x.fr) + m.Load("seq", &x.seq) +} + +func init() { + state.Register("pkg/sentry/kernel.abstractEndpoint", (*abstractEndpoint)(nil), state.Fns{Save: (*abstractEndpoint).save, Load: (*abstractEndpoint).load}) + state.Register("pkg/sentry/kernel.AbstractSocketNamespace", (*AbstractSocketNamespace)(nil), state.Fns{Save: (*AbstractSocketNamespace).save, Load: (*AbstractSocketNamespace).load}) + state.Register("pkg/sentry/kernel.FDFlags", (*FDFlags)(nil), state.Fns{Save: (*FDFlags).save, Load: (*FDFlags).load}) + state.Register("pkg/sentry/kernel.descriptor", (*descriptor)(nil), state.Fns{Save: (*descriptor).save, Load: (*descriptor).load}) + state.Register("pkg/sentry/kernel.FDTable", (*FDTable)(nil), state.Fns{Save: (*FDTable).save, Load: (*FDTable).load}) + state.Register("pkg/sentry/kernel.FSContext", (*FSContext)(nil), state.Fns{Save: (*FSContext).save, Load: (*FSContext).load}) + state.Register("pkg/sentry/kernel.IPCNamespace", (*IPCNamespace)(nil), state.Fns{Save: (*IPCNamespace).save, Load: (*IPCNamespace).load}) + state.Register("pkg/sentry/kernel.Kernel", (*Kernel)(nil), state.Fns{Save: (*Kernel).save, Load: (*Kernel).load}) + state.Register("pkg/sentry/kernel.SocketEntry", (*SocketEntry)(nil), state.Fns{Save: (*SocketEntry).save, Load: (*SocketEntry).load}) + state.Register("pkg/sentry/kernel.pendingSignals", (*pendingSignals)(nil), state.Fns{Save: (*pendingSignals).save, Load: (*pendingSignals).load}) + state.Register("pkg/sentry/kernel.pendingSignalQueue", (*pendingSignalQueue)(nil), state.Fns{Save: (*pendingSignalQueue).save, Load: (*pendingSignalQueue).load}) + state.Register("pkg/sentry/kernel.pendingSignal", (*pendingSignal)(nil), state.Fns{Save: (*pendingSignal).save, Load: (*pendingSignal).load}) + state.Register("pkg/sentry/kernel.pendingSignalList", (*pendingSignalList)(nil), state.Fns{Save: (*pendingSignalList).save, Load: (*pendingSignalList).load}) + state.Register("pkg/sentry/kernel.pendingSignalEntry", (*pendingSignalEntry)(nil), state.Fns{Save: (*pendingSignalEntry).save, Load: (*pendingSignalEntry).load}) + state.Register("pkg/sentry/kernel.savedPendingSignal", (*savedPendingSignal)(nil), state.Fns{Save: (*savedPendingSignal).save, Load: (*savedPendingSignal).load}) + state.Register("pkg/sentry/kernel.IntervalTimer", (*IntervalTimer)(nil), state.Fns{Save: (*IntervalTimer).save, Load: (*IntervalTimer).load}) + state.Register("pkg/sentry/kernel.processGroupList", (*processGroupList)(nil), state.Fns{Save: (*processGroupList).save, Load: (*processGroupList).load}) + state.Register("pkg/sentry/kernel.processGroupEntry", (*processGroupEntry)(nil), state.Fns{Save: (*processGroupEntry).save, Load: (*processGroupEntry).load}) + state.Register("pkg/sentry/kernel.ptraceOptions", (*ptraceOptions)(nil), state.Fns{Save: (*ptraceOptions).save, Load: (*ptraceOptions).load}) + state.Register("pkg/sentry/kernel.ptraceStop", (*ptraceStop)(nil), state.Fns{Save: (*ptraceStop).save, Load: (*ptraceStop).load}) + state.Register("pkg/sentry/kernel.OldRSeqCriticalRegion", (*OldRSeqCriticalRegion)(nil), state.Fns{Save: (*OldRSeqCriticalRegion).save, Load: (*OldRSeqCriticalRegion).load}) + state.Register("pkg/sentry/kernel.sessionList", (*sessionList)(nil), state.Fns{Save: (*sessionList).save, Load: (*sessionList).load}) + state.Register("pkg/sentry/kernel.sessionEntry", (*sessionEntry)(nil), state.Fns{Save: (*sessionEntry).save, Load: (*sessionEntry).load}) + state.Register("pkg/sentry/kernel.Session", (*Session)(nil), state.Fns{Save: (*Session).save, Load: (*Session).load}) + state.Register("pkg/sentry/kernel.ProcessGroup", (*ProcessGroup)(nil), state.Fns{Save: (*ProcessGroup).save, Load: (*ProcessGroup).load}) + state.Register("pkg/sentry/kernel.SignalHandlers", (*SignalHandlers)(nil), state.Fns{Save: (*SignalHandlers).save, Load: (*SignalHandlers).load}) + state.Register("pkg/sentry/kernel.socketList", (*socketList)(nil), state.Fns{Save: (*socketList).save, Load: (*socketList).load}) + state.Register("pkg/sentry/kernel.socketEntry", (*socketEntry)(nil), state.Fns{Save: (*socketEntry).save, Load: (*socketEntry).load}) + state.Register("pkg/sentry/kernel.SyscallTable", (*SyscallTable)(nil), state.Fns{Save: (*SyscallTable).save, Load: (*SyscallTable).load}) + state.Register("pkg/sentry/kernel.syslog", (*syslog)(nil), state.Fns{Save: (*syslog).save, Load: (*syslog).load}) + state.Register("pkg/sentry/kernel.Task", (*Task)(nil), state.Fns{Save: (*Task).save, Load: (*Task).load}) + state.Register("pkg/sentry/kernel.runSyscallAfterPtraceEventClone", (*runSyscallAfterPtraceEventClone)(nil), state.Fns{Save: (*runSyscallAfterPtraceEventClone).save, Load: (*runSyscallAfterPtraceEventClone).load}) + state.Register("pkg/sentry/kernel.runSyscallAfterVforkStop", (*runSyscallAfterVforkStop)(nil), state.Fns{Save: (*runSyscallAfterVforkStop).save, Load: (*runSyscallAfterVforkStop).load}) + state.Register("pkg/sentry/kernel.vforkStop", (*vforkStop)(nil), state.Fns{Save: (*vforkStop).save, Load: (*vforkStop).load}) + state.Register("pkg/sentry/kernel.TaskContext", (*TaskContext)(nil), state.Fns{Save: (*TaskContext).save, Load: (*TaskContext).load}) + state.Register("pkg/sentry/kernel.execStop", (*execStop)(nil), state.Fns{Save: (*execStop).save, Load: (*execStop).load}) + state.Register("pkg/sentry/kernel.runSyscallAfterExecStop", (*runSyscallAfterExecStop)(nil), state.Fns{Save: (*runSyscallAfterExecStop).save, Load: (*runSyscallAfterExecStop).load}) + state.Register("pkg/sentry/kernel.ExitStatus", (*ExitStatus)(nil), state.Fns{Save: (*ExitStatus).save, Load: (*ExitStatus).load}) + state.Register("pkg/sentry/kernel.runExit", (*runExit)(nil), state.Fns{Save: (*runExit).save, Load: (*runExit).load}) + state.Register("pkg/sentry/kernel.runExitMain", (*runExitMain)(nil), state.Fns{Save: (*runExitMain).save, Load: (*runExitMain).load}) + state.Register("pkg/sentry/kernel.runExitNotify", (*runExitNotify)(nil), state.Fns{Save: (*runExitNotify).save, Load: (*runExitNotify).load}) + state.Register("pkg/sentry/kernel.taskList", (*taskList)(nil), state.Fns{Save: (*taskList).save, Load: (*taskList).load}) + state.Register("pkg/sentry/kernel.taskEntry", (*taskEntry)(nil), state.Fns{Save: (*taskEntry).save, Load: (*taskEntry).load}) + state.Register("pkg/sentry/kernel.runApp", (*runApp)(nil), state.Fns{Save: (*runApp).save, Load: (*runApp).load}) + state.Register("pkg/sentry/kernel.TaskGoroutineSchedInfo", (*TaskGoroutineSchedInfo)(nil), state.Fns{Save: (*TaskGoroutineSchedInfo).save, Load: (*TaskGoroutineSchedInfo).load}) + state.Register("pkg/sentry/kernel.taskClock", (*taskClock)(nil), state.Fns{Save: (*taskClock).save, Load: (*taskClock).load}) + state.Register("pkg/sentry/kernel.tgClock", (*tgClock)(nil), state.Fns{Save: (*tgClock).save, Load: (*tgClock).load}) + state.Register("pkg/sentry/kernel.groupStop", (*groupStop)(nil), state.Fns{Save: (*groupStop).save, Load: (*groupStop).load}) + state.Register("pkg/sentry/kernel.runInterrupt", (*runInterrupt)(nil), state.Fns{Save: (*runInterrupt).save, Load: (*runInterrupt).load}) + state.Register("pkg/sentry/kernel.runInterruptAfterSignalDeliveryStop", (*runInterruptAfterSignalDeliveryStop)(nil), state.Fns{Save: (*runInterruptAfterSignalDeliveryStop).save, Load: (*runInterruptAfterSignalDeliveryStop).load}) + state.Register("pkg/sentry/kernel.runSyscallAfterSyscallEnterStop", (*runSyscallAfterSyscallEnterStop)(nil), state.Fns{Save: (*runSyscallAfterSyscallEnterStop).save, Load: (*runSyscallAfterSyscallEnterStop).load}) + state.Register("pkg/sentry/kernel.runSyscallAfterSysemuStop", (*runSyscallAfterSysemuStop)(nil), state.Fns{Save: (*runSyscallAfterSysemuStop).save, Load: (*runSyscallAfterSysemuStop).load}) + state.Register("pkg/sentry/kernel.runSyscallReinvoke", (*runSyscallReinvoke)(nil), state.Fns{Save: (*runSyscallReinvoke).save, Load: (*runSyscallReinvoke).load}) + state.Register("pkg/sentry/kernel.runSyscallExit", (*runSyscallExit)(nil), state.Fns{Save: (*runSyscallExit).save, Load: (*runSyscallExit).load}) + state.Register("pkg/sentry/kernel.ThreadGroup", (*ThreadGroup)(nil), state.Fns{Save: (*ThreadGroup).save, Load: (*ThreadGroup).load}) + state.Register("pkg/sentry/kernel.itimerRealListener", (*itimerRealListener)(nil), state.Fns{Save: (*itimerRealListener).save, Load: (*itimerRealListener).load}) + state.Register("pkg/sentry/kernel.TaskSet", (*TaskSet)(nil), state.Fns{Save: (*TaskSet).save, Load: (*TaskSet).load}) + state.Register("pkg/sentry/kernel.PIDNamespace", (*PIDNamespace)(nil), state.Fns{Save: (*PIDNamespace).save, Load: (*PIDNamespace).load}) + state.Register("pkg/sentry/kernel.threadGroupNode", (*threadGroupNode)(nil), state.Fns{Save: (*threadGroupNode).save, Load: (*threadGroupNode).load}) + state.Register("pkg/sentry/kernel.taskNode", (*taskNode)(nil), state.Fns{Save: (*taskNode).save, Load: (*taskNode).load}) + state.Register("pkg/sentry/kernel.Timekeeper", (*Timekeeper)(nil), state.Fns{Save: (*Timekeeper).save, Load: (*Timekeeper).load}) + state.Register("pkg/sentry/kernel.timekeeperClock", (*timekeeperClock)(nil), state.Fns{Save: (*timekeeperClock).save, Load: (*timekeeperClock).load}) + state.Register("pkg/sentry/kernel.TTY", (*TTY)(nil), state.Fns{Save: (*TTY).save, Load: (*TTY).load}) + state.Register("pkg/sentry/kernel.UTSNamespace", (*UTSNamespace)(nil), state.Fns{Save: (*UTSNamespace).save, Load: (*UTSNamespace).load}) + state.Register("pkg/sentry/kernel.VDSOParamPage", (*VDSOParamPage)(nil), state.Fns{Save: (*VDSOParamPage).save, Load: (*VDSOParamPage).load}) +} diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD deleted file mode 100644 index 4486848d2..000000000 --- a/pkg/sentry/kernel/memevent/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "memevent", - srcs = ["memory_events.go"], - visibility = ["//:sandbox"], - deps = [ - ":memory_events_go_proto", - "//pkg/eventchannel", - "//pkg/log", - "//pkg/metric", - "//pkg/sentry/kernel", - "//pkg/sentry/usage", - "//pkg/sync", - ], -) - -proto_library( - name = "memory_events", - srcs = ["memory_events.proto"], - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/kernel/memevent/memevent_state_autogen.go b/pkg/sentry/kernel/memevent/memevent_state_autogen.go new file mode 100755 index 000000000..4a1679fa9 --- /dev/null +++ b/pkg/sentry/kernel/memevent/memevent_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package memevent diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go index 200565bb8..200565bb8 100644..100755 --- a/pkg/sentry/kernel/memevent/memory_events.go +++ b/pkg/sentry/kernel/memevent/memory_events.go diff --git a/pkg/sentry/kernel/memevent/memory_events.proto b/pkg/sentry/kernel/memevent/memory_events.proto deleted file mode 100644 index bf8029ff5..000000000 --- a/pkg/sentry/kernel/memevent/memory_events.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -// MemoryUsageEvent describes the memory usage of the sandbox at a single -// instant in time. These messages are emitted periodically on the eventchannel. -message MemoryUsageEvent { - // The total memory usage of the sandboxed application in bytes, calculated - // using the 'fast' method. - uint64 total = 1; - - // Memory used to back memory-mapped regions for files in the application, in - // bytes. This corresponds to the usage.MemoryKind.Mapped memory type. - uint64 mapped = 2; -} diff --git a/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go b/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go new file mode 100755 index 000000000..f8b857fa9 --- /dev/null +++ b/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go @@ -0,0 +1,88 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/kernel/memevent/memory_events.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type MemoryUsageEvent struct { + Total uint64 `protobuf:"varint,1,opt,name=total,proto3" json:"total,omitempty"` + Mapped uint64 `protobuf:"varint,2,opt,name=mapped,proto3" json:"mapped,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MemoryUsageEvent) Reset() { *m = MemoryUsageEvent{} } +func (m *MemoryUsageEvent) String() string { return proto.CompactTextString(m) } +func (*MemoryUsageEvent) ProtoMessage() {} +func (*MemoryUsageEvent) Descriptor() ([]byte, []int) { + return fileDescriptor_cd85fc8d1130e4b0, []int{0} +} + +func (m *MemoryUsageEvent) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MemoryUsageEvent.Unmarshal(m, b) +} +func (m *MemoryUsageEvent) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MemoryUsageEvent.Marshal(b, m, deterministic) +} +func (m *MemoryUsageEvent) XXX_Merge(src proto.Message) { + xxx_messageInfo_MemoryUsageEvent.Merge(m, src) +} +func (m *MemoryUsageEvent) XXX_Size() int { + return xxx_messageInfo_MemoryUsageEvent.Size(m) +} +func (m *MemoryUsageEvent) XXX_DiscardUnknown() { + xxx_messageInfo_MemoryUsageEvent.DiscardUnknown(m) +} + +var xxx_messageInfo_MemoryUsageEvent proto.InternalMessageInfo + +func (m *MemoryUsageEvent) GetTotal() uint64 { + if m != nil { + return m.Total + } + return 0 +} + +func (m *MemoryUsageEvent) GetMapped() uint64 { + if m != nil { + return m.Mapped + } + return 0 +} + +func init() { + proto.RegisterType((*MemoryUsageEvent)(nil), "gvisor.MemoryUsageEvent") +} + +func init() { + proto.RegisterFile("pkg/sentry/kernel/memevent/memory_events.proto", fileDescriptor_cd85fc8d1130e4b0) +} + +var fileDescriptor_cd85fc8d1130e4b0 = []byte{ + // 128 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2b, 0xc8, 0x4e, 0xd7, + 0x2f, 0x4e, 0xcd, 0x2b, 0x29, 0xaa, 0xd4, 0xcf, 0x4e, 0x2d, 0xca, 0x4b, 0xcd, 0xd1, 0xcf, 0x4d, + 0xcd, 0x4d, 0x2d, 0x4b, 0xcd, 0x2b, 0x01, 0x31, 0xf2, 0x8b, 0x2a, 0xe3, 0xc1, 0x9c, 0x62, 0xbd, + 0x82, 0xa2, 0xfc, 0x92, 0x7c, 0x21, 0xb6, 0xf4, 0xb2, 0xcc, 0xe2, 0xfc, 0x22, 0x25, 0x07, 0x2e, + 0x01, 0x5f, 0xb0, 0x74, 0x68, 0x71, 0x62, 0x7a, 0xaa, 0x2b, 0x48, 0x89, 0x90, 0x08, 0x17, 0x6b, + 0x49, 0x7e, 0x49, 0x62, 0x8e, 0x04, 0xa3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x84, 0x23, 0x24, 0xc6, + 0xc5, 0x96, 0x9b, 0x58, 0x50, 0x90, 0x9a, 0x22, 0xc1, 0x04, 0x16, 0x86, 0xf2, 0x92, 0xd8, 0xc0, + 0x06, 0x1a, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0x99, 0x31, 0x2f, 0x9d, 0x82, 0x00, 0x00, 0x00, +} diff --git a/pkg/sentry/kernel/pending_signals_list.go b/pkg/sentry/kernel/pending_signals_list.go new file mode 100755 index 000000000..8eb40ac2c --- /dev/null +++ b/pkg/sentry/kernel/pending_signals_list.go @@ -0,0 +1,186 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type pendingSignalElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (pendingSignalElementMapper) linkerFor(elem *pendingSignal) *pendingSignal { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type pendingSignalList struct { + head *pendingSignal + tail *pendingSignal +} + +// Reset resets list l to the empty state. +func (l *pendingSignalList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *pendingSignalList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *pendingSignalList) Front() *pendingSignal { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *pendingSignalList) Back() *pendingSignal { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *pendingSignalList) PushFront(e *pendingSignal) { + linker := pendingSignalElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + pendingSignalElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *pendingSignalList) PushBack(e *pendingSignal) { + linker := pendingSignalElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *pendingSignalList) PushBackList(m *pendingSignalList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(m.head) + pendingSignalElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *pendingSignalList) InsertAfter(b, e *pendingSignal) { + bLinker := pendingSignalElementMapper{}.linkerFor(b) + eLinker := pendingSignalElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + pendingSignalElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *pendingSignalList) InsertBefore(a, e *pendingSignal) { + aLinker := pendingSignalElementMapper{}.linkerFor(a) + eLinker := pendingSignalElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + pendingSignalElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *pendingSignalList) Remove(e *pendingSignal) { + linker := pendingSignalElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + pendingSignalElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + pendingSignalElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type pendingSignalEntry struct { + next *pendingSignal + prev *pendingSignal +} + +// Next returns the entry that follows e in the list. +func (e *pendingSignalEntry) Next() *pendingSignal { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *pendingSignalEntry) Prev() *pendingSignal { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *pendingSignalEntry) SetNext(elem *pendingSignal) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *pendingSignalEntry) SetPrev(elem *pendingSignal) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD deleted file mode 100644 index f29dc0472..000000000 --- a/pkg/sentry/kernel/pipe/BUILD +++ /dev/null @@ -1,51 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = [ - "device.go", - "node.go", - "pipe.go", - "pipe_util.go", - "reader.go", - "reader_writer.go", - "vfs.go", - "writer.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/buffer", - "//pkg/context", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "pipe_test", - size = "small", - srcs = [ - "node_test.go", - "pipe_test.go", - ], - library = ":pipe", - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go deleted file mode 100644 index ab75a87ff..000000000 --- a/pkg/sentry/kernel/pipe/node_test.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -type sleeper struct { - context.Context - ch chan struct{} -} - -func newSleeperContext(t *testing.T) context.Context { - return &sleeper{ - Context: contexttest.Context(t), - ch: make(chan struct{}), - } -} - -func (s *sleeper) SleepStart() <-chan struct{} { - return s.ch -} - -func (s *sleeper) SleepFinish(bool) { -} - -func (s *sleeper) Cancel() { - s.ch <- struct{}{} -} - -func (s *sleeper) Interrupted() bool { - return len(s.ch) != 0 -} - -type openResult struct { - *fs.File - error -} - -var perms fs.FilePermissions = fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, -} - -func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, doneChan chan<- struct{}) (*fs.File, error) { - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - d := fs.NewDirent(ctx, inode, "pipe") - file, err := n.GetFile(ctx, d, flags) - if err != nil { - t.Fatalf("open with flags %+v failed: %v", flags, err) - } - if doneChan != nil { - doneChan <- struct{}{} - } - return file, err -} - -func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, resChan chan<- openResult) (*fs.File, error) { - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - d := fs.NewDirent(ctx, inode, "pipe") - file, err := n.GetFile(ctx, d, flags) - if resChan != nil { - resChan <- openResult{file, err} - } - return file, err -} - -func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(true, DefaultPipeSize, usermem.PageSize) -} - -func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(false, DefaultPipeSize, usermem.PageSize) -} - -// assertRecvBlocks ensures that a recv attempt on c blocks for at least -// blockDuration. This is useful for checking that a goroutine that is supposed -// to be executing a blocking operation is actually blocking. -func assertRecvBlocks(t *testing.T, c <-chan struct{}, blockDuration time.Duration, failMsg string) { - select { - case <-c: - t.Fatalf(failMsg) - case <-time.After(blockDuration): - // Ok, blocked for the required duration. - } -} - -func TestReadOpenBlocksForWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - // Verify that the open for read is blocking. - assertRecvBlocks(t, rDone, time.Millisecond*100, - "open for read not blocking with no writers") - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - <-wDone - <-rDone -} - -func TestWriteOpenBlocksForReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - // Verify that the open for write is blocking - assertRecvBlocks(t, wDone, time.Millisecond*100, - "open for write not blocking with no readers") - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - <-rDone - <-wDone -} - -func TestMultipleWriteOpenDoesntCountAsReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone1 := make(chan struct{}) - rDone2 := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone1) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone2) - - assertRecvBlocks(t, rDone1, time.Millisecond*100, - "open for read didn't block with no writers") - assertRecvBlocks(t, rDone2, time.Millisecond*100, - "open for read didn't block with no writers") - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - <-wDone - <-rDone2 - <-rDone1 -} - -func TestClosedReaderBlocksWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rFile, _ := testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil) - rFile.DecRef() - - wDone := make(chan struct{}) - // This open for write should block because the reader is now gone. - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - assertRecvBlocks(t, wDone, time.Millisecond*100, - "open for write didn't block with no concurrent readers") - - // Open for read again. This should unblock the open for write. - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - <-rDone - <-wDone -} - -func TestReadWriteOpenNeverBlocks(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rwDone := make(chan struct{}) - // Open for read-write never wait for a reader or writer, even if the - // nonblocking flag is not set. - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true, NonBlocking: false}, rwDone) - <-rwDone -} - -func TestReadWriteOpenUnblocksReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - rwDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true}, rwDone) - - <-rwDone - <-rDone -} - -func TestReadWriteOpenUnblocksWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - rwDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true}, rwDone) - - <-rwDone - <-wDone -} - -func TestBlockedOpenIsCancellable(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - done := make(chan openResult) - go testOpen(ctx, t, f, fs.FileFlags{Read: true}, done) - select { - case <-done: - t.Fatalf("open for read didn't block with no writers") - case <-time.After(time.Millisecond * 100): - // Ok. - } - - ctx.(*sleeper).Cancel() - // If the cancel on the sleeper didn't work, the open for read would never - // return. - res := <-done - if res.error != syserror.ErrInterrupted { - t.Fatalf("Cancellation didn't cause GetFile to return fs.ErrInterrupted, got %v.", - res.error) - } -} - -func TestNonblockingReadOpenFileNoWriters(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for read failed with error %v.", err) - } -} - -func TestNonblockingWriteOpenFileNoReaders(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != syserror.ENXIO { - t.Fatalf("Nonblocking open for write failed unexpected error %v.", err) - } -} - -func TestNonBlockingReadOpenWithWriter(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - // Open for write blocks since there are no readers yet. - assertRecvBlocks(t, wDone, time.Millisecond*100, - "Open for write didn't block with no reader.") - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for read failed with error %v.", err) - } - - // Open for write should now be unblocked. - <-wDone -} - -func TestNonBlockingWriteOpenWithReader(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - // Open for write blocked, since no reader yet. - assertRecvBlocks(t, rDone, time.Millisecond*100, - "Open for reader didn't block with no writer.") - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for write failed with error %v.", err) - } - - // Open for write should now be unblocked. - <-rDone -} - -func TestAnonReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newAnonPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true}, nil); err != nil { - t.Fatalf("open anon pipe for read failed: %v", err) - } -} - -func TestAnonWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newAnonPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true}, nil); err != nil { - t.Fatalf("open anon pipe for write failed: %v", err) - } -} diff --git a/pkg/sentry/kernel/pipe/pipe_state_autogen.go b/pkg/sentry/kernel/pipe/pipe_state_autogen.go new file mode 100755 index 000000000..b49ab46f9 --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_state_autogen.go @@ -0,0 +1,84 @@ +// automatically generated by stateify. + +package pipe + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *inodeOperations) beforeSave() {} +func (x *inodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("p", &x.p) +} + +func (x *inodeOperations) afterLoad() {} +func (x *inodeOperations) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("p", &x.p) +} + +func (x *Pipe) beforeSave() {} +func (x *Pipe) save(m state.Map) { + x.beforeSave() + m.Save("isNamed", &x.isNamed) + m.Save("atomicIOBytes", &x.atomicIOBytes) + m.Save("readers", &x.readers) + m.Save("writers", &x.writers) + m.Save("view", &x.view) + m.Save("max", &x.max) + m.Save("hadWriter", &x.hadWriter) +} + +func (x *Pipe) afterLoad() {} +func (x *Pipe) load(m state.Map) { + m.Load("isNamed", &x.isNamed) + m.Load("atomicIOBytes", &x.atomicIOBytes) + m.Load("readers", &x.readers) + m.Load("writers", &x.writers) + m.Load("view", &x.view) + m.Load("max", &x.max) + m.Load("hadWriter", &x.hadWriter) +} + +func (x *Reader) beforeSave() {} +func (x *Reader) save(m state.Map) { + x.beforeSave() + m.Save("ReaderWriter", &x.ReaderWriter) +} + +func (x *Reader) afterLoad() {} +func (x *Reader) load(m state.Map) { + m.Load("ReaderWriter", &x.ReaderWriter) +} + +func (x *ReaderWriter) beforeSave() {} +func (x *ReaderWriter) save(m state.Map) { + x.beforeSave() + m.Save("Pipe", &x.Pipe) +} + +func (x *ReaderWriter) afterLoad() {} +func (x *ReaderWriter) load(m state.Map) { + m.Load("Pipe", &x.Pipe) +} + +func (x *Writer) beforeSave() {} +func (x *Writer) save(m state.Map) { + x.beforeSave() + m.Save("ReaderWriter", &x.ReaderWriter) +} + +func (x *Writer) afterLoad() {} +func (x *Writer) load(m state.Map) { + m.Load("ReaderWriter", &x.ReaderWriter) +} + +func init() { + state.Register("pkg/sentry/kernel/pipe.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load}) + state.Register("pkg/sentry/kernel/pipe.Pipe", (*Pipe)(nil), state.Fns{Save: (*Pipe).save, Load: (*Pipe).load}) + state.Register("pkg/sentry/kernel/pipe.Reader", (*Reader)(nil), state.Fns{Save: (*Reader).save, Load: (*Reader).load}) + state.Register("pkg/sentry/kernel/pipe.ReaderWriter", (*ReaderWriter)(nil), state.Fns{Save: (*ReaderWriter).save, Load: (*ReaderWriter).load}) + state.Register("pkg/sentry/kernel/pipe.Writer", (*Writer)(nil), state.Fns{Save: (*Writer).save, Load: (*Writer).load}) +} diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go deleted file mode 100644 index bda739dbe..000000000 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestPipeRW(t *testing.T) { - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) - defer r.DecRef() - defer w.DecRef() - - msg := []byte("here's some bytes") - wantN := int64(len(msg)) - n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) - if n != wantN || err != nil { - t.Fatalf("Writev: got (%d, %v), wanted (%d, nil)", n, err, wantN) - } - - buf := make([]byte, len(msg)) - n, err = r.Readv(ctx, usermem.BytesIOSequence(buf)) - if n != wantN || err != nil || !bytes.Equal(buf, msg) { - t.Fatalf("Readv: got (%d, %v) %q, wanted (%d, nil) %q", n, err, buf, wantN, msg) - } -} - -func TestPipeReadBlock(t *testing.T) { - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) - defer r.DecRef() - defer w.DecRef() - - n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1))) - if n != 0 || err != syserror.ErrWouldBlock { - t.Fatalf("Readv: got (%d, %v), wanted (0, %v)", n, err, syserror.ErrWouldBlock) - } -} - -func TestPipeWriteBlock(t *testing.T) { - const atomicIOBytes = 2 - const capacity = MinimumPipeSize - - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes) - defer r.DecRef() - defer w.DecRef() - - msg := make([]byte, capacity+1) - n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) - if wantN, wantErr := int64(capacity), syserror.ErrWouldBlock; n != wantN || err != wantErr { - t.Fatalf("Writev: got (%d, %v), wanted (%d, %v)", n, err, wantN, wantErr) - } -} - -func TestPipeWriteUntilEnd(t *testing.T) { - const atomicIOBytes = 2 - - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes) - defer r.DecRef() - defer w.DecRef() - - msg := []byte("here's some bytes") - - wDone := make(chan struct{}, 0) - rDone := make(chan struct{}, 0) - defer func() { - // Signal the reader to stop and wait until it does so. - close(wDone) - <-rDone - }() - - go func() { - defer close(rDone) - // Read from r until done is closed. - ctx := contexttest.Context(t) - buf := make([]byte, len(msg)+1) - dst := usermem.BytesIOSequence(buf) - e, ch := waiter.NewChannelEntry(nil) - r.EventRegister(&e, waiter.EventIn) - defer r.EventUnregister(&e) - for { - n, err := r.Readv(ctx, dst) - dst = dst.DropFirst64(n) - if err == syserror.ErrWouldBlock { - select { - case <-ch: - continue - case <-wDone: - // We expect to have 1 byte left in dst since len(buf) == - // len(msg)+1. - if dst.NumBytes() != 1 || !bytes.Equal(buf[:len(msg)], msg) { - t.Errorf("Reader: got %q (%d bytes remaining), wanted %q", buf, dst.NumBytes(), msg) - } - return - } - } - if err != nil { - t.Fatalf("Readv: got unexpected error %v", err) - } - } - }() - - src := usermem.BytesIOSequence(msg) - e, ch := waiter.NewChannelEntry(nil) - w.EventRegister(&e, waiter.EventOut) - defer w.EventUnregister(&e) - for src.NumBytes() != 0 { - n, err := w.Writev(ctx, src) - src = src.DropFirst64(n) - if err == syserror.ErrWouldBlock { - <-ch - continue - } - if err != nil { - t.Fatalf("Writev: got (%d, %v)", n, err) - } - } -} diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 5a1d4fd57..5a1d4fd57 100644..100755 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index a5675bd70..a5675bd70 100644..100755 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go diff --git a/pkg/sentry/kernel/process_group_list.go b/pkg/sentry/kernel/process_group_list.go new file mode 100755 index 000000000..40c1a13a4 --- /dev/null +++ b/pkg/sentry/kernel/process_group_list.go @@ -0,0 +1,186 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type processGroupElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (processGroupElementMapper) linkerFor(elem *ProcessGroup) *ProcessGroup { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type processGroupList struct { + head *ProcessGroup + tail *ProcessGroup +} + +// Reset resets list l to the empty state. +func (l *processGroupList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *processGroupList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *processGroupList) Front() *ProcessGroup { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *processGroupList) Back() *ProcessGroup { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *processGroupList) PushFront(e *ProcessGroup) { + linker := processGroupElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + processGroupElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *processGroupList) PushBack(e *ProcessGroup) { + linker := processGroupElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + processGroupElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *processGroupList) PushBackList(m *processGroupList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + processGroupElementMapper{}.linkerFor(l.tail).SetNext(m.head) + processGroupElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *processGroupList) InsertAfter(b, e *ProcessGroup) { + bLinker := processGroupElementMapper{}.linkerFor(b) + eLinker := processGroupElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + processGroupElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *processGroupList) InsertBefore(a, e *ProcessGroup) { + aLinker := processGroupElementMapper{}.linkerFor(a) + eLinker := processGroupElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + processGroupElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *processGroupList) Remove(e *ProcessGroup) { + linker := processGroupElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + processGroupElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + processGroupElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type processGroupEntry struct { + next *ProcessGroup + prev *ProcessGroup +} + +// Next returns the entry that follows e in the list. +func (e *processGroupEntry) Next() *ProcessGroup { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *processGroupEntry) Prev() *ProcessGroup { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *processGroupEntry) SetNext(elem *ProcessGroup) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *processGroupEntry) SetPrev(elem *ProcessGroup) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD deleted file mode 100644 index 1b82e087b..000000000 --- a/pkg/sentry/kernel/sched/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sched", - srcs = [ - "cpuset.go", - "sched.go", - ], - visibility = ["//pkg/sentry:internal"], -) - -go_test( - name = "sched_test", - size = "small", - srcs = ["cpuset_test.go"], - library = ":sched", -) diff --git a/pkg/sentry/kernel/sched/cpuset_test.go b/pkg/sentry/kernel/sched/cpuset_test.go deleted file mode 100644 index 3af9f1197..000000000 --- a/pkg/sentry/kernel/sched/cpuset_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sched - -import ( - "testing" -) - -func TestNumCPUs(t *testing.T) { - for i := uint(0); i < 1024; i++ { - c := NewCPUSet(i) - for j := uint(0); j < i; j++ { - c.Set(j) - } - n := c.NumCPUs() - if n != i { - t.Errorf("got wrong number of cpus %d, want %d", n, i) - } - } -} - -func TestClearAbove(t *testing.T) { - const n = 1024 - c := NewFullCPUSet(n) - for i := uint(0); i < n; i++ { - cpu := n - i - c.ClearAbove(cpu) - if got := c.NumCPUs(); got != cpu { - t.Errorf("iteration %d: got %d cpus, wanted %d", i, got, cpu) - } - } -} diff --git a/pkg/sentry/kernel/sched/sched_state_autogen.go b/pkg/sentry/kernel/sched/sched_state_autogen.go new file mode 100755 index 000000000..9705ca79d --- /dev/null +++ b/pkg/sentry/kernel/sched/sched_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sched diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD deleted file mode 100644 index 65e5427c1..000000000 --- a/pkg/sentry/kernel/semaphore/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "semaphore", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*waiter", - "Linker": "*waiter", - }, -) - -go_library( - name = "semaphore", - srcs = [ - "semaphore.go", - "waiter_list.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sync", - "//pkg/syserror", - ], -) - -go_test( - name = "semaphore_test", - size = "small", - srcs = ["semaphore_test.go"], - library = ":semaphore", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/kernel/auth", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go new file mode 100755 index 000000000..db80a1490 --- /dev/null +++ b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go @@ -0,0 +1,117 @@ +// automatically generated by stateify. + +package semaphore + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Registry) beforeSave() {} +func (x *Registry) save(m state.Map) { + x.beforeSave() + m.Save("userNS", &x.userNS) + m.Save("semaphores", &x.semaphores) + m.Save("lastIDUsed", &x.lastIDUsed) +} + +func (x *Registry) afterLoad() {} +func (x *Registry) load(m state.Map) { + m.Load("userNS", &x.userNS) + m.Load("semaphores", &x.semaphores) + m.Load("lastIDUsed", &x.lastIDUsed) +} + +func (x *Set) beforeSave() {} +func (x *Set) save(m state.Map) { + x.beforeSave() + m.Save("registry", &x.registry) + m.Save("ID", &x.ID) + m.Save("key", &x.key) + m.Save("creator", &x.creator) + m.Save("owner", &x.owner) + m.Save("perms", &x.perms) + m.Save("opTime", &x.opTime) + m.Save("changeTime", &x.changeTime) + m.Save("sems", &x.sems) + m.Save("dead", &x.dead) +} + +func (x *Set) afterLoad() {} +func (x *Set) load(m state.Map) { + m.Load("registry", &x.registry) + m.Load("ID", &x.ID) + m.Load("key", &x.key) + m.Load("creator", &x.creator) + m.Load("owner", &x.owner) + m.Load("perms", &x.perms) + m.Load("opTime", &x.opTime) + m.Load("changeTime", &x.changeTime) + m.Load("sems", &x.sems) + m.Load("dead", &x.dead) +} + +func (x *sem) beforeSave() {} +func (x *sem) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.waiters) { + m.Failf("waiters is %v, expected zero", x.waiters) + } + m.Save("value", &x.value) + m.Save("pid", &x.pid) +} + +func (x *sem) afterLoad() {} +func (x *sem) load(m state.Map) { + m.Load("value", &x.value) + m.Load("pid", &x.pid) +} + +func (x *waiter) beforeSave() {} +func (x *waiter) save(m state.Map) { + x.beforeSave() + m.Save("waiterEntry", &x.waiterEntry) + m.Save("value", &x.value) + m.Save("ch", &x.ch) +} + +func (x *waiter) afterLoad() {} +func (x *waiter) load(m state.Map) { + m.Load("waiterEntry", &x.waiterEntry) + m.Load("value", &x.value) + m.Load("ch", &x.ch) +} + +func (x *waiterList) beforeSave() {} +func (x *waiterList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *waiterList) afterLoad() {} +func (x *waiterList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *waiterEntry) beforeSave() {} +func (x *waiterEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *waiterEntry) afterLoad() {} +func (x *waiterEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/sentry/kernel/semaphore.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load}) + state.Register("pkg/sentry/kernel/semaphore.Set", (*Set)(nil), state.Fns{Save: (*Set).save, Load: (*Set).load}) + state.Register("pkg/sentry/kernel/semaphore.sem", (*sem)(nil), state.Fns{Save: (*sem).save, Load: (*sem).load}) + state.Register("pkg/sentry/kernel/semaphore.waiter", (*waiter)(nil), state.Fns{Save: (*waiter).save, Load: (*waiter).load}) + state.Register("pkg/sentry/kernel/semaphore.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load}) + state.Register("pkg/sentry/kernel/semaphore.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load}) +} diff --git a/pkg/sentry/kernel/semaphore/semaphore_test.go b/pkg/sentry/kernel/semaphore/semaphore_test.go deleted file mode 100644 index e47acefdf..000000000 --- a/pkg/sentry/kernel/semaphore/semaphore_test.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package semaphore - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" -) - -func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} { - ch, _, err := set.executeOps(ctx, ops, 123) - if err != nil { - t.Fatalf("ExecuteOps(ops) failed, err: %v, ops: %+v", err, ops) - } - if block { - if ch == nil { - t.Fatalf("ExecuteOps(ops) got: nil, expected: !nil, ops: %+v", ops) - } - if signalled(ch) { - t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops) - } - } else { - if ch != nil { - t.Fatalf("ExecuteOps(ops) got: %v, expected: nil, ops: %+v", ch, ops) - } - } - return ch -} - -func signalled(ch chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - -func TestBasic(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{ID: 123, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 1}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -1 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -1 - ch1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - if !signalled(ch1) { - t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops) - } -} - -func TestWaitForZero(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{ID: 123, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 0}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -2 - ch1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 0 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = 0 - chZero1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 0 - chZero2 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - if !signalled(ch1) { - t.Fatalf("ExecuteOps(ops) channel should have been signalled, ops: %+v, set: %+v", ops, set) - } - - ops[0].SemOp = -2 - executeOps(ctx, t, set, ops, false) - if !signalled(chZero1) { - t.Fatalf("ExecuteOps(ops) channel zero 1 should have been signalled, ops: %+v, set: %+v", ops, set) - } - if !signalled(chZero2) { - t.Fatalf("ExecuteOps(ops) channel zero 2 should have been signalled, ops: %+v, set: %+v", ops, set) - } -} - -func TestNoWait(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{ID: 123, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 1}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -2 - ops[0].SemFlg = linux.IPC_NOWAIT - if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock { - t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock) - } - - ops[0].SemOp = 0 - ops[0].SemFlg = linux.IPC_NOWAIT - if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock { - t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock) - } -} - -func TestUnregister(t *testing.T) { - ctx := contexttest.Context(t) - r := NewRegistry(auth.NewRootUserNamespace()) - set, err := r.FindOrCreate(ctx, 123, 2, linux.FileMode(0x600), true, true, true) - if err != nil { - t.Fatalf("FindOrCreate() failed, err: %v", err) - } - if got := r.FindByID(set.ID); got.ID != set.ID { - t.Fatalf("FindById(%d) failed, got: %+v, expected: %+v", set.ID, got, set) - } - - ops := []linux.Sembuf{ - {SemOp: -1}, - } - chs := make([]chan struct{}, 0, 5) - for i := 0; i < 5; i++ { - ch := executeOps(ctx, t, set, ops, true) - chs = append(chs, ch) - } - - creds := auth.CredentialsFromContext(ctx) - if err := r.RemoveID(set.ID, creds); err != nil { - t.Fatalf("RemoveID(%d) failed, err: %v", set.ID, err) - } - if !set.dead { - t.Fatalf("set is not dead: %+v", set) - } - if got := r.FindByID(set.ID); got != nil { - t.Fatalf("FindById(%d) failed, got: %+v, expected: nil", set.ID, got) - } - for i, ch := range chs { - if !signalled(ch) { - t.Fatalf("channel %d should have been signalled", i) - } - } -} diff --git a/pkg/sentry/kernel/semaphore/waiter_list.go b/pkg/sentry/kernel/semaphore/waiter_list.go new file mode 100755 index 000000000..27120afe3 --- /dev/null +++ b/pkg/sentry/kernel/semaphore/waiter_list.go @@ -0,0 +1,186 @@ +package semaphore + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *waiter) *waiter { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *waiter + tail *waiter +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *waiterList) Front() *waiter { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *waiterList) Back() *waiter { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *waiterList) PushFront(e *waiter) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *waiterList) PushBack(e *waiter) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *waiterList) InsertAfter(b, e *waiter) { + bLinker := waiterElementMapper{}.linkerFor(b) + eLinker := waiterElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *waiterList) InsertBefore(a, e *waiter) { + aLinker := waiterElementMapper{}.linkerFor(a) + eLinker := waiterElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *waiterList) Remove(e *waiter) { + linker := waiterElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *waiter + prev *waiter +} + +// Next returns the entry that follows e in the list. +func (e *waiterEntry) Next() *waiter { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *waiterEntry) Prev() *waiter { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *waiterEntry) SetNext(elem *waiter) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *waiterEntry) SetPrev(elem *waiter) { + e.prev = elem +} diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go index eda6fb131..950645965 100644..100755 --- a/pkg/sync/seqatomic_unsafe.go +++ b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go @@ -1,11 +1,4 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template +package kernel import ( "fmt" @@ -16,29 +9,19 @@ import ( "gvisor.dev/gvisor/pkg/sync" ) -// Value is a required type parameter. -// -// Value must not contain any pointers, including interface objects, function -// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs -// containing any of the above. An init() function will panic if this property -// does not hold. -type Value struct{} - // SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race // with any writer critical sections in sc. -func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { +func SeqAtomicLoadTaskGoroutineSchedInfo(sc *sync.SeqCount, ptr *TaskGoroutineSchedInfo) TaskGoroutineSchedInfo { // This function doesn't use SeqAtomicTryLoad because doing so is // measurably, significantly (~20%) slower; Go is awful at inlining. - var val Value + var val TaskGoroutineSchedInfo for { epoch := sc.BeginRead() if sync.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, - // so it can't help us here. Instead, call runtime.memmove - // directly, which is not instrumented by the race detector. + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { - // This is ~40% faster for short reads than going through memmove. + val = *ptr } if sc.ReadOk(epoch) { @@ -52,8 +35,8 @@ func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { // in sc initiated by a call to sc.BeginRead() that returned epoch. If the read // would race with a writer critical section, SeqAtomicTryLoad returns // (unspecified, false). -func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) { - var val Value +func SeqAtomicTryLoadTaskGoroutineSchedInfo(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *TaskGoroutineSchedInfo) (TaskGoroutineSchedInfo, bool) { + var val TaskGoroutineSchedInfo if sync.RaceEnabled { sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { @@ -62,8 +45,8 @@ func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) ( return val, sc.ReadOk(epoch) } -func init() { - var val Value +func initTaskGoroutineSchedInfo() { + var val TaskGoroutineSchedInfo typ := reflect.TypeOf(val) name := typ.Name() if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 { diff --git a/pkg/sentry/kernel/session_list.go b/pkg/sentry/kernel/session_list.go new file mode 100755 index 000000000..8174f413d --- /dev/null +++ b/pkg/sentry/kernel/session_list.go @@ -0,0 +1,186 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type sessionElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (sessionElementMapper) linkerFor(elem *Session) *Session { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type sessionList struct { + head *Session + tail *Session +} + +// Reset resets list l to the empty state. +func (l *sessionList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *sessionList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *sessionList) Front() *Session { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *sessionList) Back() *Session { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *sessionList) PushFront(e *Session) { + linker := sessionElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + sessionElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *sessionList) PushBack(e *Session) { + linker := sessionElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + sessionElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *sessionList) PushBackList(m *sessionList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + sessionElementMapper{}.linkerFor(l.tail).SetNext(m.head) + sessionElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *sessionList) InsertAfter(b, e *Session) { + bLinker := sessionElementMapper{}.linkerFor(b) + eLinker := sessionElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + sessionElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *sessionList) InsertBefore(a, e *Session) { + aLinker := sessionElementMapper{}.linkerFor(a) + eLinker := sessionElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + sessionElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *sessionList) Remove(e *Session) { + linker := sessionElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + sessionElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + sessionElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type sessionEntry struct { + next *Session + prev *Session +} + +// Next returns the entry that follows e in the list. +func (e *sessionEntry) Next() *Session { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *sessionEntry) Prev() *Session { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *sessionEntry) SetNext(elem *Session) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *sessionEntry) SetPrev(elem *Session) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD deleted file mode 100644 index bfd779837..000000000 --- a/pkg/sentry/kernel/shm/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "shm", - srcs = [ - "device.go", - "shm.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/kernel/shm/shm_state_autogen.go b/pkg/sentry/kernel/shm/shm_state_autogen.go new file mode 100755 index 000000000..fa8f896f7 --- /dev/null +++ b/pkg/sentry/kernel/shm/shm_state_autogen.go @@ -0,0 +1,74 @@ +// automatically generated by stateify. + +package shm + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Registry) beforeSave() {} +func (x *Registry) save(m state.Map) { + x.beforeSave() + m.Save("userNS", &x.userNS) + m.Save("shms", &x.shms) + m.Save("keysToShms", &x.keysToShms) + m.Save("totalPages", &x.totalPages) + m.Save("lastIDUsed", &x.lastIDUsed) +} + +func (x *Registry) afterLoad() {} +func (x *Registry) load(m state.Map) { + m.Load("userNS", &x.userNS) + m.Load("shms", &x.shms) + m.Load("keysToShms", &x.keysToShms) + m.Load("totalPages", &x.totalPages) + m.Load("lastIDUsed", &x.lastIDUsed) +} + +func (x *Shm) beforeSave() {} +func (x *Shm) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("mfp", &x.mfp) + m.Save("registry", &x.registry) + m.Save("ID", &x.ID) + m.Save("creator", &x.creator) + m.Save("size", &x.size) + m.Save("effectiveSize", &x.effectiveSize) + m.Save("fr", &x.fr) + m.Save("key", &x.key) + m.Save("perms", &x.perms) + m.Save("owner", &x.owner) + m.Save("attachTime", &x.attachTime) + m.Save("detachTime", &x.detachTime) + m.Save("changeTime", &x.changeTime) + m.Save("creatorPID", &x.creatorPID) + m.Save("lastAttachDetachPID", &x.lastAttachDetachPID) + m.Save("pendingDestruction", &x.pendingDestruction) +} + +func (x *Shm) afterLoad() {} +func (x *Shm) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("mfp", &x.mfp) + m.Load("registry", &x.registry) + m.Load("ID", &x.ID) + m.Load("creator", &x.creator) + m.Load("size", &x.size) + m.Load("effectiveSize", &x.effectiveSize) + m.Load("fr", &x.fr) + m.Load("key", &x.key) + m.Load("perms", &x.perms) + m.Load("owner", &x.owner) + m.Load("attachTime", &x.attachTime) + m.Load("detachTime", &x.detachTime) + m.Load("changeTime", &x.changeTime) + m.Load("creatorPID", &x.creatorPID) + m.Load("lastAttachDetachPID", &x.lastAttachDetachPID) + m.Load("pendingDestruction", &x.pendingDestruction) +} + +func init() { + state.Register("pkg/sentry/kernel/shm.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load}) + state.Register("pkg/sentry/kernel/shm.Shm", (*Shm)(nil), state.Fns{Save: (*Shm).save, Load: (*Shm).load}) +} diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD deleted file mode 100644 index 3eb78e91b..000000000 --- a/pkg/sentry/kernel/signalfd/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "signalfd", - srcs = ["signalfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 8243bb93e..8243bb93e 100644..100755 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go diff --git a/pkg/sentry/kernel/signalfd/signalfd_state_autogen.go b/pkg/sentry/kernel/signalfd/signalfd_state_autogen.go new file mode 100755 index 000000000..2ab5b4702 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/signalfd_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package signalfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SignalOperations) beforeSave() {} +func (x *SignalOperations) save(m state.Map) { + x.beforeSave() + m.Save("target", &x.target) + m.Save("mask", &x.mask) +} + +func (x *SignalOperations) afterLoad() {} +func (x *SignalOperations) load(m state.Map) { + m.Load("target", &x.target) + m.Load("mask", &x.mask) +} + +func init() { + state.Register("pkg/sentry/kernel/signalfd.SignalOperations", (*SignalOperations)(nil), state.Fns{Save: (*SignalOperations).save, Load: (*SignalOperations).load}) +} diff --git a/pkg/sentry/kernel/socket_list.go b/pkg/sentry/kernel/socket_list.go new file mode 100755 index 000000000..ac93e2365 --- /dev/null +++ b/pkg/sentry/kernel/socket_list.go @@ -0,0 +1,186 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type socketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (socketElementMapper) linkerFor(elem *SocketEntry) *SocketEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type socketList struct { + head *SocketEntry + tail *SocketEntry +} + +// Reset resets list l to the empty state. +func (l *socketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *socketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *socketList) Front() *SocketEntry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *socketList) Back() *SocketEntry { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *socketList) PushFront(e *SocketEntry) { + linker := socketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + socketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *socketList) PushBack(e *SocketEntry) { + linker := socketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + socketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *socketList) PushBackList(m *socketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + socketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + socketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *socketList) InsertAfter(b, e *SocketEntry) { + bLinker := socketElementMapper{}.linkerFor(b) + eLinker := socketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + socketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *socketList) InsertBefore(a, e *SocketEntry) { + aLinker := socketElementMapper{}.linkerFor(a) + eLinker := socketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + socketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *socketList) Remove(e *SocketEntry) { + linker := socketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + socketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + socketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type socketEntry struct { + next *SocketEntry + prev *SocketEntry +} + +// Next returns the entry that follows e in the list. +func (e *socketEntry) Next() *SocketEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *socketEntry) Prev() *SocketEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *socketEntry) SetNext(elem *SocketEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *socketEntry) SetPrev(elem *SocketEntry) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/table_test.go b/pkg/sentry/kernel/table_test.go deleted file mode 100644 index 32cf47e05..000000000 --- a/pkg/sentry/kernel/table_test.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/sentry/arch" -) - -const ( - maxTestSyscall = 1000 -) - -func createSyscallTable() *SyscallTable { - m := make(map[uintptr]Syscall) - for i := uintptr(0); i <= maxTestSyscall; i++ { - j := i - m[i] = Syscall{ - Fn: func(*Task, arch.SyscallArguments) (uintptr, *SyscallControl, error) { - return j, nil, nil - }, - } - } - - s := &SyscallTable{ - OS: abi.Linux, - Arch: arch.AMD64, - Table: m, - } - - RegisterSyscallTable(s) - return s -} - -func TestTable(t *testing.T) { - table := createSyscallTable() - defer func() { - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} - }() - - // Go through all functions and check that they return the right value. - for i := uintptr(0); i < maxTestSyscall; i++ { - fn := table.Lookup(i) - if fn == nil { - t.Errorf("Syscall %v is set to nil", i) - continue - } - - v, _, _ := fn(nil, arch.SyscallArguments{}) - if v != i { - t.Errorf("Wrong return value for syscall %v: expected %v, got %v", i, i, v) - } - } - - // Check that values outside the range return nil. - for i := uintptr(maxTestSyscall + 1); i < maxTestSyscall+100; i++ { - fn := table.Lookup(i) - if fn != nil { - t.Errorf("Syscall %v is not nil: %v", i, fn) - continue - } - } -} - -func BenchmarkTableLookup(b *testing.B) { - table := createSyscallTable() - - b.ResetTimer() - - j := uintptr(0) - for i := 0; i < b.N; i++ { - table.Lookup(j) - j = (j + 1) % 310 - } - - b.StopTimer() - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} -} - -func BenchmarkTableMapLookup(b *testing.B) { - table := createSyscallTable() - - b.ResetTimer() - - j := uintptr(0) - for i := 0; i < b.N; i++ { - table.mapLookup(j) - j = (j + 1) % 310 - } - - b.StopTimer() - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} -} diff --git a/pkg/sentry/kernel/task_list.go b/pkg/sentry/kernel/task_list.go new file mode 100755 index 000000000..4dfcdbf2c --- /dev/null +++ b/pkg/sentry/kernel/task_list.go @@ -0,0 +1,186 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type taskElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (taskElementMapper) linkerFor(elem *Task) *Task { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type taskList struct { + head *Task + tail *Task +} + +// Reset resets list l to the empty state. +func (l *taskList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *taskList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *taskList) Front() *Task { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *taskList) Back() *Task { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *taskList) PushFront(e *Task) { + linker := taskElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + taskElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *taskList) PushBack(e *Task) { + linker := taskElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + taskElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *taskList) PushBackList(m *taskList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + taskElementMapper{}.linkerFor(l.tail).SetNext(m.head) + taskElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *taskList) InsertAfter(b, e *Task) { + bLinker := taskElementMapper{}.linkerFor(b) + eLinker := taskElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + taskElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *taskList) InsertBefore(a, e *Task) { + aLinker := taskElementMapper{}.linkerFor(a) + eLinker := taskElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + taskElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *taskList) Remove(e *Task) { + linker := taskElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + taskElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + taskElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type taskEntry struct { + next *Task + prev *Task +} + +// Next returns the entry that follows e in the list. +func (e *taskEntry) Next() *Task { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *taskEntry) Prev() *Task { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *taskEntry) SetNext(elem *Task) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *taskEntry) SetPrev(elem *Task) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/task_test.go b/pkg/sentry/kernel/task_test.go deleted file mode 100644 index cfcde9a7a..000000000 --- a/pkg/sentry/kernel/task_test.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/kernel/sched" -) - -func TestTaskCPU(t *testing.T) { - for _, test := range []struct { - mask sched.CPUSet - tid ThreadID - cpu int32 - }{ - { - mask: []byte{0xff}, - tid: 1, - cpu: 0, - }, - { - mask: []byte{0xff}, - tid: 10, - cpu: 1, - }, - { - // more than 8 cpus. - mask: []byte{0xff, 0xff}, - tid: 10, - cpu: 9, - }, - { - // missing the first cpu. - mask: []byte{0xfe}, - tid: 1, - cpu: 1, - }, - { - mask: []byte{0xfe}, - tid: 10, - cpu: 3, - }, - { - // missing the fifth cpu. - mask: []byte{0xef}, - tid: 10, - cpu: 2, - }, - } { - assigned := assignCPU(test.mask, test.tid) - if test.cpu != assigned { - t.Errorf("assignCPU(%v, %v) got %v, want %v", test.mask, test.tid, assigned, test.cpu) - } - } - -} diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD deleted file mode 100644 index 7ba7dc50c..000000000 --- a/pkg/sentry/kernel/time/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "time", - srcs = [ - "context.go", - "time.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sync", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/time/time_state_autogen.go b/pkg/sentry/kernel/time/time_state_autogen.go new file mode 100755 index 000000000..ab6c6633d --- /dev/null +++ b/pkg/sentry/kernel/time/time_state_autogen.go @@ -0,0 +1,56 @@ +// automatically generated by stateify. + +package time + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Time) beforeSave() {} +func (x *Time) save(m state.Map) { + x.beforeSave() + m.Save("ns", &x.ns) +} + +func (x *Time) afterLoad() {} +func (x *Time) load(m state.Map) { + m.Load("ns", &x.ns) +} + +func (x *Setting) beforeSave() {} +func (x *Setting) save(m state.Map) { + x.beforeSave() + m.Save("Enabled", &x.Enabled) + m.Save("Next", &x.Next) + m.Save("Period", &x.Period) +} + +func (x *Setting) afterLoad() {} +func (x *Setting) load(m state.Map) { + m.Load("Enabled", &x.Enabled) + m.Load("Next", &x.Next) + m.Load("Period", &x.Period) +} + +func (x *Timer) beforeSave() {} +func (x *Timer) save(m state.Map) { + x.beforeSave() + m.Save("clock", &x.clock) + m.Save("listener", &x.listener) + m.Save("setting", &x.setting) + m.Save("paused", &x.paused) +} + +func (x *Timer) afterLoad() {} +func (x *Timer) load(m state.Map) { + m.Load("clock", &x.clock) + m.Load("listener", &x.listener) + m.Load("setting", &x.setting) + m.Load("paused", &x.paused) +} + +func init() { + state.Register("pkg/sentry/kernel/time.Time", (*Time)(nil), state.Fns{Save: (*Time).save, Load: (*Time).load}) + state.Register("pkg/sentry/kernel/time.Setting", (*Setting)(nil), state.Fns{Save: (*Setting).save, Load: (*Setting).load}) + state.Register("pkg/sentry/kernel/time.Timer", (*Timer)(nil), state.Fns{Save: (*Timer).save, Load: (*Timer).load}) +} diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go deleted file mode 100644 index cf2f7ca72..000000000 --- a/pkg/sentry/kernel/timekeeper_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - sentrytime "gvisor.dev/gvisor/pkg/sentry/time" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// mockClocks is a sentrytime.Clocks that simply returns the times in the -// struct. -type mockClocks struct { - monotonic int64 - realtime int64 -} - -// Update implements sentrytime.Clocks.Update. It does nothing. -func (*mockClocks) Update() (monotonicParams sentrytime.Parameters, monotonicOk bool, realtimeParam sentrytime.Parameters, realtimeOk bool) { - return -} - -// Update implements sentrytime.Clocks.GetTime. -func (c *mockClocks) GetTime(id sentrytime.ClockID) (int64, error) { - switch id { - case sentrytime.Monotonic: - return c.monotonic, nil - case sentrytime.Realtime: - return c.realtime, nil - default: - return 0, syserror.EINVAL - } -} - -// stateTestClocklessTimekeeper returns a test Timekeeper which has not had -// SetClocks called. -func stateTestClocklessTimekeeper(tb testing.TB) *Timekeeper { - ctx := contexttest.Context(tb) - mfp := pgalloc.MemoryFileProviderFromContext(ctx) - fr, err := mfp.MemoryFile().Allocate(usermem.PageSize, usage.Anonymous) - if err != nil { - tb.Fatalf("failed to allocate memory: %v", err) - } - return &Timekeeper{ - params: NewVDSOParamPage(mfp, fr), - } -} - -func stateTestTimekeeper(tb testing.TB) *Timekeeper { - t := stateTestClocklessTimekeeper(tb) - t.SetClocks(sentrytime.NewCalibratedClocks()) - return t -} - -// TestTimekeeperMonotonicZero tests that monotonic time starts at zero. -func TestTimekeeperMonotonicZero(t *testing.T) { - c := &mockClocks{ - monotonic: 100000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.SetClocks(c) - defer tk.Destroy() - - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 0 { - t.Errorf("GetTime got %d want 0", now) - } - - c.monotonic += 10 - - now, err = tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 10 { - t.Errorf("GetTime got %d want 10", now) - } -} - -// TestTimekeeperMonotonicJumpForward tests that monotonic time jumps forward -// after restore. -func TestTimekeeperMonotonicForward(t *testing.T) { - c := &mockClocks{ - monotonic: 900000, - realtime: 600000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.restored = make(chan struct{}) - tk.saveMonotonic = 100000 - tk.saveRealtime = 400000 - tk.SetClocks(c) - defer tk.Destroy() - - // The monotonic clock should jump ahead by 200000 to 300000. - // - // The new system monotonic time (900000) is irrelevant to what the app - // sees. - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 300000 { - t.Errorf("GetTime got %d want 300000", now) - } -} - -// TestTimekeeperMonotonicJumpBackwards tests that monotonic time does not jump -// backwards when realtime goes backwards. -func TestTimekeeperMonotonicJumpBackwards(t *testing.T) { - c := &mockClocks{ - monotonic: 900000, - realtime: 400000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.restored = make(chan struct{}) - tk.saveMonotonic = 100000 - tk.saveRealtime = 600000 - tk.SetClocks(c) - defer tk.Destroy() - - // The monotonic clock should remain at 100000. - // - // The new system monotonic time (900000) is irrelevant to what the app - // sees and we don't want to jump the monotonic clock backwards like - // realtime did. - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 100000 { - t.Errorf("GetTime got %d want 100000", now) - } -} diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go index d0e0810e8..d0e0810e8 100644..100755 --- a/pkg/sentry/kernel/tty.go +++ b/pkg/sentry/kernel/tty.go diff --git a/pkg/sentry/kernel/uncaught_signal.proto b/pkg/sentry/kernel/uncaught_signal.proto deleted file mode 100644 index 0bdb062cb..000000000 --- a/pkg/sentry/kernel/uncaught_signal.proto +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -import "pkg/sentry/arch/registers.proto"; - -message UncaughtSignal { - // Thread ID. - int32 tid = 1; - - // Process ID. - int32 pid = 2; - - // Registers at the time of the fault or signal. - Registers registers = 3; - - // Signal number. - int32 signal_number = 4; - - // The memory location which caused the fault (set if applicable, 0 - // otherwise). This will be set for SIGILL, SIGFPE, SIGSEGV, and SIGBUS. - uint64 fault_addr = 5; -} diff --git a/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go b/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go new file mode 100755 index 000000000..822e549ab --- /dev/null +++ b/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go @@ -0,0 +1,119 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/kernel/uncaught_signal.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + registers_go_proto "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type UncaughtSignal struct { + Tid int32 `protobuf:"varint,1,opt,name=tid,proto3" json:"tid,omitempty"` + Pid int32 `protobuf:"varint,2,opt,name=pid,proto3" json:"pid,omitempty"` + Registers *registers_go_proto.Registers `protobuf:"bytes,3,opt,name=registers,proto3" json:"registers,omitempty"` + SignalNumber int32 `protobuf:"varint,4,opt,name=signal_number,json=signalNumber,proto3" json:"signal_number,omitempty"` + FaultAddr uint64 `protobuf:"varint,5,opt,name=fault_addr,json=faultAddr,proto3" json:"fault_addr,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *UncaughtSignal) Reset() { *m = UncaughtSignal{} } +func (m *UncaughtSignal) String() string { return proto.CompactTextString(m) } +func (*UncaughtSignal) ProtoMessage() {} +func (*UncaughtSignal) Descriptor() ([]byte, []int) { + return fileDescriptor_5ca9e03e13704688, []int{0} +} + +func (m *UncaughtSignal) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_UncaughtSignal.Unmarshal(m, b) +} +func (m *UncaughtSignal) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_UncaughtSignal.Marshal(b, m, deterministic) +} +func (m *UncaughtSignal) XXX_Merge(src proto.Message) { + xxx_messageInfo_UncaughtSignal.Merge(m, src) +} +func (m *UncaughtSignal) XXX_Size() int { + return xxx_messageInfo_UncaughtSignal.Size(m) +} +func (m *UncaughtSignal) XXX_DiscardUnknown() { + xxx_messageInfo_UncaughtSignal.DiscardUnknown(m) +} + +var xxx_messageInfo_UncaughtSignal proto.InternalMessageInfo + +func (m *UncaughtSignal) GetTid() int32 { + if m != nil { + return m.Tid + } + return 0 +} + +func (m *UncaughtSignal) GetPid() int32 { + if m != nil { + return m.Pid + } + return 0 +} + +func (m *UncaughtSignal) GetRegisters() *registers_go_proto.Registers { + if m != nil { + return m.Registers + } + return nil +} + +func (m *UncaughtSignal) GetSignalNumber() int32 { + if m != nil { + return m.SignalNumber + } + return 0 +} + +func (m *UncaughtSignal) GetFaultAddr() uint64 { + if m != nil { + return m.FaultAddr + } + return 0 +} + +func init() { + proto.RegisterType((*UncaughtSignal)(nil), "gvisor.UncaughtSignal") +} + +func init() { + proto.RegisterFile("pkg/sentry/kernel/uncaught_signal.proto", fileDescriptor_5ca9e03e13704688) +} + +var fileDescriptor_5ca9e03e13704688 = []byte{ + // 210 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x4c, 0x8e, 0x4d, 0x4a, 0xc6, 0x30, + 0x10, 0x86, 0x89, 0xfd, 0x81, 0xc6, 0x1f, 0x34, 0xab, 0x20, 0x88, 0x45, 0x17, 0x76, 0xd5, 0x80, + 0x9e, 0xc0, 0x0b, 0xb8, 0x88, 0xb8, 0x2e, 0x69, 0x13, 0xd3, 0xd0, 0x9a, 0x86, 0x49, 0x22, 0x78, + 0x24, 0x6f, 0x29, 0x4d, 0xd4, 0xef, 0xdb, 0x0d, 0xcf, 0xbc, 0xf3, 0xcc, 0x8b, 0x1f, 0xdc, 0xa2, + 0x99, 0x57, 0x36, 0xc0, 0x17, 0x5b, 0x14, 0x58, 0xb5, 0xb2, 0x68, 0x27, 0x11, 0xf5, 0x1c, 0x06, + 0x6f, 0xb4, 0x15, 0x6b, 0xef, 0x60, 0x0b, 0x1b, 0xa9, 0xf5, 0xa7, 0xf1, 0x1b, 0x5c, 0xdf, 0x1e, + 0x1d, 0x08, 0x98, 0x66, 0x06, 0x4a, 0x1b, 0x1f, 0x14, 0xf8, 0x1c, 0xbc, 0xfb, 0x46, 0xf8, 0xe2, + 0xed, 0x57, 0xf1, 0x9a, 0x0c, 0xe4, 0x12, 0x17, 0xc1, 0x48, 0x8a, 0x5a, 0xd4, 0x55, 0x7c, 0x1f, + 0x77, 0xe2, 0x8c, 0xa4, 0x27, 0x99, 0x38, 0x23, 0x09, 0xc3, 0xcd, 0xbf, 0x89, 0x16, 0x2d, 0xea, + 0x4e, 0x1f, 0xaf, 0xfa, 0xfc, 0xb3, 0xe7, 0x7f, 0x0b, 0x7e, 0xc8, 0x90, 0x7b, 0x7c, 0x9e, 0x0b, + 0x0e, 0x36, 0x7e, 0x8c, 0x0a, 0x68, 0x99, 0x64, 0x67, 0x19, 0xbe, 0x24, 0x46, 0x6e, 0x30, 0x7e, + 0x17, 0x71, 0x0d, 0x83, 0x90, 0x12, 0x68, 0xd5, 0xa2, 0xae, 0xe4, 0x4d, 0x22, 0xcf, 0x52, 0xc2, + 0x58, 0xa7, 0xca, 0x4f, 0x3f, 0x01, 0x00, 0x00, 0xff, 0xff, 0xfd, 0x62, 0x54, 0xdf, 0x06, 0x01, + 0x00, 0x00, +} diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD deleted file mode 100644 index cf591c4c1..000000000 --- a/pkg/sentry/limits/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "limits", - srcs = [ - "context.go", - "limits.go", - "linux.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sync", - ], -) - -go_test( - name = "limits_test", - size = "small", - srcs = [ - "limits_test.go", - ], - library = ":limits", -) diff --git a/pkg/sentry/limits/limits_state_autogen.go b/pkg/sentry/limits/limits_state_autogen.go new file mode 100755 index 000000000..aa42533a9 --- /dev/null +++ b/pkg/sentry/limits/limits_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package limits + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Limit) beforeSave() {} +func (x *Limit) save(m state.Map) { + x.beforeSave() + m.Save("Cur", &x.Cur) + m.Save("Max", &x.Max) +} + +func (x *Limit) afterLoad() {} +func (x *Limit) load(m state.Map) { + m.Load("Cur", &x.Cur) + m.Load("Max", &x.Max) +} + +func (x *LimitSet) beforeSave() {} +func (x *LimitSet) save(m state.Map) { + x.beforeSave() + m.Save("data", &x.data) +} + +func (x *LimitSet) afterLoad() {} +func (x *LimitSet) load(m state.Map) { + m.Load("data", &x.data) +} + +func init() { + state.Register("pkg/sentry/limits.Limit", (*Limit)(nil), state.Fns{Save: (*Limit).save, Load: (*Limit).load}) + state.Register("pkg/sentry/limits.LimitSet", (*LimitSet)(nil), state.Fns{Save: (*LimitSet).save, Load: (*LimitSet).load}) +} diff --git a/pkg/sentry/limits/limits_test.go b/pkg/sentry/limits/limits_test.go deleted file mode 100644 index 658a20f56..000000000 --- a/pkg/sentry/limits/limits_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package limits - -import ( - "syscall" - "testing" -) - -func TestSet(t *testing.T) { - testCases := []struct { - limit Limit - privileged bool - expectedErr error - }{ - {limit: Limit{Cur: 50, Max: 50}, privileged: false, expectedErr: nil}, - {limit: Limit{Cur: 20, Max: 50}, privileged: false, expectedErr: nil}, - {limit: Limit{Cur: 20, Max: 60}, privileged: false, expectedErr: syscall.EPERM}, - {limit: Limit{Cur: 60, Max: 50}, privileged: false, expectedErr: syscall.EINVAL}, - {limit: Limit{Cur: 11, Max: 10}, privileged: false, expectedErr: syscall.EINVAL}, - {limit: Limit{Cur: 20, Max: 60}, privileged: true, expectedErr: nil}, - } - - ls := NewLimitSet() - for _, tc := range testCases { - if _, err := ls.Set(1, tc.limit, tc.privileged); err != tc.expectedErr { - t.Fatalf("Tried to set Limit to %+v and privilege %t: got %v, wanted %v", tc.limit, tc.privileged, err, tc.expectedErr) - } - } - -} diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD deleted file mode 100644 index c6aa65f28..000000000 --- a/pkg/sentry/loader/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_embed_data", "go_library") - -package(licenses = ["notice"]) - -go_embed_data( - name = "vdso_bin", - src = "//vdso:vdso.so", - package = "loader", - var = "vdsoBin", -) - -go_library( - name = "loader", - srcs = [ - "elf.go", - "interpreter.go", - "loader.go", - "vdso.go", - "vdso_state.go", - ":vdso_bin", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/cpuid", - "//pkg/log", - "//pkg/rand", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fsbridge", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/loader/loader_state_autogen.go b/pkg/sentry/loader/loader_state_autogen.go new file mode 100755 index 000000000..e28667944 --- /dev/null +++ b/pkg/sentry/loader/loader_state_autogen.go @@ -0,0 +1,57 @@ +// automatically generated by stateify. + +package loader + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *VDSO) beforeSave() {} +func (x *VDSO) save(m state.Map) { + x.beforeSave() + var phdrs []elfProgHeader = x.savePhdrs() + m.SaveValue("phdrs", phdrs) + m.Save("ParamPage", &x.ParamPage) + m.Save("vdso", &x.vdso) + m.Save("os", &x.os) + m.Save("arch", &x.arch) +} + +func (x *VDSO) afterLoad() {} +func (x *VDSO) load(m state.Map) { + m.Load("ParamPage", &x.ParamPage) + m.Load("vdso", &x.vdso) + m.Load("os", &x.os) + m.Load("arch", &x.arch) + m.LoadValue("phdrs", new([]elfProgHeader), func(y interface{}) { x.loadPhdrs(y.([]elfProgHeader)) }) +} + +func (x *elfProgHeader) beforeSave() {} +func (x *elfProgHeader) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("Flags", &x.Flags) + m.Save("Off", &x.Off) + m.Save("Vaddr", &x.Vaddr) + m.Save("Paddr", &x.Paddr) + m.Save("Filesz", &x.Filesz) + m.Save("Memsz", &x.Memsz) + m.Save("Align", &x.Align) +} + +func (x *elfProgHeader) afterLoad() {} +func (x *elfProgHeader) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("Flags", &x.Flags) + m.Load("Off", &x.Off) + m.Load("Vaddr", &x.Vaddr) + m.Load("Paddr", &x.Paddr) + m.Load("Filesz", &x.Filesz) + m.Load("Memsz", &x.Memsz) + m.Load("Align", &x.Align) +} + +func init() { + state.Register("pkg/sentry/loader.VDSO", (*VDSO)(nil), state.Fns{Save: (*VDSO).save, Load: (*VDSO).load}) + state.Register("pkg/sentry/loader.elfProgHeader", (*elfProgHeader)(nil), state.Fns{Save: (*elfProgHeader).save, Load: (*elfProgHeader).load}) +} diff --git a/pkg/sentry/loader/vdso_bin.go b/pkg/sentry/loader/vdso_bin.go new file mode 100755 index 000000000..a5e414e21 --- /dev/null +++ b/pkg/sentry/loader/vdso_bin.go @@ -0,0 +1,5 @@ +// Generated by go_embed_data for //pkg/sentry/loader:vdso_bin. DO NOT EDIT. + +package loader + +var vdsoBin = []byte("ELF\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00>\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x008\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00p\xff\xff\xff\xff\xff\x83\x00\x00\x00\x00\x00\x00\x83\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00P\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xffPp\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00P\xe5td\x00\x00\x00@\x00\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff@p\xff\xff\xff\xff\xff<\x00\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00'\x00\x00\x00\x00 \x00pp\xff\xff\xff\xff\xff&\x00\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00 \x00\xa0p\xff\xff\xff\xff\xff_\x00\x00\x00\x00\x00\x00\x00P\x00\x00\x00\x00 \x000p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00^\x00\x00\x00\"\x00 \x00\x00p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00c\x00\x00\x00\"\x00 \x00pp\xff\xff\xff\xff\xff&\x00\x00\x00\x00\x00\x00\x00q\x00\x00\x00\"\x00 \x00\xa0p\xff\xff\xff\xff\xff_\x00\x00\x00\x00\x00\x00\x00~\x00\x00\x00\"\x00 \x000p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00linux-vdso.so.1\x00LINUX_2.6\x00__vdso_time\x00__vdso_clock_gettime\x00__vdso_gettimeofday\x00__vdso_getcpu\x00time\x00clock_gettime\x00gettimeofday\x00getcpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa1\xbf\xee
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf6u\xae\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00;<\x00\x00\x00\x00\x00\x000
\x00\x00X\x00\x00\x00`
\x00\x00p\x00\x00\x00\xc0
\x00\x00\x98\x00\x00\x00\xf0
\x00\x00\xb8\x00\x00\x00\x00\x00\x00\xd0\x00\x00\x00\xa0\x00\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00zR\x00x\x90\x00\x00\x00\x00\x00\x00\x00\x00\xd0\x00\x00&\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$\x00\x00\x004\x00\x00\x00\xe8\x00\x00_\x00\x00\x00\x00BAD0\x83\x8eTAB\x00\x00\x00\\\x00\x00\x00
\x00\x00\"\x00\x00\x00\x00AD \x83[A\x00\x00\x00\x00|\x00\x00\x000
\x00\x00\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x94\x00\x00\x00(
\x00\x00\xa0\x00\x00\x00\x00A\x83\x90AM\x00\x00\x00\xb4\x00\x00\x00\xa8
\x00\x00\xa3\x00\x00\x00\x00A\x83\x90AP\x00\x00\x00\x00\x00\x00\x00`p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\xff\xffo\x00\x00\x00\x00\xd6p\xff\xff\xff\xff\xff\xfc\xff\xffo\x00\x00\x00\x00\xecp\xff\xff\xff\xff\xff\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x83\xfft\x83\xfft\x85\xffuH\x89\xf7\xe9\xba\x00\x00\x00H\x89\xf7\xe9R\x00\x00\xb8\xe4\x00\x00\x00\xc3f.\x84\x00\x00\x00\x00\x00AVSH\x83\xecI\x89\xf6H\x85\xfft:H\x89\xfbH\x8d|$\xe8\x84\x00\x00\x00\x85\xc0u7H\x8bD$H\x89H\xb8\xcf\xf7S㥛\xc4 H\xf7l$H\x89\xd0H\xc1\xe8?H\xc1\xfaH\xc2H\x89S1\xc0M\x85\xf6tI\xc7\x00\x00\x00\x00H\x83\xc4[A^ÐSH\x83\xecH\x89\xfbH\x89\xe7\xe80\x00\x00\x00H\x8b$H\x85\xdbtH\x89H\x83\xc4[\xc3f.\x84\x00\x00\x00\x00\x00@\x00\xb85\x00\x00H\x98Ð\x90\x90\x90\x90\x90SI\x89\xf8H\x8d
\xb5\xde\xff\xffH\x8b1f\x90Hc\xdeH\x83\xe3\xfeH\x8by(L\x8bY0L\x8bI8L\x8bQ@\xae\xe81H\x8b1H9\xdeu\xdcH\x85\xfftYH\xc1\xe2 \x89\xc0H \xd01\xc9L)\xd8HM\xc8H\xb8\x00\x00\x00\x00\x00ʚ;1\xd2I\xf7\xf2H\xf7\xe1H\xa4\xc2 I\xd1L\x89\xc8H\xc1\xe8 H\xb9SZ\x9b\xa0/\xb8D\x00H\xf7\xe1H\xc1\xeaHi\xc2\x00ʚ;I)\xc1I\x89M\x89H1\xc0[\xc31\xffL\x89Ƹ\xe4\x00\x00\x00[\xc3SI\x89\xf8H\x8d
\xde\xff\xffH\x8b1f\x90Hc\xdeH\x83\xe3\xfeH\x8byL\x8bYL\x8bIL\x8bQ \xae\xe81H\x8b1H9\xdeu\xdcH\x85\xfftYH\xc1\xe2 \x89\xc0H \xd01\xc9L)\xd8HM\xc8H\xb8\x00\x00\x00\x00\x00ʚ;1\xd2I\xf7\xf2H\xf7\xe1H\xa4\xc2 I\xd1L\x89\xc8H\xc1\xe8 H\xb9SZ\x9b\xa0/\xb8D\x00H\xf7\xe1H\xc1\xeaHi\xc2\x00ʚ;I)\xc1I\x89M\x89H1\xc0[ÿ\x00\x00\x00L\x89Ƹ\xe4\x00\x00\x00[\xc3\x00clang version 10.0.0 (https://github.com/llvm/llvm-project 407ac2eb5f136af5ddd213b8bcca176481ec5198)\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x00GNU\x00gold 1.11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00@p\xff\xff\xff\xff\xff\xa0\x00\x00\x00\x00\x00\x00\x009\x00\x00\x00\x00 \x00\xe0p\xff\xff\xff\xff\xff\xa3\x00\x00\x00\x00\x00\x00\x00]\x00\x00\x00\x00Pp\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00f\x00\x00\x00\x00\x00\xf1\xff\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00s\x00\x00\x00\x00\x00\xf1\xff\x00\xf0o\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00{\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00 \x00\x00p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00\x91\x00\x00\x00\x00 \x00pp\xff\xff\xff\xff\xff&\x00\x00\x00\x00\x00\x00\x00\xa6\x00\x00\x00\x00 \x00\xa0p\xff\xff\xff\xff\xff_\x00\x00\x00\x00\x00\x00\x00\xba\x00\x00\x00\x00 \x000p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\xc8\x00\x00\x00\"\x00 \x00\x00p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00\xcd\x00\x00\x00\"\x00 \x00pp\xff\xff\xff\xff\xff&\x00\x00\x00\x00\x00\x00\x00\xdb\x00\x00\x00\"\x00 \x00\xa0p\xff\xff\xff\xff\xff_\x00\x00\x00\x00\x00\x00\x00\xe8\x00\x00\x00\"\x00 \x000p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x00vdso.cc\x00vdso_time.cc\x00_ZN4vdso13ClockRealtimeEP8timespec\x00_ZN4vdso14ClockMonotonicEP8timespec\x00_DYNAMIC\x00VDSO_PRELINK\x00_params\x00LINUX_2.6\x00__vdso_time\x00__vdso_clock_gettime\x00__vdso_gettimeofday\x00__vdso_getcpu\x00time\x00clock_gettime\x00gettimeofday\x00getcpu\x00\x00.text\x00.comment\x00.dynstr\x00.eh_frame_hdr\x00.gnu.version\x00.dynsym\x00.hash\x00.note\x00.eh_frame\x00.gnu.version_d\x00.dynamic\x00.shstrtab\x00.strtab\x00.symtab\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00;\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff \x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x003\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`p\xff\xff\xff\xff\xff`\x00\x00\x00\x00\x00\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xffP\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00&\x00\x00\x00\xff\xff\xffo\x00\x00\x00\x00\x00\x00\x00\xd6p\xff\xff\xff\xff\xff\xd6\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Q\x00\x00\x00\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00\xecp\xff\xff\xff\xff\xff\xec\x00\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff@\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80p\xff\xff\xff\xff\xff\x80\x00\x00\x00\x00\x00\x00\xd0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xffP\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00pp\xff\xff\xff\xff\xffp\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x83\x00\x00\x00\x00\x00\x00f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00A\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xec\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00{\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x98\x00\x00\x00\x00\x00\x00
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00s\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa0\x00\x00\x00\x00\x00\x00\xef\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00i\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x8f\x00\x00\x00\x00\x00\x00\x83\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD deleted file mode 100644 index a98b66de1..000000000 --- a/pkg/sentry/memmap/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "mappable_range", - out = "mappable_range.go", - package = "memmap", - prefix = "Mappable", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "mapping_set_impl", - out = "mapping_set_impl.go", - package = "memmap", - prefix = "Mapping", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "MappableRange", - "Value": "MappingsOfRange", - "Functions": "mappingSetFunctions", - }, -) - -go_library( - name = "memmap", - srcs = [ - "mappable_range.go", - "mapping_set.go", - "mapping_set_impl.go", - "memmap.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/log", - "//pkg/sentry/platform", - "//pkg/syserror", - "//pkg/usermem", - ], -) - -go_test( - name = "memmap_test", - size = "small", - srcs = ["mapping_set_test.go"], - library = ":memmap", - deps = ["//pkg/usermem"], -) diff --git a/pkg/sentry/memmap/mappable_range.go b/pkg/sentry/memmap/mappable_range.go new file mode 100755 index 000000000..6b6c2c685 --- /dev/null +++ b/pkg/sentry/memmap/mappable_range.go @@ -0,0 +1,62 @@ +package memmap + +// A Range represents a contiguous range of T. +// +// +stateify savable +type MappableRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r MappableRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r MappableRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r MappableRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r MappableRange) Overlaps(r2 MappableRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r MappableRange) IsSupersetOf(r2 MappableRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r MappableRange) Intersect(r2 MappableRange) MappableRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r MappableRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/memmap/mapping_set_impl.go b/pkg/sentry/memmap/mapping_set_impl.go new file mode 100755 index 000000000..e632f28a5 --- /dev/null +++ b/pkg/sentry/memmap/mapping_set_impl.go @@ -0,0 +1,1270 @@ +package memmap + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + MappingminDegree = 3 + + MappingmaxDegree = 2 * MappingminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type MappingSet struct { + root Mappingnode `state:".(*MappingSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *MappingSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *MappingSet) IsEmptyRange(r MappableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *MappingSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *MappingSet) SpanRange(r MappableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *MappingSet) FirstSegment() MappingIterator { + if s.root.nrSegments == 0 { + return MappingIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *MappingSet) LastSegment() MappingIterator { + if s.root.nrSegments == 0 { + return MappingIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *MappingSet) FirstGap() MappingGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return MappingGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *MappingSet) LastGap() MappingGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return MappingGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *MappingSet) Find(key uint64) (MappingIterator, MappingGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return MappingIterator{n, i}, MappingGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return MappingIterator{}, MappingGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *MappingSet) FindSegment(key uint64) MappingIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *MappingSet) LowerBoundSegment(min uint64) MappingIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *MappingSet) UpperBoundSegment(max uint64) MappingIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *MappingSet) FindGap(key uint64) MappingGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *MappingSet) LowerBoundGap(min uint64) MappingGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *MappingSet) UpperBoundGap(max uint64) MappingGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *MappingSet) Add(r MappableRange, val MappingsOfRange) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *MappingSet) AddWithoutMerging(r MappableRange, val MappingsOfRange) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *MappingSet) Insert(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (mappingSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *MappingSet) InsertWithoutMerging(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *MappingSet) InsertWithoutMergingUnchecked(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return MappingIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *MappingSet) Remove(seg MappingIterator) MappingGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + mappingSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(MappingGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *MappingSet) RemoveAll() { + s.root = Mappingnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *MappingSet) RemoveRange(r MappableRange) MappingGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *MappingSet) Merge(first, second MappingIterator) MappingIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *MappingSet) MergeUnchecked(first, second MappingIterator) MappingIterator { + if first.End() == second.Start() { + if mval, ok := (mappingSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return MappingIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *MappingSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *MappingSet) MergeRange(r MappableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *MappingSet) MergeAdjacent(r MappableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *MappingSet) Split(seg MappingIterator, split uint64) (MappingIterator, MappingIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *MappingSet) SplitUnchecked(seg MappingIterator, split uint64) (MappingIterator, MappingIterator) { + val1, val2 := (mappingSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), MappableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *MappingSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *MappingSet) Isolate(seg MappingIterator, r MappableRange) MappingIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *MappingSet) ApplyContiguous(r MappableRange, fn func(seg MappingIterator)) MappingGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return MappingGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return MappingGapIterator{} + } + } +} + +// +stateify savable +type Mappingnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Mappingnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [MappingmaxDegree - 1]MappableRange + values [MappingmaxDegree - 1]MappingsOfRange + children [MappingmaxDegree]*Mappingnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Mappingnode) firstSegment() MappingIterator { + for n.hasChildren { + n = n.children[0] + } + return MappingIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Mappingnode) lastSegment() MappingIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return MappingIterator{n, n.nrSegments - 1} +} + +func (n *Mappingnode) prevSibling() *Mappingnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Mappingnode) nextSibling() *Mappingnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Mappingnode) rebalanceBeforeInsert(gap MappingGapIterator) MappingGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < MappingmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:MappingminDegree-1], n.keys[:MappingminDegree-1]) + copy(left.values[:MappingminDegree-1], n.values[:MappingminDegree-1]) + copy(right.keys[:MappingminDegree-1], n.keys[MappingminDegree:]) + copy(right.values[:MappingminDegree-1], n.values[MappingminDegree:]) + n.keys[0], n.values[0] = n.keys[MappingminDegree-1], n.values[MappingminDegree-1] + MappingzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:MappingminDegree], n.children[:MappingminDegree]) + copy(right.children[:MappingminDegree], n.children[MappingminDegree:]) + MappingzeroNodeSlice(n.children[2:]) + for i := 0; i < MappingminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < MappingminDegree { + return MappingGapIterator{left, gap.index} + } + return MappingGapIterator{right, gap.index - MappingminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[MappingminDegree-1], n.values[MappingminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:MappingminDegree-1], n.keys[MappingminDegree:]) + copy(sibling.values[:MappingminDegree-1], n.values[MappingminDegree:]) + MappingzeroValueSlice(n.values[MappingminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:MappingminDegree], n.children[MappingminDegree:]) + MappingzeroNodeSlice(n.children[MappingminDegree:]) + for i := 0; i < MappingminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = MappingminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < MappingminDegree { + return gap + } + return MappingGapIterator{sibling, gap.index - MappingminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Mappingnode) rebalanceAfterRemove(gap MappingGapIterator) MappingGapIterator { + for { + if n.nrSegments >= MappingminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= MappingminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + mappingSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return MappingGapIterator{n, 0} + } + if gap.node == n { + return MappingGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= MappingminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + mappingSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return MappingGapIterator{n, n.nrSegments} + } + return MappingGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return MappingGapIterator{p, gap.index} + } + if gap.node == right { + return MappingGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Mappingnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = MappingGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + mappingSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type MappingIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Mappingnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg MappingIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg MappingIterator) Range() MappableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg MappingIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg MappingIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg MappingIterator) SetRangeUnchecked(r MappableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetRange(r MappableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg MappingIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg MappingIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg MappingIterator) Value() MappingsOfRange { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg MappingIterator) ValuePtr() *MappingsOfRange { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg MappingIterator) SetValue(val MappingsOfRange) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg MappingIterator) PrevSegment() MappingIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return MappingIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return MappingIterator{} + } + return MappingsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg MappingIterator) NextSegment() MappingIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return MappingIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return MappingIterator{} + } + return MappingsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg MappingIterator) PrevGap() MappingGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return MappingGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg MappingIterator) NextGap() MappingGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return MappingGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg MappingIterator) PrevNonEmpty() (MappingIterator, MappingGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return MappingIterator{}, gap + } + return gap.PrevSegment(), MappingGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg MappingIterator) NextNonEmpty() (MappingIterator, MappingGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return MappingIterator{}, gap + } + return gap.NextSegment(), MappingGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type MappingGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Mappingnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap MappingGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap MappingGapIterator) Range() MappableRange { + return MappableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap MappingGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return mappingSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap MappingGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return mappingSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap MappingGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap MappingGapIterator) PrevSegment() MappingIterator { + return MappingsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap MappingGapIterator) NextSegment() MappingIterator { + return MappingsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap MappingGapIterator) PrevGap() MappingGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return MappingGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap MappingGapIterator) NextGap() MappingGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return MappingGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func MappingsegmentBeforePosition(n *Mappingnode, i int) MappingIterator { + for i == 0 { + if n.parent == nil { + return MappingIterator{} + } + n, i = n.parent, n.parentIndex + } + return MappingIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func MappingsegmentAfterPosition(n *Mappingnode, i int) MappingIterator { + for i == n.nrSegments { + if n.parent == nil { + return MappingIterator{} + } + n, i = n.parent, n.parentIndex + } + return MappingIterator{n, i} +} + +func MappingzeroValueSlice(slice []MappingsOfRange) { + + for i := range slice { + mappingSetFunctions{}.ClearValue(&slice[i]) + } +} + +func MappingzeroNodeSlice(slice []*Mappingnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *MappingSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *Mappingnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Mappingnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type MappingSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []MappingsOfRange +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *MappingSet) ExportSortedSlices() *MappingSegmentDataSlices { + var sds MappingSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *MappingSet) ImportSortedSlices(sds *MappingSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := MappableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *MappingSet) saveRoot() *MappingSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *MappingSet) loadRoot(sds *MappingSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/memmap/mapping_set_test.go b/pkg/sentry/memmap/mapping_set_test.go deleted file mode 100644 index d39efe38f..000000000 --- a/pkg/sentry/memmap/mapping_set_test.go +++ /dev/null @@ -1,260 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memmap - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/usermem" -) - -type testMappingSpace struct { - // Ideally we'd store the full ranges that were invalidated, rather - // than individual calls to Invalidate, as they are an implementation - // detail, but this is the simplest way for now. - inv []usermem.AddrRange -} - -func (n *testMappingSpace) reset() { - n.inv = []usermem.AddrRange{} -} - -func (n *testMappingSpace) Invalidate(ar usermem.AddrRange, opts InvalidateOpts) { - n.inv = append(n.inv, ar) -} - -func TestAddRemoveMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings (usermem.AddrRanges => memmap.MappableRange): - // [0x10000, 0x12000) => [0x1000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(mapped) != 0 { - t.Errorf("AddMapping: got %+v, wanted []", mapped) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true) - if got, want := mapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x30000, 0x31000) => [0x4000, 0x5000) - t.Log(&set) - - mapped = set.AddMapping(ms, usermem.AddrRange{0x12000, 0x15000}, 0x3000, true) - if got, want := mapped, []MappableRange{{0x3000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x12000, 0x13000) => [0x3000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0x1000, true) - if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x12000, 0x13000) => [0x3000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Mappings: - // [0x11000, 0x13000) => [0x2000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x15000}, 0x2000, true) - if got, want := unmapped, []MappableRange{{0x2000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x30000, 0x31000) => [0x4000, 0x5000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true) - if got, want := unmapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateWholeMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true) - // Mappings: - // [0x10000, 0x11000) => [0, 0x1000) - t.Log(&set) - set.Invalidate(MappableRange{0, 0x1000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidatePartialMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x13000}, 0, true) - // Mappings: - // [0x10000, 0x13000) => [0, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateMultipleMappings(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true) - set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) - // Mappings: - // [0x10000, 0x11000) => [0, 0x1000) - // [0x12000, 0x13000) => [0x2000, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0, 0x3000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}, {0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateOverlappingMappings(t *testing.T) { - set := MappingSet{} - ms1 := &testMappingSpace{} - ms2 := &testMappingSpace{} - - set.AddMapping(ms1, usermem.AddrRange{0x10000, 0x12000}, 0, true) - set.AddMapping(ms2, usermem.AddrRange{0x20000, 0x22000}, 0x1000, true) - // Mappings: - // ms1:[0x10000, 0x12000) => [0, 0x2000) - // ms2:[0x11000, 0x13000) => [0x1000, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{}) - if got, want := ms1.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want) - } - if got, want := ms2.inv, []usermem.AddrRange{{0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want) - } -} - -func TestMixedWritableMappings(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x22000}, 0x2000, false) - if got, want := mapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) writable => [0x1000, 0x2000) - // [0x11000, 0x12000) writable and [0x20000, 0x21000) readonly => [0x2000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - // Unmap should fail because we specified the readonly map address range, but - // asked to unmap a writable segment. - unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Readonly mapping removed, but writable mapping still exists in the range, - // so no mappable range fully unmapped. - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, false) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x12000}, 0x2000, true) - if got, want := unmapped, []MappableRange{{0x2000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - // Unmap should fail since writable bit doesn't match. - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, false) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x21000, 0x22000}, 0x3000, false) - if got, want := unmapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } -} diff --git a/pkg/sentry/memmap/memmap_impl_state_autogen.go b/pkg/sentry/memmap/memmap_impl_state_autogen.go new file mode 100755 index 000000000..b231fd9c3 --- /dev/null +++ b/pkg/sentry/memmap/memmap_impl_state_autogen.go @@ -0,0 +1,63 @@ +// automatically generated by stateify. + +package memmap + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *MappingSet) beforeSave() {} +func (x *MappingSet) save(m state.Map) { + x.beforeSave() + var root *MappingSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *MappingSet) afterLoad() {} +func (x *MappingSet) load(m state.Map) { + m.LoadValue("root", new(*MappingSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*MappingSegmentDataSlices)) }) +} + +func (x *Mappingnode) beforeSave() {} +func (x *Mappingnode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *Mappingnode) afterLoad() {} +func (x *Mappingnode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *MappingSegmentDataSlices) beforeSave() {} +func (x *MappingSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *MappingSegmentDataSlices) afterLoad() {} +func (x *MappingSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("pkg/sentry/memmap.MappingSet", (*MappingSet)(nil), state.Fns{Save: (*MappingSet).save, Load: (*MappingSet).load}) + state.Register("pkg/sentry/memmap.Mappingnode", (*Mappingnode)(nil), state.Fns{Save: (*Mappingnode).save, Load: (*Mappingnode).load}) + state.Register("pkg/sentry/memmap.MappingSegmentDataSlices", (*MappingSegmentDataSlices)(nil), state.Fns{Save: (*MappingSegmentDataSlices).save, Load: (*MappingSegmentDataSlices).load}) +} diff --git a/pkg/sentry/memmap/memmap_state_autogen.go b/pkg/sentry/memmap/memmap_state_autogen.go new file mode 100755 index 000000000..2072dbad2 --- /dev/null +++ b/pkg/sentry/memmap/memmap_state_autogen.go @@ -0,0 +1,40 @@ +// automatically generated by stateify. + +package memmap + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *MappableRange) beforeSave() {} +func (x *MappableRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *MappableRange) afterLoad() {} +func (x *MappableRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *MappingOfRange) beforeSave() {} +func (x *MappingOfRange) save(m state.Map) { + x.beforeSave() + m.Save("MappingSpace", &x.MappingSpace) + m.Save("AddrRange", &x.AddrRange) + m.Save("Writable", &x.Writable) +} + +func (x *MappingOfRange) afterLoad() {} +func (x *MappingOfRange) load(m state.Map) { + m.Load("MappingSpace", &x.MappingSpace) + m.Load("AddrRange", &x.AddrRange) + m.Load("Writable", &x.Writable) +} + +func init() { + state.Register("pkg/sentry/memmap.MappableRange", (*MappableRange)(nil), state.Fns{Save: (*MappableRange).save, Load: (*MappableRange).load}) + state.Register("pkg/sentry/memmap.MappingOfRange", (*MappingOfRange)(nil), state.Fns{Save: (*MappingOfRange).save, Load: (*MappingOfRange).load}) +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD deleted file mode 100644 index 73591dab7..000000000 --- a/pkg/sentry/mm/BUILD +++ /dev/null @@ -1,141 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "file_refcount_set", - out = "file_refcount_set.go", - imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "mm", - prefix = "fileRefcount", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "platform.FileRange", - "Value": "int32", - "Functions": "fileRefcountSetFunctions", - }, -) - -go_template_instance( - name = "vma_set", - out = "vma_set.go", - consts = { - "minDegree": "8", - }, - imports = { - "usermem": "gvisor.dev/gvisor/pkg/usermem", - }, - package = "mm", - prefix = "vma", - template = "//pkg/segment:generic_set", - types = { - "Key": "usermem.Addr", - "Range": "usermem.AddrRange", - "Value": "vma", - "Functions": "vmaSetFunctions", - }, -) - -go_template_instance( - name = "pma_set", - out = "pma_set.go", - consts = { - "minDegree": "8", - }, - imports = { - "usermem": "gvisor.dev/gvisor/pkg/usermem", - }, - package = "mm", - prefix = "pma", - template = "//pkg/segment:generic_set", - types = { - "Key": "usermem.Addr", - "Range": "usermem.AddrRange", - "Value": "pma", - "Functions": "pmaSetFunctions", - }, -) - -go_template_instance( - name = "io_list", - out = "io_list.go", - package = "mm", - prefix = "io", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*ioResult", - "Linker": "*ioResult", - }, -) - -go_library( - name = "mm", - srcs = [ - "address_space.go", - "aio_context.go", - "aio_context_state.go", - "debug.go", - "file_refcount_set.go", - "io.go", - "io_list.go", - "lifecycle.go", - "metadata.go", - "mm.go", - "pma.go", - "pma_set.go", - "procfs.go", - "save_restore.go", - "shm.go", - "special_mappable.go", - "syscalls.go", - "vma.go", - "vma_set.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/atomicbitops", - "//pkg/context", - "//pkg/log", - "//pkg/refs", - "//pkg/safecopy", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs/proc/seqfile", - "//pkg/sentry/fsbridge", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/futex", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/syserror", - "//pkg/tcpip/buffer", - "//pkg/usermem", - ], -) - -go_test( - name = "mm_test", - size = "small", - srcs = ["mm_test.go"], - library = ":mm", - deps = [ - "//pkg/context", - "//pkg/sentry/arch", - "//pkg/sentry/contexttest", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/mm/README.md b/pkg/sentry/mm/README.md deleted file mode 100644 index f4d43d927..000000000 --- a/pkg/sentry/mm/README.md +++ /dev/null @@ -1,280 +0,0 @@ -This package provides an emulation of Linux semantics for application virtual -memory mappings. - -For completeness, this document also describes aspects of the memory management -subsystem defined outside this package. - -# Background - -We begin by describing semantics for virtual memory in Linux. - -A virtual address space is defined as a collection of mappings from virtual -addresses to physical memory. However, userspace applications do not configure -mappings to physical memory directly. Instead, applications configure memory -mappings from virtual addresses to offsets into a file using the `mmap` system -call.[^mmap-anon] For example, a call to: - - mmap( - /* addr = */ 0x400000, - /* length = */ 0x1000, - PROT_READ | PROT_WRITE, - MAP_SHARED, - /* fd = */ 3, - /* offset = */ 0); - -creates a mapping of length 0x1000 bytes, starting at virtual address (VA) -0x400000, to offset 0 in the file represented by file descriptor (FD) 3. Within -the Linux kernel, virtual memory mappings are represented by *virtual memory -areas* (VMAs). Supposing that FD 3 represents file /tmp/foo, the state of the -virtual memory subsystem after the `mmap` call may be depicted as: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - -Establishing a virtual memory area does not necessarily establish a mapping to a -physical address, because Linux has not necessarily provisioned physical memory -to store the file's contents. Thus, if the application attempts to read the -contents of VA 0x400000, it may incur a *page fault*, a CPU exception that -forces the kernel to create such a mapping to service the read. - -For a file, doing so consists of several logical phases: - -1. The kernel allocates physical memory to store the contents of the required - part of the file, and copies file contents to the allocated memory. - Supposing that the kernel chooses the physical memory at physical address - (PA) 0x2fb000, the resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - - (In Linux the state of the mapping from file offset to physical memory is - stored in `struct address_space`, but to avoid confusion with other notions - of address space we will refer to this system as filemap, named after Linux - kernel source file `mm/filemap.c`.) - -2. The kernel stores the effective mapping from virtual to physical address in - a *page table entry* (PTE) in the application's *page tables*, which are - used by the CPU's virtual memory hardware to perform address translation. - The resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x2fb000 - - The PTE is required for the application to actually use the contents of the - mapped file as virtual memory. However, the PTE is derived from the VMA and - filemap state, both of which are independently mutable, such that mutations - to either will affect the PTE. For example: - - - The application may remove the VMA using the `munmap` system call. This - breaks the mapping from VA:0x400000 to /tmp/foo:0x0, and consequently - the mapping from VA:0x400000 to PA:0x2fb000. However, it does not - necessarily break the mapping from /tmp/foo:0x0 to PA:0x2fb000, so a - future mapping of the same file offset may reuse this physical memory. - - - The application may invalidate the file's contents by passing a length - of 0 to the `ftruncate` system call. This breaks the mapping from - /tmp/foo:0x0 to PA:0x2fb000, and consequently the mapping from - VA:0x400000 to PA:0x2fb000. However, it does not break the mapping from - VA:0x400000 to /tmp/foo:0x0, so future changes to the file's contents - may again be made visible at VA:0x400000 after another page fault - results in the allocation of a new physical address. - - Note that, in order to correctly break the mapping from VA:0x400000 to - PA:0x2fb000 in the latter case, filemap must also store a *reverse mapping* - from /tmp/foo:0x0 to VA:0x400000 so that it can locate and remove the PTE. - -[^mmap-anon]: Memory mappings to non-files are discussed in later sections. - -## Private Mappings - -The preceding example considered VMAs created using the `MAP_SHARED` flag, which -means that PTEs derived from the mapping should always use physical memory that -represents the current state of the mapped file.[^mmap-dev-zero] Applications -can alternatively pass the `MAP_PRIVATE` flag to create a *private mapping*. -Private mappings are *copy-on-write*. - -Suppose that the application instead created a private mapping in the previous -example. In Linux, the state of the system after a read page fault would be: - - VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x2fb000 (read-only) - -Now suppose the application attempts to write to VA:0x400000. For a shared -mapping, the write would be propagated to PA:0x2fb000, and the kernel would be -responsible for ensuring that the write is later propagated to the mapped file. -For a private mapping, the write incurs another page fault since the PTE is -marked read-only. In response, the kernel allocates physical memory to store the -mapping's *private copy* of the file's contents, copies file contents to the -allocated memory, and changes the PTE to map to the private copy. Supposing that -the kernel chooses the physical memory at physical address (PA) 0x5ea000, the -resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x5ea000 - -Note that the filemap mapping from /tmp/foo:0x0 to PA:0x2fb000 may still exist, -but is now irrelevant to this mapping. - -[^mmap-dev-zero]: Modulo files with special mmap semantics such as `/dev/zero`. - -## Anonymous Mappings - -Instead of passing a file to the `mmap` system call, applications can instead -request an *anonymous* mapping by passing the `MAP_ANONYMOUS` flag. -Semantically, an anonymous mapping is essentially a mapping to an ephemeral file -initially filled with zero bytes. Practically speaking, this is how shared -anonymous mappings are implemented, but private anonymous mappings do not result -in the creation of an ephemeral file; since there would be no way to modify the -contents of the underlying file through a private mapping, all private anonymous -mappings use a single shared page filled with zero bytes until copy-on-write -occurs. - -# Virtual Memory in the Sentry - -The sentry implements application virtual memory atop a host kernel, introducing -an additional level of indirection to the above. - -Consider the same scenario as in the previous section. Since the sentry handles -application system calls, the effect of an application `mmap` system call is to -create a VMA in the sentry (as opposed to the host kernel): - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - -When the application first incurs a page fault on this address, the host kernel -delivers information about the page fault to the sentry in a platform-dependent -manner, and the sentry handles the fault: - -1. The sentry allocates memory to store the contents of the required part of - the file, and copies file contents to the allocated memory. However, since - the sentry is implemented atop a host kernel, it does not configure mappings - to physical memory directly. Instead, mappable "memory" in the sentry is - represented by a host file descriptor and offset, since (as noted in - "Background") this is the memory mapping primitive provided by the host - kernel. In general, memory is allocated from a temporary host file using the - `pgalloc` package. Supposing that the sentry allocates offset 0x3000 from - host file "memory-file", the resulting state is: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - -2. The sentry stores the effective mapping from virtual address to host file in - a host VMA by invoking the `mmap` system call: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 - -3. The sentry returns control to the application, which immediately incurs the - page fault again.[^mmap-populate] However, since a host VMA now exists for - the faulting virtual address, the host kernel now handles the page fault as - described in "Background": - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 - Host filemap: host:memory-file:0x3000 -> PA:0x2fb000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x2fb000 - -Thus, from an implementation standpoint, host VMAs serve the same purpose in the -sentry that PTEs do in Linux. As in Linux, sentry VMA and filemap state is -independently mutable, and the desired state of host VMAs is derived from that -state. - -[^mmap-populate]: The sentry could force the host kernel to establish PTEs when - it creates the host VMA by passing the `MAP_POPULATE` flag to - the `mmap` system call, but usually does not. This is because, - to reduce the number of page faults that require handling by - the sentry and (correspondingly) the number of host `mmap` - system calls, the sentry usually creates host VMAs that are - much larger than the single faulting page. - -## Private Mappings - -The sentry implements private mappings consistently with Linux. Before -copy-on-write, the private mapping example given in the Background results in: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 (read-only) - Host filemap: host:memory-file:0x3000 -> PA:0x2fb000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x2fb000 (read-only) - -When the application attempts to write to this address, the host kernel delivers -information about the resulting page fault to the sentry. Analogous to Linux, -the sentry allocates memory to store the mapping's private copy of the file's -contents, copies file contents to the allocated memory, and changes the host VMA -to map to the private copy. Supposing that the sentry chooses the offset 0x4000 -in host file `memory-file` to store the private copy, the state of the system -after copy-on-write is: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x4000 - Host filemap: host:memory-file:0x4000 -> PA:0x5ea000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x5ea000 - -However, this highlights an important difference between Linux and the sentry. -In Linux, page tables are concrete (architecture-dependent) data structures -owned by the kernel. Conversely, the sentry has the ability to create and -destroy host VMAs using host system calls, but it does not have direct access to -their state. Thus, as written, if the application invokes the `munmap` system -call to remove the sentry VMA, it is non-trivial for the sentry to determine -that it should deallocate `host:memory-file:0x4000`. This implies that the -sentry must retain information about the host VMAs that it has created. - -## Anonymous Mappings - -The sentry implements anonymous mappings consistently with Linux, except that -there is no shared zero page. - -# Implementation Constructs - -In Linux: - -- A virtual address space is represented by `struct mm_struct`. - -- VMAs are represented by `struct vm_area_struct`, stored in `struct - mm_struct::mmap`. - -- Mappings from file offsets to physical memory are stored in `struct - address_space`. - -- Reverse mappings from file offsets to virtual mappings are stored in `struct - address_space::i_mmap`. - -- Physical memory pages are represented by a pointer to `struct page` or an - index called a *page frame number* (PFN), represented by `pfn_t`. - -- PTEs are represented by architecture-dependent type `pte_t`, stored in a - table hierarchy rooted at `struct mm_struct::pgd`. - -In the sentry: - -- A virtual address space is represented by type [`mm.MemoryManager`][mm]. - -- Sentry VMAs are represented by type [`mm.vma`][mm], stored in - `mm.MemoryManager.vmas`. - -- Mappings from sentry file offsets to host file offsets are abstracted - through interface method [`memmap.Mappable.Translate`][memmap]. - -- Reverse mappings from sentry file offsets to virtual mappings are abstracted - through interface methods - [`memmap.Mappable.AddMapping` and `memmap.Mappable.RemoveMapping`][memmap]. - -- Host files that may be mapped into host VMAs are represented by type - [`platform.File`][platform]. - -- Host VMAs are represented in the sentry by type [`mm.pma`][mm] ("platform - mapping area"), stored in `mm.MemoryManager.pmas`. - -- Creation and destruction of host VMAs is abstracted through interface - methods - [`platform.AddressSpace.MapFile` and `platform.AddressSpace.Unmap`][platform]. - -[memmap]: https://github.com/google/gvisor/blob/master/pkg/sentry/memmap/memmap.go -[mm]: https://github.com/google/gvisor/blob/master/pkg/sentry/mm/mm.go -[pgalloc]: https://github.com/google/gvisor/blob/master/pkg/sentry/pgalloc/pgalloc.go -[platform]: https://github.com/google/gvisor/blob/master/pkg/sentry/platform/platform.go diff --git a/pkg/sentry/mm/file_refcount_set.go b/pkg/sentry/mm/file_refcount_set.go new file mode 100755 index 000000000..6b3081009 --- /dev/null +++ b/pkg/sentry/mm/file_refcount_set.go @@ -0,0 +1,1274 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/platform" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + fileRefcountminDegree = 3 + + fileRefcountmaxDegree = 2 * fileRefcountminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type fileRefcountSet struct { + root fileRefcountnode `state:".(*fileRefcountSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *fileRefcountSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *fileRefcountSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *fileRefcountSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *fileRefcountSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *fileRefcountSet) FirstSegment() fileRefcountIterator { + if s.root.nrSegments == 0 { + return fileRefcountIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *fileRefcountSet) LastSegment() fileRefcountIterator { + if s.root.nrSegments == 0 { + return fileRefcountIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *fileRefcountSet) FirstGap() fileRefcountGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return fileRefcountGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *fileRefcountSet) LastGap() fileRefcountGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return fileRefcountGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *fileRefcountSet) Find(key uint64) (fileRefcountIterator, fileRefcountGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return fileRefcountIterator{n, i}, fileRefcountGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return fileRefcountIterator{}, fileRefcountGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *fileRefcountSet) FindSegment(key uint64) fileRefcountIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *fileRefcountSet) LowerBoundSegment(min uint64) fileRefcountIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *fileRefcountSet) UpperBoundSegment(max uint64) fileRefcountIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *fileRefcountSet) FindGap(key uint64) fileRefcountGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *fileRefcountSet) LowerBoundGap(min uint64) fileRefcountGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *fileRefcountSet) UpperBoundGap(max uint64) fileRefcountGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *fileRefcountSet) Add(r __generics_imported0.FileRange, val int32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *fileRefcountSet) AddWithoutMerging(r __generics_imported0.FileRange, val int32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *fileRefcountSet) Insert(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (fileRefcountSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (fileRefcountSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (fileRefcountSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *fileRefcountSet) InsertWithoutMerging(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *fileRefcountSet) InsertWithoutMergingUnchecked(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return fileRefcountIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *fileRefcountSet) Remove(seg fileRefcountIterator) fileRefcountGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + fileRefcountSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(fileRefcountGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *fileRefcountSet) RemoveAll() { + s.root = fileRefcountnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *fileRefcountSet) RemoveRange(r __generics_imported0.FileRange) fileRefcountGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *fileRefcountSet) Merge(first, second fileRefcountIterator) fileRefcountIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *fileRefcountSet) MergeUnchecked(first, second fileRefcountIterator) fileRefcountIterator { + if first.End() == second.Start() { + if mval, ok := (fileRefcountSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return fileRefcountIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *fileRefcountSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *fileRefcountSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *fileRefcountSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *fileRefcountSet) Split(seg fileRefcountIterator, split uint64) (fileRefcountIterator, fileRefcountIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *fileRefcountSet) SplitUnchecked(seg fileRefcountIterator, split uint64) (fileRefcountIterator, fileRefcountIterator) { + val1, val2 := (fileRefcountSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *fileRefcountSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *fileRefcountSet) Isolate(seg fileRefcountIterator, r __generics_imported0.FileRange) fileRefcountIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *fileRefcountSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg fileRefcountIterator)) fileRefcountGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return fileRefcountGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return fileRefcountGapIterator{} + } + } +} + +// +stateify savable +type fileRefcountnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *fileRefcountnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [fileRefcountmaxDegree - 1]__generics_imported0.FileRange + values [fileRefcountmaxDegree - 1]int32 + children [fileRefcountmaxDegree]*fileRefcountnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *fileRefcountnode) firstSegment() fileRefcountIterator { + for n.hasChildren { + n = n.children[0] + } + return fileRefcountIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *fileRefcountnode) lastSegment() fileRefcountIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return fileRefcountIterator{n, n.nrSegments - 1} +} + +func (n *fileRefcountnode) prevSibling() *fileRefcountnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *fileRefcountnode) nextSibling() *fileRefcountnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *fileRefcountnode) rebalanceBeforeInsert(gap fileRefcountGapIterator) fileRefcountGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < fileRefcountmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:fileRefcountminDegree-1], n.keys[:fileRefcountminDegree-1]) + copy(left.values[:fileRefcountminDegree-1], n.values[:fileRefcountminDegree-1]) + copy(right.keys[:fileRefcountminDegree-1], n.keys[fileRefcountminDegree:]) + copy(right.values[:fileRefcountminDegree-1], n.values[fileRefcountminDegree:]) + n.keys[0], n.values[0] = n.keys[fileRefcountminDegree-1], n.values[fileRefcountminDegree-1] + fileRefcountzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:fileRefcountminDegree], n.children[:fileRefcountminDegree]) + copy(right.children[:fileRefcountminDegree], n.children[fileRefcountminDegree:]) + fileRefcountzeroNodeSlice(n.children[2:]) + for i := 0; i < fileRefcountminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < fileRefcountminDegree { + return fileRefcountGapIterator{left, gap.index} + } + return fileRefcountGapIterator{right, gap.index - fileRefcountminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[fileRefcountminDegree-1], n.values[fileRefcountminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:fileRefcountminDegree-1], n.keys[fileRefcountminDegree:]) + copy(sibling.values[:fileRefcountminDegree-1], n.values[fileRefcountminDegree:]) + fileRefcountzeroValueSlice(n.values[fileRefcountminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:fileRefcountminDegree], n.children[fileRefcountminDegree:]) + fileRefcountzeroNodeSlice(n.children[fileRefcountminDegree:]) + for i := 0; i < fileRefcountminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = fileRefcountminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < fileRefcountminDegree { + return gap + } + return fileRefcountGapIterator{sibling, gap.index - fileRefcountminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *fileRefcountnode) rebalanceAfterRemove(gap fileRefcountGapIterator) fileRefcountGapIterator { + for { + if n.nrSegments >= fileRefcountminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= fileRefcountminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + fileRefcountSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return fileRefcountGapIterator{n, 0} + } + if gap.node == n { + return fileRefcountGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= fileRefcountminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + fileRefcountSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return fileRefcountGapIterator{n, n.nrSegments} + } + return fileRefcountGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return fileRefcountGapIterator{p, gap.index} + } + if gap.node == right { + return fileRefcountGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *fileRefcountnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = fileRefcountGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + fileRefcountSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type fileRefcountIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *fileRefcountnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg fileRefcountIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg fileRefcountIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg fileRefcountIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg fileRefcountIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg fileRefcountIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg fileRefcountIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg fileRefcountIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg fileRefcountIterator) Value() int32 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg fileRefcountIterator) ValuePtr() *int32 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg fileRefcountIterator) SetValue(val int32) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg fileRefcountIterator) PrevSegment() fileRefcountIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return fileRefcountIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return fileRefcountIterator{} + } + return fileRefcountsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg fileRefcountIterator) NextSegment() fileRefcountIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return fileRefcountIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return fileRefcountIterator{} + } + return fileRefcountsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg fileRefcountIterator) PrevGap() fileRefcountGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return fileRefcountGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg fileRefcountIterator) NextGap() fileRefcountGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return fileRefcountGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg fileRefcountIterator) PrevNonEmpty() (fileRefcountIterator, fileRefcountGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return fileRefcountIterator{}, gap + } + return gap.PrevSegment(), fileRefcountGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg fileRefcountIterator) NextNonEmpty() (fileRefcountIterator, fileRefcountGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return fileRefcountIterator{}, gap + } + return gap.NextSegment(), fileRefcountGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type fileRefcountGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *fileRefcountnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap fileRefcountGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap fileRefcountGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap fileRefcountGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return fileRefcountSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap fileRefcountGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return fileRefcountSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap fileRefcountGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap fileRefcountGapIterator) PrevSegment() fileRefcountIterator { + return fileRefcountsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap fileRefcountGapIterator) NextSegment() fileRefcountIterator { + return fileRefcountsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap fileRefcountGapIterator) PrevGap() fileRefcountGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return fileRefcountGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap fileRefcountGapIterator) NextGap() fileRefcountGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return fileRefcountGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func fileRefcountsegmentBeforePosition(n *fileRefcountnode, i int) fileRefcountIterator { + for i == 0 { + if n.parent == nil { + return fileRefcountIterator{} + } + n, i = n.parent, n.parentIndex + } + return fileRefcountIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func fileRefcountsegmentAfterPosition(n *fileRefcountnode, i int) fileRefcountIterator { + for i == n.nrSegments { + if n.parent == nil { + return fileRefcountIterator{} + } + n, i = n.parent, n.parentIndex + } + return fileRefcountIterator{n, i} +} + +func fileRefcountzeroValueSlice(slice []int32) { + + for i := range slice { + fileRefcountSetFunctions{}.ClearValue(&slice[i]) + } +} + +func fileRefcountzeroNodeSlice(slice []*fileRefcountnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *fileRefcountSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *fileRefcountnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *fileRefcountnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type fileRefcountSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []int32 +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *fileRefcountSet) ExportSortedSlices() *fileRefcountSegmentDataSlices { + var sds fileRefcountSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *fileRefcountSet) ImportSortedSlices(sds *fileRefcountSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *fileRefcountSet) saveRoot() *fileRefcountSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *fileRefcountSet) loadRoot(sds *fileRefcountSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/mm/io_list.go b/pkg/sentry/mm/io_list.go new file mode 100755 index 000000000..287e4305c --- /dev/null +++ b/pkg/sentry/mm/io_list.go @@ -0,0 +1,186 @@ +package mm + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type ioElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (ioElementMapper) linkerFor(elem *ioResult) *ioResult { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type ioList struct { + head *ioResult + tail *ioResult +} + +// Reset resets list l to the empty state. +func (l *ioList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *ioList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *ioList) Front() *ioResult { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *ioList) Back() *ioResult { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *ioList) PushFront(e *ioResult) { + linker := ioElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + ioElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *ioList) PushBack(e *ioResult) { + linker := ioElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + ioElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *ioList) PushBackList(m *ioList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + ioElementMapper{}.linkerFor(l.tail).SetNext(m.head) + ioElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *ioList) InsertAfter(b, e *ioResult) { + bLinker := ioElementMapper{}.linkerFor(b) + eLinker := ioElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + ioElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *ioList) InsertBefore(a, e *ioResult) { + aLinker := ioElementMapper{}.linkerFor(a) + eLinker := ioElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + ioElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *ioList) Remove(e *ioResult) { + linker := ioElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + ioElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + ioElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type ioEntry struct { + next *ioResult + prev *ioResult +} + +// Next returns the entry that follows e in the list. +func (e *ioEntry) Next() *ioResult { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *ioEntry) Prev() *ioResult { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *ioEntry) SetNext(elem *ioResult) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *ioEntry) SetPrev(elem *ioResult) { + e.prev = elem +} diff --git a/pkg/sentry/mm/mm_state_autogen.go b/pkg/sentry/mm/mm_state_autogen.go new file mode 100755 index 000000000..ef95c2836 --- /dev/null +++ b/pkg/sentry/mm/mm_state_autogen.go @@ -0,0 +1,396 @@ +// automatically generated by stateify. + +package mm + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *aioManager) beforeSave() {} +func (x *aioManager) save(m state.Map) { + x.beforeSave() + m.Save("contexts", &x.contexts) +} + +func (x *aioManager) afterLoad() {} +func (x *aioManager) load(m state.Map) { + m.Load("contexts", &x.contexts) +} + +func (x *ioResult) beforeSave() {} +func (x *ioResult) save(m state.Map) { + x.beforeSave() + m.Save("data", &x.data) + m.Save("ioEntry", &x.ioEntry) +} + +func (x *ioResult) afterLoad() {} +func (x *ioResult) load(m state.Map) { + m.Load("data", &x.data) + m.Load("ioEntry", &x.ioEntry) +} + +func (x *AIOContext) beforeSave() {} +func (x *AIOContext) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.dead) { + m.Failf("dead is %v, expected zero", x.dead) + } + m.Save("results", &x.results) + m.Save("maxOutstanding", &x.maxOutstanding) + m.Save("outstanding", &x.outstanding) +} + +func (x *AIOContext) load(m state.Map) { + m.Load("results", &x.results) + m.Load("maxOutstanding", &x.maxOutstanding) + m.Load("outstanding", &x.outstanding) + m.AfterLoad(x.afterLoad) +} + +func (x *aioMappable) beforeSave() {} +func (x *aioMappable) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("mfp", &x.mfp) + m.Save("fr", &x.fr) +} + +func (x *aioMappable) afterLoad() {} +func (x *aioMappable) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("mfp", &x.mfp) + m.Load("fr", &x.fr) +} + +func (x *fileRefcountSet) beforeSave() {} +func (x *fileRefcountSet) save(m state.Map) { + x.beforeSave() + var root *fileRefcountSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *fileRefcountSet) afterLoad() {} +func (x *fileRefcountSet) load(m state.Map) { + m.LoadValue("root", new(*fileRefcountSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*fileRefcountSegmentDataSlices)) }) +} + +func (x *fileRefcountnode) beforeSave() {} +func (x *fileRefcountnode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *fileRefcountnode) afterLoad() {} +func (x *fileRefcountnode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *fileRefcountSegmentDataSlices) beforeSave() {} +func (x *fileRefcountSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *fileRefcountSegmentDataSlices) afterLoad() {} +func (x *fileRefcountSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *ioList) beforeSave() {} +func (x *ioList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *ioList) afterLoad() {} +func (x *ioList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *ioEntry) beforeSave() {} +func (x *ioEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *ioEntry) afterLoad() {} +func (x *ioEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *MemoryManager) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.active) { + m.Failf("active is %v, expected zero", x.active) + } + if !state.IsZeroValue(x.captureInvalidations) { + m.Failf("captureInvalidations is %v, expected zero", x.captureInvalidations) + } + m.Save("p", &x.p) + m.Save("mfp", &x.mfp) + m.Save("layout", &x.layout) + m.Save("privateRefs", &x.privateRefs) + m.Save("users", &x.users) + m.Save("vmas", &x.vmas) + m.Save("brk", &x.brk) + m.Save("usageAS", &x.usageAS) + m.Save("lockedAS", &x.lockedAS) + m.Save("dataAS", &x.dataAS) + m.Save("defMLockMode", &x.defMLockMode) + m.Save("pmas", &x.pmas) + m.Save("curRSS", &x.curRSS) + m.Save("maxRSS", &x.maxRSS) + m.Save("argv", &x.argv) + m.Save("envv", &x.envv) + m.Save("auxv", &x.auxv) + m.Save("executable", &x.executable) + m.Save("dumpability", &x.dumpability) + m.Save("aioManager", &x.aioManager) + m.Save("sleepForActivation", &x.sleepForActivation) +} + +func (x *MemoryManager) load(m state.Map) { + m.Load("p", &x.p) + m.Load("mfp", &x.mfp) + m.Load("layout", &x.layout) + m.Load("privateRefs", &x.privateRefs) + m.Load("users", &x.users) + m.Load("vmas", &x.vmas) + m.Load("brk", &x.brk) + m.Load("usageAS", &x.usageAS) + m.Load("lockedAS", &x.lockedAS) + m.Load("dataAS", &x.dataAS) + m.Load("defMLockMode", &x.defMLockMode) + m.Load("pmas", &x.pmas) + m.Load("curRSS", &x.curRSS) + m.Load("maxRSS", &x.maxRSS) + m.Load("argv", &x.argv) + m.Load("envv", &x.envv) + m.Load("auxv", &x.auxv) + m.Load("executable", &x.executable) + m.Load("dumpability", &x.dumpability) + m.Load("aioManager", &x.aioManager) + m.Load("sleepForActivation", &x.sleepForActivation) + m.AfterLoad(x.afterLoad) +} + +func (x *vma) beforeSave() {} +func (x *vma) save(m state.Map) { + x.beforeSave() + var realPerms int = x.saveRealPerms() + m.SaveValue("realPerms", realPerms) + m.Save("mappable", &x.mappable) + m.Save("off", &x.off) + m.Save("dontfork", &x.dontfork) + m.Save("mlockMode", &x.mlockMode) + m.Save("numaPolicy", &x.numaPolicy) + m.Save("numaNodemask", &x.numaNodemask) + m.Save("id", &x.id) + m.Save("hint", &x.hint) +} + +func (x *vma) afterLoad() {} +func (x *vma) load(m state.Map) { + m.Load("mappable", &x.mappable) + m.Load("off", &x.off) + m.Load("dontfork", &x.dontfork) + m.Load("mlockMode", &x.mlockMode) + m.Load("numaPolicy", &x.numaPolicy) + m.Load("numaNodemask", &x.numaNodemask) + m.Load("id", &x.id) + m.Load("hint", &x.hint) + m.LoadValue("realPerms", new(int), func(y interface{}) { x.loadRealPerms(y.(int)) }) +} + +func (x *pma) beforeSave() {} +func (x *pma) save(m state.Map) { + x.beforeSave() + m.Save("off", &x.off) + m.Save("translatePerms", &x.translatePerms) + m.Save("effectivePerms", &x.effectivePerms) + m.Save("maxPerms", &x.maxPerms) + m.Save("needCOW", &x.needCOW) + m.Save("private", &x.private) +} + +func (x *pma) afterLoad() {} +func (x *pma) load(m state.Map) { + m.Load("off", &x.off) + m.Load("translatePerms", &x.translatePerms) + m.Load("effectivePerms", &x.effectivePerms) + m.Load("maxPerms", &x.maxPerms) + m.Load("needCOW", &x.needCOW) + m.Load("private", &x.private) +} + +func (x *privateRefs) beforeSave() {} +func (x *privateRefs) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) +} + +func (x *privateRefs) afterLoad() {} +func (x *privateRefs) load(m state.Map) { + m.Load("refs", &x.refs) +} + +func (x *pmaSet) beforeSave() {} +func (x *pmaSet) save(m state.Map) { + x.beforeSave() + var root *pmaSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *pmaSet) afterLoad() {} +func (x *pmaSet) load(m state.Map) { + m.LoadValue("root", new(*pmaSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*pmaSegmentDataSlices)) }) +} + +func (x *pmanode) beforeSave() {} +func (x *pmanode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *pmanode) afterLoad() {} +func (x *pmanode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *pmaSegmentDataSlices) beforeSave() {} +func (x *pmaSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *pmaSegmentDataSlices) afterLoad() {} +func (x *pmaSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *SpecialMappable) beforeSave() {} +func (x *SpecialMappable) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("mfp", &x.mfp) + m.Save("fr", &x.fr) + m.Save("name", &x.name) +} + +func (x *SpecialMappable) afterLoad() {} +func (x *SpecialMappable) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("mfp", &x.mfp) + m.Load("fr", &x.fr) + m.Load("name", &x.name) +} + +func (x *vmaSet) beforeSave() {} +func (x *vmaSet) save(m state.Map) { + x.beforeSave() + var root *vmaSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *vmaSet) afterLoad() {} +func (x *vmaSet) load(m state.Map) { + m.LoadValue("root", new(*vmaSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*vmaSegmentDataSlices)) }) +} + +func (x *vmanode) beforeSave() {} +func (x *vmanode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *vmanode) afterLoad() {} +func (x *vmanode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *vmaSegmentDataSlices) beforeSave() {} +func (x *vmaSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *vmaSegmentDataSlices) afterLoad() {} +func (x *vmaSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("pkg/sentry/mm.aioManager", (*aioManager)(nil), state.Fns{Save: (*aioManager).save, Load: (*aioManager).load}) + state.Register("pkg/sentry/mm.ioResult", (*ioResult)(nil), state.Fns{Save: (*ioResult).save, Load: (*ioResult).load}) + state.Register("pkg/sentry/mm.AIOContext", (*AIOContext)(nil), state.Fns{Save: (*AIOContext).save, Load: (*AIOContext).load}) + state.Register("pkg/sentry/mm.aioMappable", (*aioMappable)(nil), state.Fns{Save: (*aioMappable).save, Load: (*aioMappable).load}) + state.Register("pkg/sentry/mm.fileRefcountSet", (*fileRefcountSet)(nil), state.Fns{Save: (*fileRefcountSet).save, Load: (*fileRefcountSet).load}) + state.Register("pkg/sentry/mm.fileRefcountnode", (*fileRefcountnode)(nil), state.Fns{Save: (*fileRefcountnode).save, Load: (*fileRefcountnode).load}) + state.Register("pkg/sentry/mm.fileRefcountSegmentDataSlices", (*fileRefcountSegmentDataSlices)(nil), state.Fns{Save: (*fileRefcountSegmentDataSlices).save, Load: (*fileRefcountSegmentDataSlices).load}) + state.Register("pkg/sentry/mm.ioList", (*ioList)(nil), state.Fns{Save: (*ioList).save, Load: (*ioList).load}) + state.Register("pkg/sentry/mm.ioEntry", (*ioEntry)(nil), state.Fns{Save: (*ioEntry).save, Load: (*ioEntry).load}) + state.Register("pkg/sentry/mm.MemoryManager", (*MemoryManager)(nil), state.Fns{Save: (*MemoryManager).save, Load: (*MemoryManager).load}) + state.Register("pkg/sentry/mm.vma", (*vma)(nil), state.Fns{Save: (*vma).save, Load: (*vma).load}) + state.Register("pkg/sentry/mm.pma", (*pma)(nil), state.Fns{Save: (*pma).save, Load: (*pma).load}) + state.Register("pkg/sentry/mm.privateRefs", (*privateRefs)(nil), state.Fns{Save: (*privateRefs).save, Load: (*privateRefs).load}) + state.Register("pkg/sentry/mm.pmaSet", (*pmaSet)(nil), state.Fns{Save: (*pmaSet).save, Load: (*pmaSet).load}) + state.Register("pkg/sentry/mm.pmanode", (*pmanode)(nil), state.Fns{Save: (*pmanode).save, Load: (*pmanode).load}) + state.Register("pkg/sentry/mm.pmaSegmentDataSlices", (*pmaSegmentDataSlices)(nil), state.Fns{Save: (*pmaSegmentDataSlices).save, Load: (*pmaSegmentDataSlices).load}) + state.Register("pkg/sentry/mm.SpecialMappable", (*SpecialMappable)(nil), state.Fns{Save: (*SpecialMappable).save, Load: (*SpecialMappable).load}) + state.Register("pkg/sentry/mm.vmaSet", (*vmaSet)(nil), state.Fns{Save: (*vmaSet).save, Load: (*vmaSet).load}) + state.Register("pkg/sentry/mm.vmanode", (*vmanode)(nil), state.Fns{Save: (*vmanode).save, Load: (*vmanode).load}) + state.Register("pkg/sentry/mm.vmaSegmentDataSlices", (*vmaSegmentDataSlices)(nil), state.Fns{Save: (*vmaSegmentDataSlices).save, Load: (*vmaSegmentDataSlices).load}) +} diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go deleted file mode 100644 index fdc308542..000000000 --- a/pkg/sentry/mm/mm_test.go +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mm - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -func testMemoryManager(ctx context.Context) *MemoryManager { - p := platform.FromContext(ctx) - mfp := pgalloc.MemoryFileProviderFromContext(ctx) - mm := NewMemoryManager(p, mfp, false) - mm.layout = arch.MmapLayout{ - MinAddr: p.MinUserAddress(), - MaxAddr: p.MaxUserAddress(), - BottomUpBase: p.MinUserAddress(), - TopDownBase: p.MaxUserAddress(), - } - return mm -} - -func (mm *MemoryManager) realUsageAS() uint64 { - return uint64(mm.vmas.Span()) -} - -func TestUsageASUpdates(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 2 * usermem.PageSize, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - realUsage := mm.realUsageAS() - if mm.usageAS != realUsage { - t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage) - } - - mm.MUnmap(ctx, addr, usermem.PageSize) - realUsage = mm.realUsageAS() - if mm.usageAS != realUsage { - t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage) - } -} - -func (mm *MemoryManager) realDataAS() uint64 { - var sz uint64 - for seg := mm.vmas.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - vma := seg.Value() - if vma.isPrivateDataLocked() { - sz += uint64(seg.Range().Length()) - } - } - return sz -} - -func TestDataASUpdates(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 3 * usermem.PageSize, - Private: true, - Perms: usermem.Write, - MaxPerms: usermem.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - if mm.dataAS == 0 { - t.Fatalf("dataAS is 0, wanted not 0") - } - realDataAS := mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MUnmap(ctx, addr, usermem.PageSize) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MProtect(addr+usermem.PageSize, usermem.PageSize, usermem.Read, false) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MRemap(ctx, addr+2*usermem.PageSize, usermem.PageSize, 2*usermem.PageSize, MRemapOpts{ - Move: MRemapMayMove, - }) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } -} - -func TestBrkDataLimitUpdates(t *testing.T) { - limitSet := limits.NewLimitSet() - limitSet.Set(limits.Data, limits.Limit{}, true /* privileged */) // zero RLIMIT_DATA - - ctx := contexttest.WithLimitSet(contexttest.Context(t), limitSet) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - // Try to extend the brk by one page and expect doing so to fail. - oldBrk, _ := mm.Brk(ctx, 0) - if newBrk, _ := mm.Brk(ctx, oldBrk+usermem.PageSize); newBrk != oldBrk { - t.Errorf("brk() increased data segment above RLIMIT_DATA (old brk = %#x, new brk = %#x", oldBrk, newBrk) - } -} - -// TestIOAfterUnmap ensures that IO fails after unmap. -func TestIOAfterUnmap(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: usermem.PageSize, - Private: true, - Perms: usermem.Read, - MaxPerms: usermem.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - - // IO works before munmap. - b := make([]byte, 1) - n, err := mm.CopyIn(ctx, addr, b, usermem.IOOpts{}) - if err != nil { - t.Errorf("CopyIn got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyIn got %d want 1", n) - } - - err = mm.MUnmap(ctx, addr, usermem.PageSize) - if err != nil { - t.Fatalf("MUnmap got err %v want nil", err) - } - - n, err = mm.CopyIn(ctx, addr, b, usermem.IOOpts{}) - if err != syserror.EFAULT { - t.Errorf("CopyIn got err %v want EFAULT", err) - } - if n != 0 { - t.Errorf("CopyIn got %d want 0", n) - } -} - -// TestIOAfterMProtect tests IO interaction with mprotect permissions. -func TestIOAfterMProtect(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: usermem.PageSize, - Private: true, - Perms: usermem.ReadWrite, - MaxPerms: usermem.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - - // Writing works before mprotect. - b := make([]byte, 1) - n, err := mm.CopyOut(ctx, addr, b, usermem.IOOpts{}) - if err != nil { - t.Errorf("CopyOut got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyOut got %d want 1", n) - } - - err = mm.MProtect(addr, usermem.PageSize, usermem.Read, false) - if err != nil { - t.Errorf("MProtect got err %v want nil", err) - } - - // Without IgnorePermissions, CopyOut should no longer succeed. - n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{}) - if err != syserror.EFAULT { - t.Errorf("CopyOut got err %v want EFAULT", err) - } - if n != 0 { - t.Errorf("CopyOut got %d want 0", n) - } - - // With IgnorePermissions, CopyOut should succeed despite mprotect. - n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{ - IgnorePermissions: true, - }) - if err != nil { - t.Errorf("CopyOut got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyOut got %d want 1", n) - } -} diff --git a/pkg/sentry/mm/pma_set.go b/pkg/sentry/mm/pma_set.go new file mode 100755 index 000000000..8906e4edc --- /dev/null +++ b/pkg/sentry/mm/pma_set.go @@ -0,0 +1,1274 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/usermem" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + pmaminDegree = 8 + + pmamaxDegree = 2 * pmaminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type pmaSet struct { + root pmanode `state:".(*pmaSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *pmaSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *pmaSet) IsEmptyRange(r __generics_imported0.AddrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *pmaSet) Span() __generics_imported0.Addr { + var sz __generics_imported0.Addr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *pmaSet) SpanRange(r __generics_imported0.AddrRange) __generics_imported0.Addr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz __generics_imported0.Addr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *pmaSet) FirstSegment() pmaIterator { + if s.root.nrSegments == 0 { + return pmaIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *pmaSet) LastSegment() pmaIterator { + if s.root.nrSegments == 0 { + return pmaIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *pmaSet) FirstGap() pmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return pmaGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *pmaSet) LastGap() pmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return pmaGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *pmaSet) Find(key __generics_imported0.Addr) (pmaIterator, pmaGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return pmaIterator{n, i}, pmaGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return pmaIterator{}, pmaGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *pmaSet) FindSegment(key __generics_imported0.Addr) pmaIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *pmaSet) LowerBoundSegment(min __generics_imported0.Addr) pmaIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *pmaSet) UpperBoundSegment(max __generics_imported0.Addr) pmaIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *pmaSet) FindGap(key __generics_imported0.Addr) pmaGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *pmaSet) LowerBoundGap(min __generics_imported0.Addr) pmaGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *pmaSet) UpperBoundGap(max __generics_imported0.Addr) pmaGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *pmaSet) Add(r __generics_imported0.AddrRange, val pma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *pmaSet) AddWithoutMerging(r __generics_imported0.AddrRange, val pma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *pmaSet) Insert(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (pmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *pmaSet) InsertWithoutMerging(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *pmaSet) InsertWithoutMergingUnchecked(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return pmaIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *pmaSet) Remove(seg pmaIterator) pmaGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + pmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(pmaGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *pmaSet) RemoveAll() { + s.root = pmanode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *pmaSet) RemoveRange(r __generics_imported0.AddrRange) pmaGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *pmaSet) Merge(first, second pmaIterator) pmaIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *pmaSet) MergeUnchecked(first, second pmaIterator) pmaIterator { + if first.End() == second.Start() { + if mval, ok := (pmaSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return pmaIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *pmaSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *pmaSet) MergeRange(r __generics_imported0.AddrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *pmaSet) MergeAdjacent(r __generics_imported0.AddrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *pmaSet) Split(seg pmaIterator, split __generics_imported0.Addr) (pmaIterator, pmaIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *pmaSet) SplitUnchecked(seg pmaIterator, split __generics_imported0.Addr) (pmaIterator, pmaIterator) { + val1, val2 := (pmaSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.AddrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *pmaSet) SplitAt(split __generics_imported0.Addr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *pmaSet) Isolate(seg pmaIterator, r __generics_imported0.AddrRange) pmaIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *pmaSet) ApplyContiguous(r __generics_imported0.AddrRange, fn func(seg pmaIterator)) pmaGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return pmaGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return pmaGapIterator{} + } + } +} + +// +stateify savable +type pmanode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *pmanode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [pmamaxDegree - 1]__generics_imported0.AddrRange + values [pmamaxDegree - 1]pma + children [pmamaxDegree]*pmanode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *pmanode) firstSegment() pmaIterator { + for n.hasChildren { + n = n.children[0] + } + return pmaIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *pmanode) lastSegment() pmaIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return pmaIterator{n, n.nrSegments - 1} +} + +func (n *pmanode) prevSibling() *pmanode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *pmanode) nextSibling() *pmanode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *pmanode) rebalanceBeforeInsert(gap pmaGapIterator) pmaGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < pmamaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:pmaminDegree-1], n.keys[:pmaminDegree-1]) + copy(left.values[:pmaminDegree-1], n.values[:pmaminDegree-1]) + copy(right.keys[:pmaminDegree-1], n.keys[pmaminDegree:]) + copy(right.values[:pmaminDegree-1], n.values[pmaminDegree:]) + n.keys[0], n.values[0] = n.keys[pmaminDegree-1], n.values[pmaminDegree-1] + pmazeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:pmaminDegree], n.children[:pmaminDegree]) + copy(right.children[:pmaminDegree], n.children[pmaminDegree:]) + pmazeroNodeSlice(n.children[2:]) + for i := 0; i < pmaminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < pmaminDegree { + return pmaGapIterator{left, gap.index} + } + return pmaGapIterator{right, gap.index - pmaminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[pmaminDegree-1], n.values[pmaminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:pmaminDegree-1], n.keys[pmaminDegree:]) + copy(sibling.values[:pmaminDegree-1], n.values[pmaminDegree:]) + pmazeroValueSlice(n.values[pmaminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:pmaminDegree], n.children[pmaminDegree:]) + pmazeroNodeSlice(n.children[pmaminDegree:]) + for i := 0; i < pmaminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = pmaminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < pmaminDegree { + return gap + } + return pmaGapIterator{sibling, gap.index - pmaminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *pmanode) rebalanceAfterRemove(gap pmaGapIterator) pmaGapIterator { + for { + if n.nrSegments >= pmaminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= pmaminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + pmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return pmaGapIterator{n, 0} + } + if gap.node == n { + return pmaGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= pmaminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + pmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return pmaGapIterator{n, n.nrSegments} + } + return pmaGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return pmaGapIterator{p, gap.index} + } + if gap.node == right { + return pmaGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *pmanode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = pmaGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + pmaSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type pmaIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *pmanode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg pmaIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg pmaIterator) Range() __generics_imported0.AddrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg pmaIterator) Start() __generics_imported0.Addr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg pmaIterator) End() __generics_imported0.Addr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg pmaIterator) SetRangeUnchecked(r __generics_imported0.AddrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetRange(r __generics_imported0.AddrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg pmaIterator) SetStartUnchecked(start __generics_imported0.Addr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetStart(start __generics_imported0.Addr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg pmaIterator) SetEndUnchecked(end __generics_imported0.Addr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetEnd(end __generics_imported0.Addr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg pmaIterator) Value() pma { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg pmaIterator) ValuePtr() *pma { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg pmaIterator) SetValue(val pma) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg pmaIterator) PrevSegment() pmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return pmaIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return pmaIterator{} + } + return pmasegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg pmaIterator) NextSegment() pmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return pmaIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return pmaIterator{} + } + return pmasegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg pmaIterator) PrevGap() pmaGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return pmaGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg pmaIterator) NextGap() pmaGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return pmaGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg pmaIterator) PrevNonEmpty() (pmaIterator, pmaGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return pmaIterator{}, gap + } + return gap.PrevSegment(), pmaGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg pmaIterator) NextNonEmpty() (pmaIterator, pmaGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return pmaIterator{}, gap + } + return gap.NextSegment(), pmaGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type pmaGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *pmanode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap pmaGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap pmaGapIterator) Range() __generics_imported0.AddrRange { + return __generics_imported0.AddrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap pmaGapIterator) Start() __generics_imported0.Addr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return pmaSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap pmaGapIterator) End() __generics_imported0.Addr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return pmaSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap pmaGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap pmaGapIterator) PrevSegment() pmaIterator { + return pmasegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap pmaGapIterator) NextSegment() pmaIterator { + return pmasegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap pmaGapIterator) PrevGap() pmaGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return pmaGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap pmaGapIterator) NextGap() pmaGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return pmaGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func pmasegmentBeforePosition(n *pmanode, i int) pmaIterator { + for i == 0 { + if n.parent == nil { + return pmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return pmaIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func pmasegmentAfterPosition(n *pmanode, i int) pmaIterator { + for i == n.nrSegments { + if n.parent == nil { + return pmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return pmaIterator{n, i} +} + +func pmazeroValueSlice(slice []pma) { + + for i := range slice { + pmaSetFunctions{}.ClearValue(&slice[i]) + } +} + +func pmazeroNodeSlice(slice []*pmanode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *pmaSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *pmanode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *pmanode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type pmaSegmentDataSlices struct { + Start []__generics_imported0.Addr + End []__generics_imported0.Addr + Values []pma +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *pmaSet) ExportSortedSlices() *pmaSegmentDataSlices { + var sds pmaSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *pmaSet) ImportSortedSlices(sds *pmaSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.AddrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *pmaSet) saveRoot() *pmaSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *pmaSet) loadRoot(sds *pmaSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/mm/vma_set.go b/pkg/sentry/mm/vma_set.go new file mode 100755 index 000000000..af6b1d317 --- /dev/null +++ b/pkg/sentry/mm/vma_set.go @@ -0,0 +1,1274 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/usermem" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + vmaminDegree = 8 + + vmamaxDegree = 2 * vmaminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type vmaSet struct { + root vmanode `state:".(*vmaSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *vmaSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *vmaSet) IsEmptyRange(r __generics_imported0.AddrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *vmaSet) Span() __generics_imported0.Addr { + var sz __generics_imported0.Addr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *vmaSet) SpanRange(r __generics_imported0.AddrRange) __generics_imported0.Addr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz __generics_imported0.Addr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *vmaSet) FirstSegment() vmaIterator { + if s.root.nrSegments == 0 { + return vmaIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *vmaSet) LastSegment() vmaIterator { + if s.root.nrSegments == 0 { + return vmaIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *vmaSet) FirstGap() vmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return vmaGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *vmaSet) LastGap() vmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return vmaGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *vmaSet) Find(key __generics_imported0.Addr) (vmaIterator, vmaGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return vmaIterator{n, i}, vmaGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return vmaIterator{}, vmaGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *vmaSet) FindSegment(key __generics_imported0.Addr) vmaIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *vmaSet) LowerBoundSegment(min __generics_imported0.Addr) vmaIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *vmaSet) UpperBoundSegment(max __generics_imported0.Addr) vmaIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *vmaSet) FindGap(key __generics_imported0.Addr) vmaGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *vmaSet) LowerBoundGap(min __generics_imported0.Addr) vmaGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *vmaSet) UpperBoundGap(max __generics_imported0.Addr) vmaGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *vmaSet) Add(r __generics_imported0.AddrRange, val vma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *vmaSet) AddWithoutMerging(r __generics_imported0.AddrRange, val vma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *vmaSet) Insert(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (vmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (vmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (vmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *vmaSet) InsertWithoutMerging(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *vmaSet) InsertWithoutMergingUnchecked(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return vmaIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *vmaSet) Remove(seg vmaIterator) vmaGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + vmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(vmaGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *vmaSet) RemoveAll() { + s.root = vmanode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *vmaSet) RemoveRange(r __generics_imported0.AddrRange) vmaGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *vmaSet) Merge(first, second vmaIterator) vmaIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *vmaSet) MergeUnchecked(first, second vmaIterator) vmaIterator { + if first.End() == second.Start() { + if mval, ok := (vmaSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return vmaIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *vmaSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *vmaSet) MergeRange(r __generics_imported0.AddrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *vmaSet) MergeAdjacent(r __generics_imported0.AddrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *vmaSet) Split(seg vmaIterator, split __generics_imported0.Addr) (vmaIterator, vmaIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *vmaSet) SplitUnchecked(seg vmaIterator, split __generics_imported0.Addr) (vmaIterator, vmaIterator) { + val1, val2 := (vmaSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.AddrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *vmaSet) SplitAt(split __generics_imported0.Addr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *vmaSet) Isolate(seg vmaIterator, r __generics_imported0.AddrRange) vmaIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *vmaSet) ApplyContiguous(r __generics_imported0.AddrRange, fn func(seg vmaIterator)) vmaGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return vmaGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return vmaGapIterator{} + } + } +} + +// +stateify savable +type vmanode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *vmanode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [vmamaxDegree - 1]__generics_imported0.AddrRange + values [vmamaxDegree - 1]vma + children [vmamaxDegree]*vmanode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *vmanode) firstSegment() vmaIterator { + for n.hasChildren { + n = n.children[0] + } + return vmaIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *vmanode) lastSegment() vmaIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return vmaIterator{n, n.nrSegments - 1} +} + +func (n *vmanode) prevSibling() *vmanode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *vmanode) nextSibling() *vmanode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *vmanode) rebalanceBeforeInsert(gap vmaGapIterator) vmaGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < vmamaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:vmaminDegree-1], n.keys[:vmaminDegree-1]) + copy(left.values[:vmaminDegree-1], n.values[:vmaminDegree-1]) + copy(right.keys[:vmaminDegree-1], n.keys[vmaminDegree:]) + copy(right.values[:vmaminDegree-1], n.values[vmaminDegree:]) + n.keys[0], n.values[0] = n.keys[vmaminDegree-1], n.values[vmaminDegree-1] + vmazeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:vmaminDegree], n.children[:vmaminDegree]) + copy(right.children[:vmaminDegree], n.children[vmaminDegree:]) + vmazeroNodeSlice(n.children[2:]) + for i := 0; i < vmaminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < vmaminDegree { + return vmaGapIterator{left, gap.index} + } + return vmaGapIterator{right, gap.index - vmaminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[vmaminDegree-1], n.values[vmaminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:vmaminDegree-1], n.keys[vmaminDegree:]) + copy(sibling.values[:vmaminDegree-1], n.values[vmaminDegree:]) + vmazeroValueSlice(n.values[vmaminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:vmaminDegree], n.children[vmaminDegree:]) + vmazeroNodeSlice(n.children[vmaminDegree:]) + for i := 0; i < vmaminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = vmaminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < vmaminDegree { + return gap + } + return vmaGapIterator{sibling, gap.index - vmaminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *vmanode) rebalanceAfterRemove(gap vmaGapIterator) vmaGapIterator { + for { + if n.nrSegments >= vmaminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= vmaminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + vmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return vmaGapIterator{n, 0} + } + if gap.node == n { + return vmaGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= vmaminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + vmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return vmaGapIterator{n, n.nrSegments} + } + return vmaGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return vmaGapIterator{p, gap.index} + } + if gap.node == right { + return vmaGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *vmanode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = vmaGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + vmaSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type vmaIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *vmanode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg vmaIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg vmaIterator) Range() __generics_imported0.AddrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg vmaIterator) Start() __generics_imported0.Addr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg vmaIterator) End() __generics_imported0.Addr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg vmaIterator) SetRangeUnchecked(r __generics_imported0.AddrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetRange(r __generics_imported0.AddrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg vmaIterator) SetStartUnchecked(start __generics_imported0.Addr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetStart(start __generics_imported0.Addr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg vmaIterator) SetEndUnchecked(end __generics_imported0.Addr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetEnd(end __generics_imported0.Addr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg vmaIterator) Value() vma { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg vmaIterator) ValuePtr() *vma { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg vmaIterator) SetValue(val vma) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg vmaIterator) PrevSegment() vmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return vmaIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return vmaIterator{} + } + return vmasegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg vmaIterator) NextSegment() vmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return vmaIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return vmaIterator{} + } + return vmasegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg vmaIterator) PrevGap() vmaGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return vmaGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg vmaIterator) NextGap() vmaGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return vmaGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg vmaIterator) PrevNonEmpty() (vmaIterator, vmaGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return vmaIterator{}, gap + } + return gap.PrevSegment(), vmaGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg vmaIterator) NextNonEmpty() (vmaIterator, vmaGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return vmaIterator{}, gap + } + return gap.NextSegment(), vmaGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type vmaGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *vmanode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap vmaGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap vmaGapIterator) Range() __generics_imported0.AddrRange { + return __generics_imported0.AddrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap vmaGapIterator) Start() __generics_imported0.Addr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return vmaSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap vmaGapIterator) End() __generics_imported0.Addr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return vmaSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap vmaGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap vmaGapIterator) PrevSegment() vmaIterator { + return vmasegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap vmaGapIterator) NextSegment() vmaIterator { + return vmasegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap vmaGapIterator) PrevGap() vmaGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return vmaGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap vmaGapIterator) NextGap() vmaGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return vmaGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func vmasegmentBeforePosition(n *vmanode, i int) vmaIterator { + for i == 0 { + if n.parent == nil { + return vmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return vmaIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func vmasegmentAfterPosition(n *vmanode, i int) vmaIterator { + for i == n.nrSegments { + if n.parent == nil { + return vmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return vmaIterator{n, i} +} + +func vmazeroValueSlice(slice []vma) { + + for i := range slice { + vmaSetFunctions{}.ClearValue(&slice[i]) + } +} + +func vmazeroNodeSlice(slice []*vmanode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *vmaSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *vmanode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *vmanode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type vmaSegmentDataSlices struct { + Start []__generics_imported0.Addr + End []__generics_imported0.Addr + Values []vma +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *vmaSet) ExportSortedSlices() *vmaSegmentDataSlices { + var sds vmaSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *vmaSet) ImportSortedSlices(sds *vmaSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.AddrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *vmaSet) saveRoot() *vmaSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *vmaSet) loadRoot(sds *vmaSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD deleted file mode 100644 index 1eeb9f317..000000000 --- a/pkg/sentry/pgalloc/BUILD +++ /dev/null @@ -1,85 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "evictable_range", - out = "evictable_range.go", - package = "pgalloc", - prefix = "Evictable", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "evictable_range_set", - out = "evictable_range_set.go", - package = "pgalloc", - prefix = "evictableRange", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "EvictableRange", - "Value": "evictableRangeSetValue", - "Functions": "evictableRangeSetFunctions", - }, -) - -go_template_instance( - name = "usage_set", - out = "usage_set.go", - consts = { - "minDegree": "10", - }, - imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "pgalloc", - prefix = "usage", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "platform.FileRange", - "Value": "usageInfo", - "Functions": "usageSetFunctions", - }, -) - -go_library( - name = "pgalloc", - srcs = [ - "context.go", - "evictable_range.go", - "evictable_range_set.go", - "pgalloc.go", - "pgalloc_unsafe.go", - "save_restore.go", - "usage_set.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/log", - "//pkg/memutil", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/hostmm", - "//pkg/sentry/platform", - "//pkg/sentry/usage", - "//pkg/state", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) - -go_test( - name = "pgalloc_test", - size = "small", - srcs = ["pgalloc_test.go"], - library = ":pgalloc", - deps = ["//pkg/usermem"], -) diff --git a/pkg/sentry/pgalloc/evictable_range.go b/pkg/sentry/pgalloc/evictable_range.go new file mode 100755 index 000000000..10ce2ff44 --- /dev/null +++ b/pkg/sentry/pgalloc/evictable_range.go @@ -0,0 +1,62 @@ +package pgalloc + +// A Range represents a contiguous range of T. +// +// +stateify savable +type EvictableRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r EvictableRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r EvictableRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r EvictableRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r EvictableRange) Overlaps(r2 EvictableRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r EvictableRange) IsSupersetOf(r2 EvictableRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r EvictableRange) Intersect(r2 EvictableRange) EvictableRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r EvictableRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/pgalloc/evictable_range_set.go b/pkg/sentry/pgalloc/evictable_range_set.go new file mode 100755 index 000000000..6fbd02434 --- /dev/null +++ b/pkg/sentry/pgalloc/evictable_range_set.go @@ -0,0 +1,1270 @@ +package pgalloc + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + evictableRangeminDegree = 3 + + evictableRangemaxDegree = 2 * evictableRangeminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type evictableRangeSet struct { + root evictableRangenode `state:".(*evictableRangeSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *evictableRangeSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *evictableRangeSet) IsEmptyRange(r EvictableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *evictableRangeSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *evictableRangeSet) SpanRange(r EvictableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *evictableRangeSet) FirstSegment() evictableRangeIterator { + if s.root.nrSegments == 0 { + return evictableRangeIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *evictableRangeSet) LastSegment() evictableRangeIterator { + if s.root.nrSegments == 0 { + return evictableRangeIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *evictableRangeSet) FirstGap() evictableRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return evictableRangeGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *evictableRangeSet) LastGap() evictableRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return evictableRangeGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *evictableRangeSet) Find(key uint64) (evictableRangeIterator, evictableRangeGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return evictableRangeIterator{n, i}, evictableRangeGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return evictableRangeIterator{}, evictableRangeGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *evictableRangeSet) FindSegment(key uint64) evictableRangeIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *evictableRangeSet) LowerBoundSegment(min uint64) evictableRangeIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *evictableRangeSet) UpperBoundSegment(max uint64) evictableRangeIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *evictableRangeSet) FindGap(key uint64) evictableRangeGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *evictableRangeSet) LowerBoundGap(min uint64) evictableRangeGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *evictableRangeSet) UpperBoundGap(max uint64) evictableRangeGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *evictableRangeSet) Add(r EvictableRange, val evictableRangeSetValue) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *evictableRangeSet) AddWithoutMerging(r EvictableRange, val evictableRangeSetValue) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *evictableRangeSet) Insert(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (evictableRangeSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (evictableRangeSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (evictableRangeSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *evictableRangeSet) InsertWithoutMerging(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *evictableRangeSet) InsertWithoutMergingUnchecked(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return evictableRangeIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *evictableRangeSet) Remove(seg evictableRangeIterator) evictableRangeGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + evictableRangeSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(evictableRangeGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *evictableRangeSet) RemoveAll() { + s.root = evictableRangenode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *evictableRangeSet) RemoveRange(r EvictableRange) evictableRangeGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *evictableRangeSet) Merge(first, second evictableRangeIterator) evictableRangeIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *evictableRangeSet) MergeUnchecked(first, second evictableRangeIterator) evictableRangeIterator { + if first.End() == second.Start() { + if mval, ok := (evictableRangeSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return evictableRangeIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *evictableRangeSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *evictableRangeSet) MergeRange(r EvictableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *evictableRangeSet) MergeAdjacent(r EvictableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *evictableRangeSet) Split(seg evictableRangeIterator, split uint64) (evictableRangeIterator, evictableRangeIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *evictableRangeSet) SplitUnchecked(seg evictableRangeIterator, split uint64) (evictableRangeIterator, evictableRangeIterator) { + val1, val2 := (evictableRangeSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), EvictableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *evictableRangeSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *evictableRangeSet) Isolate(seg evictableRangeIterator, r EvictableRange) evictableRangeIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *evictableRangeSet) ApplyContiguous(r EvictableRange, fn func(seg evictableRangeIterator)) evictableRangeGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return evictableRangeGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return evictableRangeGapIterator{} + } + } +} + +// +stateify savable +type evictableRangenode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *evictableRangenode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [evictableRangemaxDegree - 1]EvictableRange + values [evictableRangemaxDegree - 1]evictableRangeSetValue + children [evictableRangemaxDegree]*evictableRangenode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *evictableRangenode) firstSegment() evictableRangeIterator { + for n.hasChildren { + n = n.children[0] + } + return evictableRangeIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *evictableRangenode) lastSegment() evictableRangeIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return evictableRangeIterator{n, n.nrSegments - 1} +} + +func (n *evictableRangenode) prevSibling() *evictableRangenode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *evictableRangenode) nextSibling() *evictableRangenode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *evictableRangenode) rebalanceBeforeInsert(gap evictableRangeGapIterator) evictableRangeGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < evictableRangemaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:evictableRangeminDegree-1], n.keys[:evictableRangeminDegree-1]) + copy(left.values[:evictableRangeminDegree-1], n.values[:evictableRangeminDegree-1]) + copy(right.keys[:evictableRangeminDegree-1], n.keys[evictableRangeminDegree:]) + copy(right.values[:evictableRangeminDegree-1], n.values[evictableRangeminDegree:]) + n.keys[0], n.values[0] = n.keys[evictableRangeminDegree-1], n.values[evictableRangeminDegree-1] + evictableRangezeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:evictableRangeminDegree], n.children[:evictableRangeminDegree]) + copy(right.children[:evictableRangeminDegree], n.children[evictableRangeminDegree:]) + evictableRangezeroNodeSlice(n.children[2:]) + for i := 0; i < evictableRangeminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < evictableRangeminDegree { + return evictableRangeGapIterator{left, gap.index} + } + return evictableRangeGapIterator{right, gap.index - evictableRangeminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[evictableRangeminDegree-1], n.values[evictableRangeminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:evictableRangeminDegree-1], n.keys[evictableRangeminDegree:]) + copy(sibling.values[:evictableRangeminDegree-1], n.values[evictableRangeminDegree:]) + evictableRangezeroValueSlice(n.values[evictableRangeminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:evictableRangeminDegree], n.children[evictableRangeminDegree:]) + evictableRangezeroNodeSlice(n.children[evictableRangeminDegree:]) + for i := 0; i < evictableRangeminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = evictableRangeminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < evictableRangeminDegree { + return gap + } + return evictableRangeGapIterator{sibling, gap.index - evictableRangeminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *evictableRangenode) rebalanceAfterRemove(gap evictableRangeGapIterator) evictableRangeGapIterator { + for { + if n.nrSegments >= evictableRangeminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= evictableRangeminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + evictableRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return evictableRangeGapIterator{n, 0} + } + if gap.node == n { + return evictableRangeGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= evictableRangeminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + evictableRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return evictableRangeGapIterator{n, n.nrSegments} + } + return evictableRangeGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return evictableRangeGapIterator{p, gap.index} + } + if gap.node == right { + return evictableRangeGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *evictableRangenode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = evictableRangeGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + evictableRangeSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type evictableRangeIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *evictableRangenode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg evictableRangeIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg evictableRangeIterator) Range() EvictableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg evictableRangeIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg evictableRangeIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg evictableRangeIterator) SetRangeUnchecked(r EvictableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetRange(r EvictableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg evictableRangeIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg evictableRangeIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg evictableRangeIterator) Value() evictableRangeSetValue { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg evictableRangeIterator) ValuePtr() *evictableRangeSetValue { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg evictableRangeIterator) SetValue(val evictableRangeSetValue) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg evictableRangeIterator) PrevSegment() evictableRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return evictableRangeIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return evictableRangeIterator{} + } + return evictableRangesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg evictableRangeIterator) NextSegment() evictableRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return evictableRangeIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return evictableRangeIterator{} + } + return evictableRangesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg evictableRangeIterator) PrevGap() evictableRangeGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return evictableRangeGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg evictableRangeIterator) NextGap() evictableRangeGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return evictableRangeGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg evictableRangeIterator) PrevNonEmpty() (evictableRangeIterator, evictableRangeGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return evictableRangeIterator{}, gap + } + return gap.PrevSegment(), evictableRangeGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg evictableRangeIterator) NextNonEmpty() (evictableRangeIterator, evictableRangeGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return evictableRangeIterator{}, gap + } + return gap.NextSegment(), evictableRangeGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type evictableRangeGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *evictableRangenode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap evictableRangeGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap evictableRangeGapIterator) Range() EvictableRange { + return EvictableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap evictableRangeGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return evictableRangeSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap evictableRangeGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return evictableRangeSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap evictableRangeGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap evictableRangeGapIterator) PrevSegment() evictableRangeIterator { + return evictableRangesegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap evictableRangeGapIterator) NextSegment() evictableRangeIterator { + return evictableRangesegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap evictableRangeGapIterator) PrevGap() evictableRangeGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return evictableRangeGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap evictableRangeGapIterator) NextGap() evictableRangeGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return evictableRangeGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func evictableRangesegmentBeforePosition(n *evictableRangenode, i int) evictableRangeIterator { + for i == 0 { + if n.parent == nil { + return evictableRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return evictableRangeIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func evictableRangesegmentAfterPosition(n *evictableRangenode, i int) evictableRangeIterator { + for i == n.nrSegments { + if n.parent == nil { + return evictableRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return evictableRangeIterator{n, i} +} + +func evictableRangezeroValueSlice(slice []evictableRangeSetValue) { + + for i := range slice { + evictableRangeSetFunctions{}.ClearValue(&slice[i]) + } +} + +func evictableRangezeroNodeSlice(slice []*evictableRangenode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *evictableRangeSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *evictableRangenode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *evictableRangenode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type evictableRangeSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []evictableRangeSetValue +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *evictableRangeSet) ExportSortedSlices() *evictableRangeSegmentDataSlices { + var sds evictableRangeSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *evictableRangeSet) ImportSortedSlices(sds *evictableRangeSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := EvictableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *evictableRangeSet) saveRoot() *evictableRangeSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *evictableRangeSet) loadRoot(sds *evictableRangeSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/pgalloc/pgalloc_state_autogen.go b/pkg/sentry/pgalloc/pgalloc_state_autogen.go new file mode 100755 index 000000000..97e1c883b --- /dev/null +++ b/pkg/sentry/pgalloc/pgalloc_state_autogen.go @@ -0,0 +1,146 @@ +// automatically generated by stateify. + +package pgalloc + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *EvictableRange) beforeSave() {} +func (x *EvictableRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *EvictableRange) afterLoad() {} +func (x *EvictableRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *evictableRangeSet) beforeSave() {} +func (x *evictableRangeSet) save(m state.Map) { + x.beforeSave() + var root *evictableRangeSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *evictableRangeSet) afterLoad() {} +func (x *evictableRangeSet) load(m state.Map) { + m.LoadValue("root", new(*evictableRangeSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*evictableRangeSegmentDataSlices)) }) +} + +func (x *evictableRangenode) beforeSave() {} +func (x *evictableRangenode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *evictableRangenode) afterLoad() {} +func (x *evictableRangenode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *evictableRangeSegmentDataSlices) beforeSave() {} +func (x *evictableRangeSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *evictableRangeSegmentDataSlices) afterLoad() {} +func (x *evictableRangeSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *usageInfo) beforeSave() {} +func (x *usageInfo) save(m state.Map) { + x.beforeSave() + m.Save("kind", &x.kind) + m.Save("knownCommitted", &x.knownCommitted) + m.Save("refs", &x.refs) +} + +func (x *usageInfo) afterLoad() {} +func (x *usageInfo) load(m state.Map) { + m.Load("kind", &x.kind) + m.Load("knownCommitted", &x.knownCommitted) + m.Load("refs", &x.refs) +} + +func (x *usageSet) beforeSave() {} +func (x *usageSet) save(m state.Map) { + x.beforeSave() + var root *usageSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *usageSet) afterLoad() {} +func (x *usageSet) load(m state.Map) { + m.LoadValue("root", new(*usageSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*usageSegmentDataSlices)) }) +} + +func (x *usagenode) beforeSave() {} +func (x *usagenode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *usagenode) afterLoad() {} +func (x *usagenode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *usageSegmentDataSlices) beforeSave() {} +func (x *usageSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *usageSegmentDataSlices) afterLoad() {} +func (x *usageSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("pkg/sentry/pgalloc.EvictableRange", (*EvictableRange)(nil), state.Fns{Save: (*EvictableRange).save, Load: (*EvictableRange).load}) + state.Register("pkg/sentry/pgalloc.evictableRangeSet", (*evictableRangeSet)(nil), state.Fns{Save: (*evictableRangeSet).save, Load: (*evictableRangeSet).load}) + state.Register("pkg/sentry/pgalloc.evictableRangenode", (*evictableRangenode)(nil), state.Fns{Save: (*evictableRangenode).save, Load: (*evictableRangenode).load}) + state.Register("pkg/sentry/pgalloc.evictableRangeSegmentDataSlices", (*evictableRangeSegmentDataSlices)(nil), state.Fns{Save: (*evictableRangeSegmentDataSlices).save, Load: (*evictableRangeSegmentDataSlices).load}) + state.Register("pkg/sentry/pgalloc.usageInfo", (*usageInfo)(nil), state.Fns{Save: (*usageInfo).save, Load: (*usageInfo).load}) + state.Register("pkg/sentry/pgalloc.usageSet", (*usageSet)(nil), state.Fns{Save: (*usageSet).save, Load: (*usageSet).load}) + state.Register("pkg/sentry/pgalloc.usagenode", (*usagenode)(nil), state.Fns{Save: (*usagenode).save, Load: (*usagenode).load}) + state.Register("pkg/sentry/pgalloc.usageSegmentDataSlices", (*usageSegmentDataSlices)(nil), state.Fns{Save: (*usageSegmentDataSlices).save, Load: (*usageSegmentDataSlices).load}) +} diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go deleted file mode 100644 index 293f22c6b..000000000 --- a/pkg/sentry/pgalloc/pgalloc_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgalloc - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - page = usermem.PageSize - hugepage = usermem.HugePageSize -) - -func TestFindUnallocatedRange(t *testing.T) { - for _, test := range []struct { - desc string - usage *usageSegmentDataSlices - start uint64 - length uint64 - alignment uint64 - unallocated uint64 - minUnallocated uint64 - }{ - { - desc: "Initial allocation succeeds", - usage: &usageSegmentDataSlices{}, - start: 0, - length: page, - alignment: page, - unallocated: 0, - minUnallocated: 0, - }, - { - desc: "Allocation begins at start of file", - usage: &usageSegmentDataSlices{ - Start: []uint64{page}, - End: []uint64{2 * page}, - Values: []usageInfo{{refs: 1}}, - }, - start: 0, - length: page, - alignment: page, - unallocated: 0, - minUnallocated: 0, - }, - { - desc: "In-use frames are not allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, page}, - End: []uint64{page, 2 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - start: 0, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, - }, - { - desc: "Reclaimable frames are not allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, page, 2 * page}, - End: []uint64{page, 2 * page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}}, - }, - start: 0, - length: page, - alignment: page, - unallocated: 3 * page, - minUnallocated: 3 * page, - }, - { - desc: "Gaps between in-use frames are allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 2 * page}, - End: []uint64{page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - start: 0, - length: page, - alignment: page, - unallocated: page, - minUnallocated: page, - }, - { - desc: "Inadequately-sized gaps are rejected", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 2 * page}, - End: []uint64{page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - start: 0, - length: 2 * page, - alignment: page, - unallocated: 3 * page, - minUnallocated: page, - }, - { - desc: "Hugepage alignment is honored", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, hugepage + page}, - // Hugepage-sized gap here that shouldn't be allocated from - // since it's incorrectly aligned. - End: []uint64{page, hugepage + 2*page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - start: 0, - length: hugepage, - alignment: hugepage, - unallocated: 2 * hugepage, - minUnallocated: page, - }, - { - desc: "Pages before start ignored", - usage: &usageSegmentDataSlices{ - Start: []uint64{page, 3 * page}, - End: []uint64{2 * page, 4 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - start: page, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, - }, - { - desc: "start may be in the middle of segment", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 3 * page}, - End: []uint64{2 * page, 4 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - start: page, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, - }, - } { - t.Run(test.desc, func(t *testing.T) { - var usage usageSet - if err := usage.ImportSortedSlices(test.usage); err != nil { - t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err) - } - unallocated, minUnallocated := findUnallocatedRange(&usage, test.start, test.length, test.alignment) - if unallocated != test.unallocated { - t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got unallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, unallocated, test.unallocated) - } - if minUnallocated != test.minUnallocated { - t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got minUnallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, minUnallocated, test.minUnallocated) - } - }) - } -} diff --git a/pkg/sentry/pgalloc/usage_set.go b/pkg/sentry/pgalloc/usage_set.go new file mode 100755 index 000000000..37b9235ca --- /dev/null +++ b/pkg/sentry/pgalloc/usage_set.go @@ -0,0 +1,1274 @@ +package pgalloc + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/platform" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + usageminDegree = 10 + + usagemaxDegree = 2 * usageminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type usageSet struct { + root usagenode `state:".(*usageSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *usageSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *usageSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *usageSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *usageSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *usageSet) FirstSegment() usageIterator { + if s.root.nrSegments == 0 { + return usageIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *usageSet) LastSegment() usageIterator { + if s.root.nrSegments == 0 { + return usageIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *usageSet) FirstGap() usageGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return usageGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *usageSet) LastGap() usageGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return usageGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *usageSet) Find(key uint64) (usageIterator, usageGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return usageIterator{n, i}, usageGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return usageIterator{}, usageGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *usageSet) FindSegment(key uint64) usageIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *usageSet) LowerBoundSegment(min uint64) usageIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *usageSet) UpperBoundSegment(max uint64) usageIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *usageSet) FindGap(key uint64) usageGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *usageSet) LowerBoundGap(min uint64) usageGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *usageSet) UpperBoundGap(max uint64) usageGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *usageSet) Add(r __generics_imported0.FileRange, val usageInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *usageSet) AddWithoutMerging(r __generics_imported0.FileRange, val usageInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *usageSet) Insert(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (usageSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (usageSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (usageSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *usageSet) InsertWithoutMerging(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *usageSet) InsertWithoutMergingUnchecked(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return usageIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *usageSet) Remove(seg usageIterator) usageGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + usageSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(usageGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *usageSet) RemoveAll() { + s.root = usagenode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *usageSet) RemoveRange(r __generics_imported0.FileRange) usageGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *usageSet) Merge(first, second usageIterator) usageIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *usageSet) MergeUnchecked(first, second usageIterator) usageIterator { + if first.End() == second.Start() { + if mval, ok := (usageSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return usageIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *usageSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *usageSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *usageSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *usageSet) Split(seg usageIterator, split uint64) (usageIterator, usageIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *usageSet) SplitUnchecked(seg usageIterator, split uint64) (usageIterator, usageIterator) { + val1, val2 := (usageSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *usageSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *usageSet) Isolate(seg usageIterator, r __generics_imported0.FileRange) usageIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *usageSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg usageIterator)) usageGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return usageGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return usageGapIterator{} + } + } +} + +// +stateify savable +type usagenode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *usagenode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [usagemaxDegree - 1]__generics_imported0.FileRange + values [usagemaxDegree - 1]usageInfo + children [usagemaxDegree]*usagenode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *usagenode) firstSegment() usageIterator { + for n.hasChildren { + n = n.children[0] + } + return usageIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *usagenode) lastSegment() usageIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return usageIterator{n, n.nrSegments - 1} +} + +func (n *usagenode) prevSibling() *usagenode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *usagenode) nextSibling() *usagenode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *usagenode) rebalanceBeforeInsert(gap usageGapIterator) usageGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < usagemaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:usageminDegree-1], n.keys[:usageminDegree-1]) + copy(left.values[:usageminDegree-1], n.values[:usageminDegree-1]) + copy(right.keys[:usageminDegree-1], n.keys[usageminDegree:]) + copy(right.values[:usageminDegree-1], n.values[usageminDegree:]) + n.keys[0], n.values[0] = n.keys[usageminDegree-1], n.values[usageminDegree-1] + usagezeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:usageminDegree], n.children[:usageminDegree]) + copy(right.children[:usageminDegree], n.children[usageminDegree:]) + usagezeroNodeSlice(n.children[2:]) + for i := 0; i < usageminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < usageminDegree { + return usageGapIterator{left, gap.index} + } + return usageGapIterator{right, gap.index - usageminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[usageminDegree-1], n.values[usageminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:usageminDegree-1], n.keys[usageminDegree:]) + copy(sibling.values[:usageminDegree-1], n.values[usageminDegree:]) + usagezeroValueSlice(n.values[usageminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:usageminDegree], n.children[usageminDegree:]) + usagezeroNodeSlice(n.children[usageminDegree:]) + for i := 0; i < usageminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = usageminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < usageminDegree { + return gap + } + return usageGapIterator{sibling, gap.index - usageminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *usagenode) rebalanceAfterRemove(gap usageGapIterator) usageGapIterator { + for { + if n.nrSegments >= usageminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= usageminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + usageSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return usageGapIterator{n, 0} + } + if gap.node == n { + return usageGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= usageminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + usageSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return usageGapIterator{n, n.nrSegments} + } + return usageGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return usageGapIterator{p, gap.index} + } + if gap.node == right { + return usageGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *usagenode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = usageGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + usageSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type usageIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *usagenode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg usageIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg usageIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg usageIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg usageIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg usageIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg usageIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg usageIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg usageIterator) Value() usageInfo { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg usageIterator) ValuePtr() *usageInfo { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg usageIterator) SetValue(val usageInfo) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg usageIterator) PrevSegment() usageIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return usageIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return usageIterator{} + } + return usagesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg usageIterator) NextSegment() usageIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return usageIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return usageIterator{} + } + return usagesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg usageIterator) PrevGap() usageGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return usageGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg usageIterator) NextGap() usageGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return usageGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg usageIterator) PrevNonEmpty() (usageIterator, usageGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return usageIterator{}, gap + } + return gap.PrevSegment(), usageGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg usageIterator) NextNonEmpty() (usageIterator, usageGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return usageIterator{}, gap + } + return gap.NextSegment(), usageGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type usageGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *usagenode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap usageGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap usageGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap usageGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return usageSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap usageGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return usageSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap usageGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap usageGapIterator) PrevSegment() usageIterator { + return usagesegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap usageGapIterator) NextSegment() usageIterator { + return usagesegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap usageGapIterator) PrevGap() usageGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return usageGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap usageGapIterator) NextGap() usageGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return usageGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func usagesegmentBeforePosition(n *usagenode, i int) usageIterator { + for i == 0 { + if n.parent == nil { + return usageIterator{} + } + n, i = n.parent, n.parentIndex + } + return usageIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func usagesegmentAfterPosition(n *usagenode, i int) usageIterator { + for i == n.nrSegments { + if n.parent == nil { + return usageIterator{} + } + n, i = n.parent, n.parentIndex + } + return usageIterator{n, i} +} + +func usagezeroValueSlice(slice []usageInfo) { + + for i := range slice { + usageSetFunctions{}.ClearValue(&slice[i]) + } +} + +func usagezeroNodeSlice(slice []*usagenode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *usageSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *usagenode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *usagenode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type usageSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []usageInfo +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *usageSet) ExportSortedSlices() *usageSegmentDataSlices { + var sds usageSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *usageSet) ImportSortedSlices(sds *usageSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *usageSet) saveRoot() *usageSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *usageSet) loadRoot(sds *usageSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD deleted file mode 100644 index 453241eca..000000000 --- a/pkg/sentry/platform/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "file_range", - out = "file_range.go", - package = "platform", - prefix = "File", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_library( - name = "platform", - srcs = [ - "context.go", - "file_range.go", - "mmap_min_addr.go", - "platform.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/atomicbitops", - "//pkg/context", - "//pkg/log", - "//pkg/safecopy", - "//pkg/safemem", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/usage", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/platform/file_range.go b/pkg/sentry/platform/file_range.go new file mode 100755 index 000000000..685d360e3 --- /dev/null +++ b/pkg/sentry/platform/file_range.go @@ -0,0 +1,62 @@ +package platform + +// A Range represents a contiguous range of T. +// +// +stateify savable +type FileRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r FileRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r FileRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r FileRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r FileRange) Overlaps(r2 FileRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r FileRange) IsSupersetOf(r2 FileRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r FileRange) Intersect(r2 FileRange) FileRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r FileRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD deleted file mode 100644 index 83b385f14..000000000 --- a/pkg/sentry/platform/interrupt/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "interrupt", - srcs = [ - "interrupt.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/sync"], -) - -go_test( - name = "interrupt_test", - size = "small", - srcs = ["interrupt_test.go"], - library = ":interrupt", -) diff --git a/pkg/sentry/platform/interrupt/interrupt_state_autogen.go b/pkg/sentry/platform/interrupt/interrupt_state_autogen.go new file mode 100755 index 000000000..1336e5f01 --- /dev/null +++ b/pkg/sentry/platform/interrupt/interrupt_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package interrupt diff --git a/pkg/sentry/platform/interrupt/interrupt_test.go b/pkg/sentry/platform/interrupt/interrupt_test.go deleted file mode 100644 index 0ecdf6e7a..000000000 --- a/pkg/sentry/platform/interrupt/interrupt_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package interrupt - -import ( - "testing" -) - -type countingReceiver struct { - interrupts int -} - -// NotifyInterrupt implements Receiver.NotifyInterrupt. -func (r *countingReceiver) NotifyInterrupt() { - r.interrupts++ -} - -func TestSingleInterruptBeforeEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - f.NotifyInterrupt() - // The interrupt should cause the first Enable to fail. - if f.Enable(&r) { - f.Disable() - t.Fatalf("Enable: got true, wanted false") - } - // The failing Enable "acknowledges" the interrupt, allowing future Enables - // to succeed. - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - f.Disable() -} - -func TestMultipleInterruptsBeforeEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - f.NotifyInterrupt() - f.NotifyInterrupt() - // The interrupts should cause the first Enable to fail. - if f.Enable(&r) { - f.Disable() - t.Fatalf("Enable: got true, wanted false") - } - // Interrupts are deduplicated while the Forwarder is disabled, so the - // failing Enable "acknowledges" all interrupts, allowing future Enables to - // succeed. - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - f.Disable() -} - -func TestSingleInterruptAfterEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - defer f.Disable() - f.NotifyInterrupt() - if r.interrupts != 1 { - t.Errorf("interrupts: got %d, wanted 1", r.interrupts) - } -} - -func TestMultipleInterruptsAfterEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - defer f.Disable() - f.NotifyInterrupt() - f.NotifyInterrupt() - if r.interrupts != 2 { - t.Errorf("interrupts: got %d, wanted 2", r.interrupts) - } -} diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD deleted file mode 100644 index 159f7eafd..000000000 --- a/pkg/sentry/platform/kvm/BUILD +++ /dev/null @@ -1,80 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "kvm", - srcs = [ - "address_space.go", - "allocator.go", - "bluepill.go", - "bluepill_amd64.go", - "bluepill_amd64.s", - "bluepill_amd64_unsafe.go", - "bluepill_arm64.go", - "bluepill_arm64.s", - "bluepill_arm64_unsafe.go", - "bluepill_fault.go", - "bluepill_unsafe.go", - "context.go", - "filters_amd64.go", - "filters_arm64.go", - "kvm.go", - "kvm_amd64.go", - "kvm_amd64_unsafe.go", - "kvm_arm64.go", - "kvm_arm64_unsafe.go", - "kvm_const.go", - "kvm_const_arm64.go", - "machine.go", - "machine_amd64.go", - "machine_amd64_unsafe.go", - "machine_arm64.go", - "machine_arm64_unsafe.go", - "machine_unsafe.go", - "physical_map.go", - "physical_map_amd64.go", - "physical_map_arm64.go", - "virtual_map.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/atomicbitops", - "//pkg/cpuid", - "//pkg/log", - "//pkg/procid", - "//pkg/safecopy", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/platform", - "//pkg/sentry/platform/interrupt", - "//pkg/sentry/platform/ring0", - "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/time", - "//pkg/sync", - "//pkg/usermem", - ], -) - -go_test( - name = "kvm_test", - srcs = [ - "kvm_test.go", - "virtual_map_test.go", - ], - library = ":kvm", - tags = [ - "manual", - "nogotsan", - "requires-kvm", - ], - deps = [ - "//pkg/sentry/arch", - "//pkg/sentry/platform", - "//pkg/sentry/platform/kvm/testutil", - "//pkg/sentry/platform/ring0", - "//pkg/sentry/platform/ring0/pagetables", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 552341721..552341721 100644..100755 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 04efa0147..04efa0147 100644..100755 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index eb5ed574e..eb5ed574e 100644..100755 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go diff --git a/pkg/sentry/platform/kvm/filters_amd64.go b/pkg/sentry/platform/kvm/filters_amd64.go index 7d949f1dd..7d949f1dd 100644..100755 --- a/pkg/sentry/platform/kvm/filters_amd64.go +++ b/pkg/sentry/platform/kvm/filters_amd64.go diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go index 9245d07c2..9245d07c2 100644..100755 --- a/pkg/sentry/platform/kvm/filters_arm64.go +++ b/pkg/sentry/platform/kvm/filters_arm64.go diff --git a/pkg/sentry/platform/kvm/kvm_amd64_state_autogen.go b/pkg/sentry/platform/kvm/kvm_amd64_state_autogen.go new file mode 100755 index 000000000..a69cbee8b --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +// +build amd64 +// +build amd64 +// +build amd64 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_amd64_unsafe_state_autogen.go b/pkg/sentry/platform/kvm/kvm_amd64_unsafe_state_autogen.go new file mode 100755 index 000000000..a69cbee8b --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_unsafe_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +// +build amd64 +// +build amd64 +// +build amd64 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 79045651e..79045651e 100644..100755 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go diff --git a/pkg/sentry/platform/kvm/kvm_arm64_state_autogen.go b/pkg/sentry/platform/kvm/kvm_arm64_state_autogen.go new file mode 100755 index 000000000..90183b764 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +// +build arm64 +// +build arm64 +// +build arm64 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go index 6531bae1d..6531bae1d 100644..100755 --- a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe_state_autogen.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe_state_autogen.go new file mode 100755 index 000000000..90183b764 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +// +build arm64 +// +build arm64 +// +build arm64 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index 5a74c6e36..5a74c6e36 100644..100755 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go diff --git a/pkg/sentry/platform/kvm/kvm_state_autogen.go b/pkg/sentry/platform/kvm/kvm_state_autogen.go new file mode 100755 index 000000000..2325262dc --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_state_autogen.go @@ -0,0 +1,8 @@ +// automatically generated by stateify. + +// +build go1.12 +// +build !go1.15 +// +build go1.12 +// +build !go1.15 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go deleted file mode 100644 index c42752d50..000000000 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ /dev/null @@ -1,533 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvm - -import ( - "math/rand" - "reflect" - "sync/atomic" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/usermem" -) - -var dummyFPState = (*byte)(arch.NewFloatingPointData()) - -type testHarness interface { - Errorf(format string, args ...interface{}) - Fatalf(format string, args ...interface{}) -} - -func kvmTest(t testHarness, setup func(*KVM), fn func(*vCPU) bool) { - // Create the machine. - deviceFile, err := OpenDevice() - if err != nil { - t.Fatalf("error opening device file: %v", err) - } - k, err := New(deviceFile) - if err != nil { - t.Fatalf("error creating KVM instance: %v", err) - } - defer k.machine.Destroy() - - // Call additional setup. - if setup != nil { - setup(k) - } - - var c *vCPU // For recovery. - defer func() { - redpill() - if c != nil { - k.machine.Put(c) - } - }() - for { - c = k.machine.Get() - if !fn(c) { - break - } - - // We put the vCPU here and clear the value so that the - // deferred recovery will not re-put it above. - k.machine.Put(c) - c = nil - } -} - -func bluepillTest(t testHarness, fn func(*vCPU)) { - kvmTest(t, nil, func(c *vCPU) bool { - bluepill(c) - fn(c) - return false - }) -} - -func TestKernelSyscall(t *testing.T) { - bluepillTest(t, func(c *vCPU) { - redpill() // Leave guest mode. - if got := atomic.LoadUint32(&c.state); got != vCPUUser { - t.Errorf("vCPU not in ready state: got %v", got) - } - }) -} - -func hostFault() { - defer func() { - recover() - }() - var foo *int - *foo = 0 -} - -func TestKernelFault(t *testing.T) { - hostFault() // Ensure recovery works. - bluepillTest(t, func(c *vCPU) { - hostFault() - if got := atomic.LoadUint32(&c.state); got != vCPUUser { - t.Errorf("vCPU not in ready state: got %v", got) - } - }) -} - -func TestKernelFloatingPoint(t *testing.T) { - bluepillTest(t, func(c *vCPU) { - if !testutil.FloatingPointWorks() { - t.Errorf("floating point does not work, and it should!") - } - }) -} - -func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *syscall.PtraceRegs, *pagetables.PageTables) bool) { - // Initialize registers & page tables. - var ( - regs syscall.PtraceRegs - pt *pagetables.PageTables - ) - testutil.SetTestTarget(®s, target) - - kvmTest(t, func(k *KVM) { - // Create new page tables. - as, _, err := k.NewAddressSpace(nil /* invalidator */) - if err != nil { - t.Fatalf("can't create new address space: %v", err) - } - pt = as.(*addressSpace).pageTables - - if useHostMappings { - // Apply the physical mappings to these page tables. - // (This is normally dangerous, since they point to - // physical pages that may not exist. This shouldn't be - // done for regular user code, but is fine for test - // purposes.) - applyPhysicalRegions(func(pr physicalRegion) bool { - pt.Map(usermem.Addr(pr.virtual), pr.length, pagetables.MapOpts{ - AccessType: usermem.AnyAccess, - User: true, - }, pr.physical) - return true // Keep iterating. - }) - } - }, func(c *vCPU) bool { - // Invoke the function with the extra data. - return fn(c, ®s, pt) - }) -} - -func TestApplicationSyscall(t *testing.T) { - applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != nil { - t.Errorf("application syscall with full restore failed: %v", err) - } - return false - }) - applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != nil { - t.Errorf("application syscall with partial restore failed: %v", err) - } - return false - }) -} - -func TestApplicationFault(t *testing.T) { - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, nil) // Cause fault. - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) { - t.Errorf("application fault with full restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal) - } - return false - }) - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, nil) // Cause fault. - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) { - t.Errorf("application fault with partial restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal) - } - return false - }) -} - -func TestRegistersSyscall(t *testing.T) { - applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTestRegs(regs) // Fill values for all registers. - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application register check with partial restore got unexpected error: %v", err) - } - if err := testutil.CheckTestRegs(regs, false); err != nil { - t.Errorf("application register check with partial restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -func TestRegistersFault(t *testing.T) { - applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTestRegs(regs) // Fill values for all registers. - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) { - t.Errorf("application register check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestRegs(regs, true); err != nil { - t.Errorf("application register check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -func TestSegments(t *testing.T) { - applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTestSegments(regs) - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application segment check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestSegments(regs); err != nil { - t.Errorf("application segment check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -func TestBounce(t *testing.T) { - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - go func() { - time.Sleep(time.Millisecond) - c.BounceToKernel() - }() - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - return false - }) - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - go func() { - time.Sleep(time.Millisecond) - c.BounceToKernel() - }() - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application full restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - return false - }) -} - -func TestBounceStress(t *testing.T) { - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - randomSleep := func() { - // O(hundreds of microseconds) is appropriate to ensure - // different overlaps and different schedules. - if n := rand.Intn(1000); n > 100 { - time.Sleep(time.Duration(n) * time.Microsecond) - } - } - for i := 0; i < 1000; i++ { - // Start an asynchronously executing goroutine that - // calls Bounce at pseudo-random point in time. - // This should wind up calling Bounce when the - // kernel is in various stages of the switch. - go func() { - randomSleep() - c.BounceToKernel() - }() - randomSleep() - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - c.unlock() - randomSleep() - c.lock() - } - return false - }) -} - -func TestInvalidate(t *testing.T) { - var data uintptr // Used below. - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, &data) // Read legitimate value. - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application partial restore: got %v, wanted nil", err) - } - break // Done. - } - // Unmap the page containing data & invalidate. - pt.Unmap(usermem.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(usermem.PageSize-1)), usermem.PageSize) - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - Flush: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != platform.ErrContextSignal { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextSignal) - } - break // Success. - } - return false - }) -} - -// IsFault returns true iff the given signal represents a fault. -func IsFault(err error, si *arch.SignalInfo) bool { - return err == platform.ErrContextSignal && si.Signo == int32(syscall.SIGSEGV) -} - -func TestEmptyAddressSpace(t *testing.T) { - applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if !IsFault(err, &si) { - t.Errorf("first fault with partial restore failed got %v", err) - t.Logf("registers: %#v", ®s) - } - return false - }) - applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if !IsFault(err, &si) { - t.Errorf("first fault with full restore failed got %v", err) - t.Logf("registers: %#v", ®s) - } - return false - }) -} - -func TestWrongVCPU(t *testing.T) { - kvmTest(t, nil, func(c1 *vCPU) bool { - kvmTest(t, nil, func(c2 *vCPU) bool { - // Basic test, one then the other. - bluepill(c1) - bluepill(c2) - if c2.switches == 0 { - // Don't allow the test to proceed if this fails. - t.Fatalf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - - // Alternate vCPUs; we expect to need to trigger the - // wrong vCPU path on each switch. - for i := 0; i < 100; i++ { - bluepill(c1) - bluepill(c2) - } - if count := c1.switches; count < 90 { - t.Errorf("wrong vCPU#1 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - if count := c2.switches; count < 90 { - t.Errorf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - return false - }) - return false - }) - kvmTest(t, nil, func(c1 *vCPU) bool { - kvmTest(t, nil, func(c2 *vCPU) bool { - bluepill(c1) - bluepill(c2) - return false - }) - return false - }) -} - -func BenchmarkApplicationSyscall(b *testing.B) { - var ( - i int // Iteration includes machine.Get() / machine.Put(). - a int // Count for ErrContextInterrupt. - ) - applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - a++ - return true // Ignore. - } else if err != nil { - b.Fatalf("benchmark failed: %v", err) - } - i++ - return i < b.N - }) - if a != 0 { - b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i) - } -} - -func BenchmarkKernelSyscall(b *testing.B) { - // Note that the target passed here is irrelevant, we never execute SwitchToUser. - applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - // iteration does not include machine.Get() / machine.Put(). - for i := 0; i < b.N; i++ { - testutil.Getpid() - } - return false - }) -} - -func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { - // see BenchmarkApplicationSyscall. - var ( - i int - a int - ) - applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - a++ - return true // Ignore. - } else if err != nil { - b.Fatalf("benchmark failed: %v", err) - } - // This will intentionally cause the world switch. By executing - // a host syscall here, we force the transition between guest - // and host mode. - testutil.Getpid() - i++ - return i < b.N - }) - if a != 0 { - b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i) - } -} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 09552837a..09552837a 100644..100755 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index b531f2f85..b531f2f85 100644..100755 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go diff --git a/pkg/sentry/platform/kvm/physical_map_amd64.go b/pkg/sentry/platform/kvm/physical_map_amd64.go index c5adfb577..c5adfb577 100644..100755 --- a/pkg/sentry/platform/kvm/physical_map_amd64.go +++ b/pkg/sentry/platform/kvm/physical_map_amd64.go diff --git a/pkg/sentry/platform/kvm/physical_map_arm64.go b/pkg/sentry/platform/kvm/physical_map_arm64.go index 4d8561453..4d8561453 100644..100755 --- a/pkg/sentry/platform/kvm/physical_map_arm64.go +++ b/pkg/sentry/platform/kvm/physical_map_arm64.go diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD deleted file mode 100644 index f7605df8a..000000000 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - testonly = 1, - srcs = [ - "testutil.go", - "testutil_amd64.go", - "testutil_amd64.s", - "testutil_arm64.go", - "testutil_arm64.s", - ], - visibility = ["//pkg/sentry/platform/kvm:__pkg__"], -) diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go deleted file mode 100644 index 5c1efa0fd..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil provides common assembly stubs for testing. -package testutil - -import ( - "fmt" - "strings" -) - -// Getpid executes a trivial system call. -func Getpid() - -// Touch touches the value in the first register. -func Touch() - -// SyscallLoop executes a syscall and loops. -func SyscallLoop() - -// SpinLoop spins on the CPU. -func SpinLoop() - -// HaltLoop immediately halts and loops. -func HaltLoop() - -// TwiddleRegsFault twiddles registers then faults. -func TwiddleRegsFault() - -// TwiddleRegsSyscall twiddles registers then executes a syscall. -func TwiddleRegsSyscall() - -// FloatingPointWorks is a floating point test. -// -// It returns true or false. -func FloatingPointWorks() bool - -// RegisterMismatchError is used for checking registers. -type RegisterMismatchError []string - -// Error returns a human-readable error. -func (r RegisterMismatchError) Error() string { - return strings.Join([]string(r), ";") -} - -// addRegisterMisatch allows simple chaining of register mismatches. -func addRegisterMismatch(err error, reg string, got, expected interface{}) error { - errStr := fmt.Sprintf("%s got %08x, expected %08x", reg, got, expected) - switch r := err.(type) { - case nil: - // Return a new register mismatch. - return RegisterMismatchError{errStr} - case RegisterMismatchError: - // Append the error. - r = append(r, errStr) - return r - default: - // Leave as is. - return err - } -} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go deleted file mode 100644 index 4c108abbf..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -package testutil - -import ( - "reflect" - "syscall" -) - -// TwiddleSegments reads segments into known registers. -func TwiddleSegments() - -// SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { - regs.Rip = uint64(reflect.ValueOf(fn).Pointer()) -} - -// SetTouchTarget sets rax appropriately. -func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { - if target != nil { - regs.Rax = uint64(reflect.ValueOf(target).Pointer()) - } else { - regs.Rax = 0 - } -} - -// RewindSyscall rewinds a syscall RIP. -func RewindSyscall(regs *syscall.PtraceRegs) { - regs.Rip -= 2 -} - -// SetTestRegs initializes registers to known values. -func SetTestRegs(regs *syscall.PtraceRegs) { - regs.R15 = 0x15 - regs.R14 = 0x14 - regs.R13 = 0x13 - regs.R12 = 0x12 - regs.Rbp = 0xb9 - regs.Rbx = 0xb4 - regs.R11 = 0x11 - regs.R10 = 0x10 - regs.R9 = 0x09 - regs.R8 = 0x08 - regs.Rax = 0x44 - regs.Rcx = 0xc4 - regs.Rdx = 0xd4 - regs.Rsi = 0x51 - regs.Rdi = 0xd1 - regs.Rsp = 0x59 -} - -// CheckTestRegs checks that registers were twiddled per TwiddleRegs. -func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { - if need := ^uint64(0x15); regs.R15 != need { - err = addRegisterMismatch(err, "R15", regs.R15, need) - } - if need := ^uint64(0x14); regs.R14 != need { - err = addRegisterMismatch(err, "R14", regs.R14, need) - } - if need := ^uint64(0x13); regs.R13 != need { - err = addRegisterMismatch(err, "R13", regs.R13, need) - } - if need := ^uint64(0x12); regs.R12 != need { - err = addRegisterMismatch(err, "R12", regs.R12, need) - } - if need := ^uint64(0xb9); regs.Rbp != need { - err = addRegisterMismatch(err, "Rbp", regs.Rbp, need) - } - if need := ^uint64(0xb4); regs.Rbx != need { - err = addRegisterMismatch(err, "Rbx", regs.Rbx, need) - } - if need := ^uint64(0x10); regs.R10 != need { - err = addRegisterMismatch(err, "R10", regs.R10, need) - } - if need := ^uint64(0x09); regs.R9 != need { - err = addRegisterMismatch(err, "R9", regs.R9, need) - } - if need := ^uint64(0x08); regs.R8 != need { - err = addRegisterMismatch(err, "R8", regs.R8, need) - } - if need := ^uint64(0x44); regs.Rax != need { - err = addRegisterMismatch(err, "Rax", regs.Rax, need) - } - if need := ^uint64(0xd4); regs.Rdx != need { - err = addRegisterMismatch(err, "Rdx", regs.Rdx, need) - } - if need := ^uint64(0x51); regs.Rsi != need { - err = addRegisterMismatch(err, "Rsi", regs.Rsi, need) - } - if need := ^uint64(0xd1); regs.Rdi != need { - err = addRegisterMismatch(err, "Rdi", regs.Rdi, need) - } - if need := ^uint64(0x59); regs.Rsp != need { - err = addRegisterMismatch(err, "Rsp", regs.Rsp, need) - } - // Rcx & R11 are ignored if !full is set. - if need := ^uint64(0x11); full && regs.R11 != need { - err = addRegisterMismatch(err, "R11", regs.R11, need) - } - if need := ^uint64(0xc4); full && regs.Rcx != need { - err = addRegisterMismatch(err, "Rcx", regs.Rcx, need) - } - return -} - -var fsData uint64 = 0x55 -var gsData uint64 = 0x85 - -// SetTestSegments initializes segments to known values. -func SetTestSegments(regs *syscall.PtraceRegs) { - regs.Fs_base = uint64(reflect.ValueOf(&fsData).Pointer()) - regs.Gs_base = uint64(reflect.ValueOf(&gsData).Pointer()) -} - -// CheckTestSegments checks that registers were twiddled per TwiddleSegments. -func CheckTestSegments(regs *syscall.PtraceRegs) (err error) { - if regs.Rax != fsData { - err = addRegisterMismatch(err, "Rax", regs.Rax, fsData) - } - if regs.Rbx != gsData { - err = addRegisterMismatch(err, "Rbx", regs.Rcx, gsData) - } - return -} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.s b/pkg/sentry/platform/kvm/testutil/testutil_amd64.s deleted file mode 100644 index 491ec0c2a..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.s +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -// test_util_amd64.s provides AMD64 test functions. - -#include "funcdata.h" -#include "textflag.h" - -TEXT ·Getpid(SB),NOSPLIT,$0 - NO_LOCAL_POINTERS - MOVQ $39, AX // getpid - SYSCALL - RET - -TEXT ·Touch(SB),NOSPLIT,$0 -start: - MOVQ 0(AX), BX // deref AX - MOVQ $39, AX // getpid - SYSCALL - JMP start - -TEXT ·HaltLoop(SB),NOSPLIT,$0 -start: - HLT - JMP start - -TEXT ·SyscallLoop(SB),NOSPLIT,$0 -start: - SYSCALL - JMP start - -TEXT ·SpinLoop(SB),NOSPLIT,$0 -start: - JMP start - -TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 - NO_LOCAL_POINTERS - MOVQ $1, AX - MOVQ AX, X0 - MOVQ $39, AX // getpid - SYSCALL - MOVQ X0, AX - CMPQ AX, $1 - SETEQ ret+0(FP) - RET - -#define TWIDDLE_REGS() \ - NOTQ R15; \ - NOTQ R14; \ - NOTQ R13; \ - NOTQ R12; \ - NOTQ BP; \ - NOTQ BX; \ - NOTQ R11; \ - NOTQ R10; \ - NOTQ R9; \ - NOTQ R8; \ - NOTQ AX; \ - NOTQ CX; \ - NOTQ DX; \ - NOTQ SI; \ - NOTQ DI; \ - NOTQ SP; - -TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 - TWIDDLE_REGS() - SYSCALL - RET // never reached - -TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 - TWIDDLE_REGS() - JMP AX // must fault - RET // never reached - -#define READ_FS() BYTE $0x64; BYTE $0x48; BYTE $0x8b; BYTE $0x00; -#define READ_GS() BYTE $0x65; BYTE $0x48; BYTE $0x8b; BYTE $0x00; - -TEXT ·TwiddleSegments(SB),NOSPLIT,$0 - MOVQ $0x0, AX - READ_GS() - MOVQ AX, BX - MOVQ $0x0, AX - READ_FS() - SYSCALL - RET // never reached diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go deleted file mode 100644 index 40b2e4acc..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package testutil - -import ( - "fmt" - "reflect" - "syscall" -) - -// SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { - regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) -} - -// SetTouchTarget sets rax appropriately. -func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { - if target != nil { - regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer()) - } else { - regs.Regs[8] = 0 - } -} - -// RewindSyscall rewinds a syscall RIP. -func RewindSyscall(regs *syscall.PtraceRegs) { - regs.Pc -= 4 -} - -// SetTestRegs initializes registers to known values. -func SetTestRegs(regs *syscall.PtraceRegs) { - for i := 0; i <= 30; i++ { - regs.Regs[i] = uint64(i) + 1 - } -} - -// CheckTestRegs checks that registers were twiddled per TwiddleRegs. -func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { - for i := 0; i <= 30; i++ { - if need := ^uint64(i + 1); regs.Regs[i] != need { - err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need) - } - } - return -} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s deleted file mode 100644 index 0bebee852..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -// test_util_arm64.s provides ARM64 test functions. - -#include "funcdata.h" -#include "textflag.h" - -#define SYS_GETPID 172 - -// This function simulates the getpid syscall. -TEXT ·Getpid(SB),NOSPLIT,$0 - NO_LOCAL_POINTERS - MOVD $SYS_GETPID, R8 - SVC - RET - -TEXT ·Touch(SB),NOSPLIT,$0 -start: - MOVD 0(R8), R1 - MOVD $SYS_GETPID, R8 // getpid - SVC - B start - -TEXT ·HaltLoop(SB),NOSPLIT,$0 -start: - HLT - B start - -// This function simulates a loop of syscall. -TEXT ·SyscallLoop(SB),NOSPLIT,$0 -start: - SVC - B start - -TEXT ·SpinLoop(SB),NOSPLIT,$0 -start: - B start - -TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 - NO_LOCAL_POINTERS - FMOVD $(9.9), F0 - MOVD $SYS_GETPID, R8 // getpid - SVC - FMOVD $(9.9), F1 - FCMPD F0, F1 - BNE isNaN - MOVD $1, R0 - MOVD R0, ret+0(FP) - RET -isNaN: - MOVD $0, ret+0(FP) - RET - -// MVN: bitwise logical NOT -// This case simulates an application that modified R0-R30. -#define TWIDDLE_REGS() \ - MVN R0, R0; \ - MVN R1, R1; \ - MVN R2, R2; \ - MVN R3, R3; \ - MVN R4, R4; \ - MVN R5, R5; \ - MVN R6, R6; \ - MVN R7, R7; \ - MVN R8, R8; \ - MVN R9, R9; \ - MVN R10, R10; \ - MVN R11, R11; \ - MVN R12, R12; \ - MVN R13, R13; \ - MVN R14, R14; \ - MVN R15, R15; \ - MVN R16, R16; \ - MVN R17, R17; \ - MVN R18_PLATFORM, R18_PLATFORM; \ - MVN R19, R19; \ - MVN R20, R20; \ - MVN R21, R21; \ - MVN R22, R22; \ - MVN R23, R23; \ - MVN R24, R24; \ - MVN R25, R25; \ - MVN R26, R26; \ - MVN R27, R27; \ - MVN g, g; \ - MVN R29, R29; \ - MVN R30, R30; - -TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 - TWIDDLE_REGS() - SVC - RET // never reached diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go deleted file mode 100644 index 327e2be4f..000000000 --- a/pkg/sentry/platform/kvm/virtual_map_test.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvm - -import ( - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/usermem" -) - -type checker struct { - ok bool - accessType usermem.AccessType -} - -func (c *checker) Containing(addr uintptr) func(virtualRegion) { - c.ok = false // Reset for below calls. - return func(vr virtualRegion) { - if vr.virtual <= addr && addr < vr.virtual+vr.length { - c.ok = true - c.accessType = vr.accessType - } - } -} - -func TestParseMaps(t *testing.T) { - c := new(checker) - - // Simple test. - if err := applyVirtualRegions(c.Containing(0)); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // MMap a new page. - addr, _, errno := syscall.RawSyscall6( - syscall.SYS_MMAP, 0, usermem.PageSize, - syscall.PROT_READ|syscall.PROT_WRITE, - syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE, 0, 0) - if errno != 0 { - t.Fatalf("unexpected map error: %v", errno) - } - - // Re-parse maps. - if err := applyVirtualRegions(c.Containing(addr)); err != nil { - syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) - t.Fatalf("unexpected error: %v", err) - } - - // Assert that it now does contain the region. - if !c.ok { - syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) - t.Fatalf("updated map does not contain 0x%08x, expected true", addr) - } - - // Map the region as PROT_NONE. - newAddr, _, errno := syscall.RawSyscall6( - syscall.SYS_MMAP, addr, usermem.PageSize, - syscall.PROT_NONE, - syscall.MAP_ANONYMOUS|syscall.MAP_FIXED|syscall.MAP_PRIVATE, 0, 0) - if errno != 0 { - t.Fatalf("unexpected map error: %v", errno) - } - if newAddr != addr { - t.Fatalf("unable to remap address: got 0x%08x, wanted 0x%08x", newAddr, addr) - } - - // Re-parse maps. - if err := applyVirtualRegions(c.Containing(addr)); err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !c.ok { - t.Fatalf("final map does not contain 0x%08x, expected true", addr) - } - if c.accessType.Read || c.accessType.Write || c.accessType.Execute { - t.Fatalf("final map has incorrect permissions for 0x%08x", addr) - } - - // Unmap the region. - syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) -} diff --git a/pkg/sentry/platform/platform_state_autogen.go b/pkg/sentry/platform/platform_state_autogen.go new file mode 100755 index 000000000..7597195ef --- /dev/null +++ b/pkg/sentry/platform/platform_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package platform + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FileRange) beforeSave() {} +func (x *FileRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *FileRange) afterLoad() {} +func (x *FileRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func init() { + state.Register("pkg/sentry/platform.FileRange", (*FileRange)(nil), state.Fns{Save: (*FileRange).save, Load: (*FileRange).load}) +} diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD deleted file mode 100644 index 30402c2df..000000000 --- a/pkg/sentry/platform/ptrace/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "ptrace", - srcs = [ - "filters.go", - "ptrace.go", - "ptrace_amd64.go", - "ptrace_arm64.go", - "ptrace_arm64_unsafe.go", - "ptrace_unsafe.go", - "stub_amd64.s", - "stub_arm64.s", - "stub_unsafe.go", - "subprocess.go", - "subprocess_amd64.go", - "subprocess_arm64.go", - "subprocess_linux.go", - "subprocess_linux_unsafe.go", - "subprocess_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/procid", - "//pkg/safecopy", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/hostcpu", - "//pkg/sentry/platform", - "//pkg/sentry/platform/interrupt", - "//pkg/sync", - "//pkg/usermem", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_amd64_state_autogen.go new file mode 100755 index 000000000..f730ab393 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_arm64_state_autogen.go new file mode 100755 index 000000000..6239d1305 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go index 32b8a6be9..32b8a6be9 100644..100755 --- a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe_state_autogen.go new file mode 100755 index 000000000..6239d1305 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_linux_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_linux_state_autogen.go new file mode 100755 index 000000000..9f90aef93 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_linux_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build linux + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_linux_unsafe_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_linux_unsafe_state_autogen.go new file mode 100755 index 000000000..45d94c547 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_linux_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build linux +// +build amd64 arm64 + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_state_autogen.go new file mode 100755 index 000000000..4526fc387 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build go1.12 +// +build !go1.15 + +package ptrace diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD deleted file mode 100644 index 934b6fbcd..000000000 --- a/pkg/sentry/platform/ring0/BUILD +++ /dev/null @@ -1,83 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_template( - name = "defs_amd64", - srcs = [ - "defs.go", - "defs_amd64.go", - "offsets_amd64.go", - "x86.go", - ], - visibility = [":__subpackages__"], -) - -go_template( - name = "defs_arm64", - srcs = [ - "aarch64.go", - "defs.go", - "defs_arm64.go", - "offsets_arm64.go", - ], - visibility = [":__subpackages__"], -) - -go_template_instance( - name = "defs_impl_amd64", - out = "defs_impl_amd64.go", - package = "ring0", - template = ":defs_amd64", -) - -go_template_instance( - name = "defs_impl_arm64", - out = "defs_impl_arm64.go", - package = "ring0", - template = ":defs_arm64", -) - -genrule( - name = "entry_impl_amd64", - srcs = ["entry_amd64.s"], - outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], -) - -genrule( - name = "entry_impl_arm64", - srcs = ["entry_arm64.s"], - outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], -) - -go_library( - name = "ring0", - srcs = [ - "defs_impl_amd64.go", - "defs_impl_arm64.go", - "entry_amd64.go", - "entry_arm64.go", - "entry_impl_amd64.s", - "entry_impl_arm64.s", - "kernel.go", - "kernel_amd64.go", - "kernel_arm64.go", - "kernel_unsafe.go", - "lib_amd64.go", - "lib_amd64.s", - "lib_arm64.go", - "lib_arm64.s", - "ring0.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/cpuid", - "//pkg/sentry/platform/ring0/pagetables", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go deleted file mode 100644 index 8122ac6e2..000000000 --- a/pkg/sentry/platform/ring0/aarch64.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package ring0 - -// Useful bits. -const ( - _PGD_PGT_BASE = 0x1000 - _PGD_PGT_SIZE = 0x1000 - _PUD_PGT_BASE = 0x2000 - _PUD_PGT_SIZE = 0x1000 - _PMD_PGT_BASE = 0x3000 - _PMD_PGT_SIZE = 0x4000 - _PTE_PGT_BASE = 0x7000 - _PTE_PGT_SIZE = 0x1000 - - _PSR_D_BIT = 0x00000200 - _PSR_A_BIT = 0x00000100 - _PSR_I_BIT = 0x00000080 - _PSR_F_BIT = 0x00000040 -) - -const ( - // PSR bits - PSR_MODE_EL0t = 0x00000000 - PSR_MODE_EL1t = 0x00000004 - PSR_MODE_EL1h = 0x00000005 - PSR_MODE_MASK = 0x0000000f - - // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = PSR_MODE_EL1h - - // UserFlagsSet are always set in userspace. - UserFlagsSet = PSR_MODE_EL0t - - KernelFlagsClear = PSR_MODE_MASK - UserFlagsClear = PSR_MODE_MASK - - PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT -) - -// Vector is an exception vector. -type Vector uintptr - -// Exception vectors. -const ( - El1SyncInvalid = iota - El1IrqInvalid - El1FiqInvalid - El1ErrorInvalid - El1Sync - El1Irq - El1Fiq - El1Error - El0Sync - El0Irq - El0Fiq - El0Error - El0Sync_invalid - El0Irq_invalid - El0Fiq_invalid - El0Error_invalid - El1Sync_da - El1Sync_ia - El1Sync_sp_pc - El1Sync_undef - El1Sync_dbg - El1Sync_inv - El0Sync_svc - El0Sync_da - El0Sync_ia - El0Sync_fpsimd_acc - El0Sync_sve_acc - El0Sync_sys - El0Sync_sp_pc - El0Sync_undef - El0Sync_dbg - El0Sync_inv - _NR_INTERRUPTS -) - -// System call vectors. -const ( - Syscall Vector = El0Sync_svc - PageFault Vector = El0Sync_da - VirtualizationException Vector = El0Error -) - -// VirtualAddressBits returns the number bits available for virtual addresses. -func VirtualAddressBits() uint32 { - return 48 -} - -// PhysicalAddressBits returns the number of bits available for physical addresses. -func PhysicalAddressBits() uint32 { - return 40 -} diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go deleted file mode 100644 index 86fd5ed58..000000000 --- a/pkg/sentry/platform/ring0/defs.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ring0 - -import ( - "syscall" - - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" -) - -// Kernel is a global kernel object. -// -// This contains global state, shared by multiple CPUs. -type Kernel struct { - KernelArchState -} - -// Hooks are hooks for kernel functions. -type Hooks interface { - // KernelSyscall is called for kernel system calls. - // - // Return from this call will restore registers and return to the kernel: the - // registers must be modified directly. - // - // If this function is not provided, a kernel exception results in halt. - // - // This must be go:nosplit, as this will be on the interrupt stack. - // Closures are permitted, as the pointer to the closure frame is not - // passed on the stack. - KernelSyscall() - - // KernelException handles an exception during kernel execution. - // - // Return from this call will restore registers and return to the kernel: the - // registers must be modified directly. - // - // If this function is not provided, a kernel exception results in halt. - // - // This must be go:nosplit, as this will be on the interrupt stack. - // Closures are permitted, as the pointer to the closure frame is not - // passed on the stack. - KernelException(Vector) -} - -// CPU is the per-CPU struct. -type CPU struct { - // self is a self reference. - // - // This is always guaranteed to be at offset zero. - self *CPU - - // kernel is reference to the kernel that this CPU was initialized - // with. This reference is kept for garbage collection purposes: CPU - // registers may refer to objects within the Kernel object that cannot - // be safely freed. - kernel *Kernel - - // CPUArchState is architecture-specific state. - CPUArchState - - // registers is a set of registers; these may be used on kernel system - // calls and exceptions via the Registers function. - registers syscall.PtraceRegs - - // hooks are kernel hooks. - hooks Hooks -} - -// Registers returns a modifiable-copy of the kernel registers. -// -// This is explicitly safe to call during KernelException and KernelSyscall. -// -//go:nosplit -func (c *CPU) Registers() *syscall.PtraceRegs { - return &c.registers -} - -// SwitchOpts are passed to the Switch function. -type SwitchOpts struct { - // Registers are the user register state. - Registers *syscall.PtraceRegs - - // FloatingPointState is a byte pointer where floating point state is - // saved and restored. - FloatingPointState *byte - - // PageTables are the application page tables. - PageTables *pagetables.PageTables - - // Flush indicates that a TLB flush should be forced on switch. - Flush bool - - // FullRestore indicates that an iret-based restore should be used. - FullRestore bool - - // SwitchArchOpts are architecture-specific options. - SwitchArchOpts -} diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go deleted file mode 100644 index 9c6c2cf5c..000000000 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/usermem" -) - -var ( - // UserspaceSize is the total size of userspace. - UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) - - // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) - - // KernelStartAddress is the starting kernel address. - KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) -) - -// Segment indices and Selectors. -const ( - // Index into GDT array. - _ = iota // Null descriptor first. - _ // Reserved (Linux is kernel 32). - segKcode // Kernel code (64-bit). - segKdata // Kernel data. - segUcode32 // User code (32-bit). - segUdata // User data. - segUcode64 // User code (64-bit). - segTss // Task segment descriptor. - segTssHi // Upper bits for TSS. - segLast // Last segment (terminal, not included). -) - -// Selectors. -const ( - Kcode Selector = segKcode << 3 - Kdata Selector = segKdata << 3 - Ucode32 Selector = (segUcode32 << 3) | 3 - Udata Selector = (segUdata << 3) | 3 - Ucode64 Selector = (segUcode64 << 3) | 3 - Tss Selector = segTss << 3 -) - -// Standard segments. -var ( - UserCodeSegment32 SegmentDescriptor - UserDataSegment SegmentDescriptor - UserCodeSegment64 SegmentDescriptor - KernelCodeSegment SegmentDescriptor - KernelDataSegment SegmentDescriptor -) - -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - -// KernelArchState contains architecture-specific state. -type KernelArchState struct { - KernelOpts - - // globalIDT is our set of interrupt gates. - globalIDT idt64 -} - -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { - // stack is the stack used for interrupts on this CPU. - stack [256]byte - - // errorCode is the error code from the last exception. - errorCode uintptr - - // errorType indicates the type of error code here, it is always set - // along with the errorCode value above. - // - // It will either by 1, which indicates a user error, or 0 indicating a - // kernel error. If the error code below returns false (kernel error), - // then it cannot provide relevant information about the last - // exception. - errorType uintptr - - // gdt is the CPU's descriptor table. - gdt descriptorTable - - // tss is the CPU's task state. - tss TaskState64 -} - -// ErrorCode returns the last error code. -// -// The returned boolean indicates whether the error code corresponds to the -// last user error or not. If it does not, then fault information must be -// ignored. This is generally the result of a kernel fault while servicing a -// user fault. -// -//go:nosplit -func (c *CPU) ErrorCode() (value uintptr, user bool) { - return c.errorCode, c.errorType != 0 -} - -// ClearErrorCode resets the error code. -// -//go:nosplit -func (c *CPU) ClearErrorCode() { - c.errorCode = 0 // No code. - c.errorType = 1 // User mode. -} - -// SwitchArchOpts are embedded in SwitchOpts. -type SwitchArchOpts struct { - // UserPCID indicates that the application PCID to be used on switch, - // assuming that PCIDs are supported. - // - // Per pagetables_x86.go, a zero PCID implies a flush. - UserPCID uint16 - - // KernelPCID indicates that the kernel PCID to be used on return, - // assuming that PCIDs are supported. - // - // Per pagetables_x86.go, a zero PCID implies a flush. - KernelPCID uint16 -} - -func init() { - KernelCodeSegment.setCode64(0, 0, 0) - KernelDataSegment.setData(0, 0xffffffff, 0) - UserCodeSegment32.setCode64(0, 0, 3) - UserDataSegment.setData(0, 0xffffffff, 3) - UserCodeSegment64.setCode64(0, 0, 3) -} diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go deleted file mode 100644 index 1583dda12..000000000 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/usermem" -) - -var ( - // UserspaceSize is the total size of userspace. - UserspaceSize = uintptr(1) << (VirtualAddressBits()) - - // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) - - // KernelStartAddress is the starting kernel address. - KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) -) - -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - -// KernelArchState contains architecture-specific state. -type KernelArchState struct { - KernelOpts -} - -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { - // stack is the stack used for interrupts on this CPU. - stack [512]byte - - // errorCode is the error code from the last exception. - errorCode uintptr - - // errorType indicates the type of error code here, it is always set - // along with the errorCode value above. - // - // It will either by 1, which indicates a user error, or 0 indicating a - // kernel error. If the error code below returns false (kernel error), - // then it cannot provide relevant information about the last - // exception. - errorType uintptr - - // faultAddr is the value of far_el1. - faultAddr uintptr - - // ttbr0Kvm is the value of ttbr0_el1 for sentry. - ttbr0Kvm uintptr - - // ttbr0App is the value of ttbr0_el1 for applicaton. - ttbr0App uintptr - - // exception vector. - vecCode Vector - - // application context pointer. - appAddr uintptr - - // lazyVFP is the value of cpacr_el1. - lazyVFP uintptr -} - -// ErrorCode returns the last error code. -// -// The returned boolean indicates whether the error code corresponds to the -// last user error or not. If it does not, then fault information must be -// ignored. This is generally the result of a kernel fault while servicing a -// user fault. -// -//go:nosplit -func (c *CPU) ErrorCode() (value uintptr, user bool) { - return c.errorCode, c.errorType != 0 -} - -// ClearErrorCode resets the error code. -// -//go:nosplit -func (c *CPU) ClearErrorCode() { - c.errorCode = 0 // No code. - c.errorType = 1 // User mode. -} - -//go:nosplit -func (c *CPU) GetFaultAddr() (value uintptr) { - return c.faultAddr -} - -//go:nosplit -func (c *CPU) SetTtbr0Kvm(value uintptr) { - c.ttbr0Kvm = value -} - -//go:nosplit -func (c *CPU) SetTtbr0App(value uintptr) { - c.ttbr0App = value -} - -//go:nosplit -func (c *CPU) GetVector() (value Vector) { - return c.vecCode -} - -//go:nosplit -func (c *CPU) SetAppAddr(value uintptr) { - c.appAddr = value -} - -// SwitchArchOpts are embedded in SwitchOpts. -type SwitchArchOpts struct { - // UserASID indicates that the application ASID to be used on switch, - UserASID uint16 - - // KernelASID indicates that the kernel ASID to be used on return, - KernelASID uint16 -} - -func init() { -} diff --git a/pkg/sentry/platform/ring0/defs_impl_amd64.go b/pkg/sentry/platform/ring0/defs_impl_amd64.go new file mode 100755 index 000000000..178eabd85 --- /dev/null +++ b/pkg/sentry/platform/ring0/defs_impl_amd64.go @@ -0,0 +1,538 @@ +package ring0 + +import ( + "fmt" + "gvisor.dev/gvisor/pkg/cpuid" + "io" + "reflect" + "syscall" + + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Kernel is a global kernel object. +// +// This contains global state, shared by multiple CPUs. +type Kernel struct { + KernelArchState +} + +// Hooks are hooks for kernel functions. +type Hooks interface { + // KernelSyscall is called for kernel system calls. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelSyscall() + + // KernelException handles an exception during kernel execution. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelException(Vector) +} + +// CPU is the per-CPU struct. +type CPU struct { + // self is a self reference. + // + // This is always guaranteed to be at offset zero. + self *CPU + + // kernel is reference to the kernel that this CPU was initialized + // with. This reference is kept for garbage collection purposes: CPU + // registers may refer to objects within the Kernel object that cannot + // be safely freed. + kernel *Kernel + + // CPUArchState is architecture-specific state. + CPUArchState + + // registers is a set of registers; these may be used on kernel system + // calls and exceptions via the Registers function. + registers syscall.PtraceRegs + + // hooks are kernel hooks. + hooks Hooks +} + +// Registers returns a modifiable-copy of the kernel registers. +// +// This is explicitly safe to call during KernelException and KernelSyscall. +// +//go:nosplit +func (c *CPU) Registers() *syscall.PtraceRegs { + return &c.registers +} + +// SwitchOpts are passed to the Switch function. +type SwitchOpts struct { + // Registers are the user register state. + Registers *syscall.PtraceRegs + + // FloatingPointState is a byte pointer where floating point state is + // saved and restored. + FloatingPointState *byte + + // PageTables are the application page tables. + PageTables *pagetables.PageTables + + // Flush indicates that a TLB flush should be forced on switch. + Flush bool + + // FullRestore indicates that an iret-based restore should be used. + FullRestore bool + + // SwitchArchOpts are architecture-specific options. + SwitchArchOpts +} + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// Segment indices and Selectors. +const ( + // Index into GDT array. + _ = iota // Null descriptor first. + _ // Reserved (Linux is kernel 32). + segKcode // Kernel code (64-bit). + segKdata // Kernel data. + segUcode32 // User code (32-bit). + segUdata // User data. + segUcode64 // User code (64-bit). + segTss // Task segment descriptor. + segTssHi // Upper bits for TSS. + segLast // Last segment (terminal, not included). +) + +// Selectors. +const ( + Kcode Selector = segKcode << 3 + Kdata Selector = segKdata << 3 + Ucode32 Selector = (segUcode32 << 3) | 3 + Udata Selector = (segUdata << 3) | 3 + Ucode64 Selector = (segUcode64 << 3) | 3 + Tss Selector = segTss << 3 +) + +// Standard segments. +var ( + UserCodeSegment32 SegmentDescriptor + UserDataSegment SegmentDescriptor + UserCodeSegment64 SegmentDescriptor + KernelCodeSegment SegmentDescriptor + KernelDataSegment SegmentDescriptor +) + +// KernelOpts has initialization options for the kernel. +type KernelOpts struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables +} + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { + KernelOpts + + // globalIDT is our set of interrupt gates. + globalIDT idt64 +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // stack is the stack used for interrupts on this CPU. + stack [256]byte + + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + // gdt is the CPU's descriptor table. + gdt descriptorTable + + // tss is the CPU's task state. + tss TaskState64 +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +// ClearErrorCode resets the error code. +// +//go:nosplit +func (c *CPU) ClearErrorCode() { + c.errorCode = 0 + c.errorType = 1 +} + +// SwitchArchOpts are embedded in SwitchOpts. +type SwitchArchOpts struct { + // UserPCID indicates that the application PCID to be used on switch, + // assuming that PCIDs are supported. + // + // Per pagetables_x86.go, a zero PCID implies a flush. + UserPCID uint16 + + // KernelPCID indicates that the kernel PCID to be used on return, + // assuming that PCIDs are supported. + // + // Per pagetables_x86.go, a zero PCID implies a flush. + KernelPCID uint16 +} + +func init() { + KernelCodeSegment.setCode64(0, 0, 0) + KernelDataSegment.setData(0, 0xffffffff, 0) + UserCodeSegment32.setCode64(0, 0, 3) + UserDataSegment.setData(0, 0xffffffff, 3) + UserCodeSegment64.setCode64(0, 0, 3) +} + +// Emit prints architecture-specific offsets. +func Emit(w io.Writer) { + fmt.Fprintf(w, "// Automatically generated, do not edit.\n") + + c := &CPU{} + fmt.Fprintf(w, "\n// CPU offsets.\n") + fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) + fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + + fmt.Fprintf(w, "\n// Bits.\n") + fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) + fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) + + fmt.Fprintf(w, "\n// Vectors.\n") + fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero) + fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug) + fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI) + fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint) + fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow) + fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded) + fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode) + fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable) + fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault) + fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun) + fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS) + fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent) + fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault) + fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault) + fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) + fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException) + fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck) + fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck) + fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException) + fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException) + fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException) + fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80) + fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) + + p := &syscall.PtraceRegs{} + fmt.Fprintf(w, "\n// Ptrace registers.\n") + fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_FS 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_GS 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer()) +} + +// Useful bits. +const ( + _CR0_PE = 1 << 0 + _CR0_ET = 1 << 4 + _CR0_AM = 1 << 18 + _CR0_PG = 1 << 31 + + _CR4_PSE = 1 << 4 + _CR4_PAE = 1 << 5 + _CR4_PGE = 1 << 7 + _CR4_OSFXSR = 1 << 9 + _CR4_OSXMMEXCPT = 1 << 10 + _CR4_FSGSBASE = 1 << 16 + _CR4_PCIDE = 1 << 17 + _CR4_OSXSAVE = 1 << 18 + _CR4_SMEP = 1 << 20 + + _RFLAGS_AC = 1 << 18 + _RFLAGS_NT = 1 << 14 + _RFLAGS_IOPL = 3 << 12 + _RFLAGS_DF = 1 << 10 + _RFLAGS_IF = 1 << 9 + _RFLAGS_STEP = 1 << 8 + _RFLAGS_RESERVED = 1 << 1 + + _EFER_SCE = 0x001 + _EFER_LME = 0x100 + _EFER_LMA = 0x400 + _EFER_NX = 0x800 + + _MSR_STAR = 0xc0000081 + _MSR_LSTAR = 0xc0000082 + _MSR_CSTAR = 0xc0000083 + _MSR_SYSCALL_MASK = 0xc0000084 + _MSR_PLATFORM_INFO = 0xce + _MSR_MISC_FEATURES = 0x140 + + _PLATFORM_INFO_CPUID_FAULT = 1 << 31 + + _MISC_FEATURE_CPUID_TRAP = 0x1 +) + +const ( + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _RFLAGS_RESERVED + + // UserFlagsSet are always set in userspace. + UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF + + // KernelFlagsClear should always be clear in the kernel. + KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT + + // UserFlagsClear are always cleared in userspace. + UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL +) + +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + DivideByZero Vector = iota + Debug + NMI + Breakpoint + Overflow + BoundRangeExceeded + InvalidOpcode + DeviceNotAvailable + DoubleFault + CoprocessorSegmentOverrun + InvalidTSS + SegmentNotPresent + StackSegmentFault + GeneralProtectionFault + PageFault + _ + X87FloatingPointException + AlignmentCheck + MachineCheck + SIMDFloatingPointException + VirtualizationException + SecurityException = 0x1e + SyscallInt80 = 0x80 + _NR_INTERRUPTS = SyscallInt80 + 1 +) + +// System call vectors. +const ( + Syscall Vector = _NR_INTERRUPTS +) + +// VirtualAddressBits returns the number bits available for virtual addresses. +// +// Note that sign-extension semantics apply to the highest order bit. +// +// FIXME(b/69382326): This should use the cpuid passed to Init. +func VirtualAddressBits() uint32 { + ax, _, _, _ := cpuid.HostID(0x80000008, 0) + return (ax >> 8) & 0xff +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +// +// FIXME(b/69382326): This should use the cpuid passed to Init. +func PhysicalAddressBits() uint32 { + ax, _, _, _ := cpuid.HostID(0x80000008, 0) + return ax & 0xff +} + +// Selector is a segment Selector. +type Selector uint16 + +// SegmentDescriptor is a segment descriptor. +type SegmentDescriptor struct { + bits [2]uint32 +} + +// descriptorTable is a collection of descriptors. +type descriptorTable [32]SegmentDescriptor + +// SegmentDescriptorFlags are typed flags within a descriptor. +type SegmentDescriptorFlags uint32 + +// SegmentDescriptorFlag declarations. +const ( + SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set). + SegmentDescriptorWrite = 1 << 9 // Write permission. + SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used. + SegmentDescriptorExecute = 1 << 11 // Execute permission. + SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data. + SegmentDescriptorPresent = 1 << 15 // Present. + SegmentDescriptorAVL = 1 << 20 // Available. + SegmentDescriptorLong = 1 << 21 // Long mode. + SegmentDescriptorDB = 1 << 22 // 16 or 32-bit. + SegmentDescriptorG = 1 << 23 // Granularity: page or byte. +) + +// Base returns the descriptor's base linear address. +func (d *SegmentDescriptor) Base() uint32 { + return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16 +} + +// Limit returns the descriptor size. +func (d *SegmentDescriptor) Limit() uint32 { + l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000 + if d.bits[1]&uint32(SegmentDescriptorG) != 0 { + l <<= 12 + l |= 0xFFF + } + return l +} + +// Flags returns descriptor flags. +func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags { + return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00) +} + +// DPL returns the descriptor privilege level. +func (d *SegmentDescriptor) DPL() int { + return int((d.bits[1] >> 13) & 3) +} + +func (d *SegmentDescriptor) setNull() { + d.bits[0] = 0 + d.bits[1] = 0 +} + +func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) { + flags |= SegmentDescriptorPresent + if limit>>12 != 0 { + limit >>= 12 + flags |= SegmentDescriptorG + } + d.bits[0] = base<<16 | limit&0xFFFF + d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13 +} + +func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorDB| + SegmentDescriptorExecute| + SegmentDescriptorSystem) +} + +func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorG| + SegmentDescriptorLong| + SegmentDescriptorExecute| + SegmentDescriptorSystem) +} + +func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorWrite| + SegmentDescriptorSystem) +} + +// setHi is only used for the TSS segment, which is magically 64-bits. +func (d *SegmentDescriptor) setHi(base uint32) { + d.bits[0] = base + d.bits[1] = 0 +} + +// Gate64 is a 64-bit task, trap, or interrupt gate. +type Gate64 struct { + bits [4]uint32 +} + +// idt64 is a 64-bit interrupt descriptor table. +type idt64 [_NR_INTERRUPTS]Gate64 + +func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) { + g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF + g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7 + g.bits[2] = uint32(rip >> 32) +} + +func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) { + g.setInterrupt(cs, rip, dpl, ist) + g.bits[1] |= 1 << 8 +} + +// TaskState64 is a 64-bit task state structure. +type TaskState64 struct { + _ uint32 + rsp0Lo, rsp0Hi uint32 + rsp1Lo, rsp1Hi uint32 + rsp2Lo, rsp2Hi uint32 + _ [2]uint32 + ist1Lo, ist1Hi uint32 + ist2Lo, ist2Hi uint32 + ist3Lo, ist3Hi uint32 + ist4Lo, ist4Hi uint32 + ist5Lo, ist5Hi uint32 + ist6Lo, ist6Hi uint32 + ist7Lo, ist7Hi uint32 + _ [2]uint32 + _ uint16 + ioPerm uint16 +} diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/defs_impl_arm64.go index 057fb5c69..23d30c1ef 100644..100755 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/defs_impl_arm64.go @@ -1,19 +1,3 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - package ring0 import ( @@ -21,8 +5,308 @@ import ( "io" "reflect" "syscall" + + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Useful bits. +const ( + _PGD_PGT_BASE = 0x1000 + _PGD_PGT_SIZE = 0x1000 + _PUD_PGT_BASE = 0x2000 + _PUD_PGT_SIZE = 0x1000 + _PMD_PGT_BASE = 0x3000 + _PMD_PGT_SIZE = 0x4000 + _PTE_PGT_BASE = 0x7000 + _PTE_PGT_SIZE = 0x1000 + + _PSR_D_BIT = 0x00000200 + _PSR_A_BIT = 0x00000100 + _PSR_I_BIT = 0x00000080 + _PSR_F_BIT = 0x00000040 +) + +const ( + // PSR bits + PSR_MODE_EL0t = 0x00000000 + PSR_MODE_EL1t = 0x00000004 + PSR_MODE_EL1h = 0x00000005 + PSR_MODE_MASK = 0x0000000f + + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = PSR_MODE_EL1h + + // UserFlagsSet are always set in userspace. + UserFlagsSet = PSR_MODE_EL0t + + KernelFlagsClear = PSR_MODE_MASK + UserFlagsClear = PSR_MODE_MASK + + PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT ) +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + El1SyncInvalid = iota + El1IrqInvalid + El1FiqInvalid + El1ErrorInvalid + El1Sync + El1Irq + El1Fiq + El1Error + El0Sync + El0Irq + El0Fiq + El0Error + El0Sync_invalid + El0Irq_invalid + El0Fiq_invalid + El0Error_invalid + El1Sync_da + El1Sync_ia + El1Sync_sp_pc + El1Sync_undef + El1Sync_dbg + El1Sync_inv + El0Sync_svc + El0Sync_da + El0Sync_ia + El0Sync_fpsimd_acc + El0Sync_sve_acc + El0Sync_sys + El0Sync_sp_pc + El0Sync_undef + El0Sync_dbg + El0Sync_inv + _NR_INTERRUPTS +) + +// System call vectors. +const ( + Syscall Vector = El0Sync_svc + PageFault Vector = El0Sync_da + VirtualizationException Vector = El0Error +) + +// VirtualAddressBits returns the number bits available for virtual addresses. +func VirtualAddressBits() uint32 { + return 48 +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +func PhysicalAddressBits() uint32 { + return 40 +} + +// Kernel is a global kernel object. +// +// This contains global state, shared by multiple CPUs. +type Kernel struct { + KernelArchState +} + +// Hooks are hooks for kernel functions. +type Hooks interface { + // KernelSyscall is called for kernel system calls. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelSyscall() + + // KernelException handles an exception during kernel execution. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelException(Vector) +} + +// CPU is the per-CPU struct. +type CPU struct { + // self is a self reference. + // + // This is always guaranteed to be at offset zero. + self *CPU + + // kernel is reference to the kernel that this CPU was initialized + // with. This reference is kept for garbage collection purposes: CPU + // registers may refer to objects within the Kernel object that cannot + // be safely freed. + kernel *Kernel + + // CPUArchState is architecture-specific state. + CPUArchState + + // registers is a set of registers; these may be used on kernel system + // calls and exceptions via the Registers function. + registers syscall.PtraceRegs + + // hooks are kernel hooks. + hooks Hooks +} + +// Registers returns a modifiable-copy of the kernel registers. +// +// This is explicitly safe to call during KernelException and KernelSyscall. +// +//go:nosplit +func (c *CPU) Registers() *syscall.PtraceRegs { + return &c.registers +} + +// SwitchOpts are passed to the Switch function. +type SwitchOpts struct { + // Registers are the user register state. + Registers *syscall.PtraceRegs + + // FloatingPointState is a byte pointer where floating point state is + // saved and restored. + FloatingPointState *byte + + // PageTables are the application page tables. + PageTables *pagetables.PageTables + + // Flush indicates that a TLB flush should be forced on switch. + Flush bool + + // FullRestore indicates that an iret-based restore should be used. + FullRestore bool + + // SwitchArchOpts are architecture-specific options. + SwitchArchOpts +} + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits()) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// KernelOpts has initialization options for the kernel. +type KernelOpts struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables +} + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { + KernelOpts +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // stack is the stack used for interrupts on this CPU. + stack [512]byte + + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + // faultAddr is the value of far_el1. + faultAddr uintptr + + // ttbr0Kvm is the value of ttbr0_el1 for sentry. + ttbr0Kvm uintptr + + // ttbr0App is the value of ttbr0_el1 for applicaton. + ttbr0App uintptr + + // exception vector. + vecCode Vector + + // application context pointer. + appAddr uintptr + + // lazyVFP is the value of cpacr_el1. + lazyVFP uintptr +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +// ClearErrorCode resets the error code. +// +//go:nosplit +func (c *CPU) ClearErrorCode() { + c.errorCode = 0 + c.errorType = 1 +} + +//go:nosplit +func (c *CPU) GetFaultAddr() (value uintptr) { + return c.faultAddr +} + +//go:nosplit +func (c *CPU) SetTtbr0Kvm(value uintptr) { + c.ttbr0Kvm = value +} + +//go:nosplit +func (c *CPU) SetTtbr0App(value uintptr) { + c.ttbr0App = value +} + +//go:nosplit +func (c *CPU) GetVector() (value Vector) { + return c.vecCode +} + +//go:nosplit +func (c *CPU) SetAppAddr(value uintptr) { + c.appAddr = value +} + +// SwitchArchOpts are embedded in SwitchOpts. +type SwitchArchOpts struct { + // UserASID indicates that the application ASID to be used on switch, + UserASID uint16 + + // KernelASID indicates that the kernel ASID to be used on return, + KernelASID uint16 +} + +func init() { +} + // Emit prints architecture-specific offsets. func Emit(w io.Writer) { fmt.Fprintf(w, "// Automatically generated, do not edit.\n") diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/sentry/platform/ring0/entry_arm64.go index 62a93f3d6..62a93f3d6 100644..100755 --- a/pkg/sentry/platform/ring0/entry_arm64.go +++ b/pkg/sentry/platform/ring0/entry_arm64.go diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_impl_amd64.s index 02df38331..daba45f9d 100644..100755 --- a/pkg/sentry/platform/ring0/entry_amd64.s +++ b/pkg/sentry/platform/ring0/entry_impl_amd64.s @@ -1,3 +1,67 @@ +// build +amd64 + +// Automatically generated, do not edit. + +// CPU offsets. +#define CPU_SELF 0x00 +#define CPU_REGISTERS 0x288 +#define CPU_STACK_TOP 0x110 +#define CPU_ERROR_CODE 0x110 +#define CPU_ERROR_TYPE 0x118 + +// Bits. +#define _RFLAGS_IF 0x200 +#define _KERNEL_FLAGS 0x02 + +// Vectors. +#define DivideByZero 0x00 +#define Debug 0x01 +#define NMI 0x02 +#define Breakpoint 0x03 +#define Overflow 0x04 +#define BoundRangeExceeded 0x05 +#define InvalidOpcode 0x06 +#define DeviceNotAvailable 0x07 +#define DoubleFault 0x08 +#define CoprocessorSegmentOverrun 0x09 +#define InvalidTSS 0x0a +#define SegmentNotPresent 0x0b +#define StackSegmentFault 0x0c +#define GeneralProtectionFault 0x0d +#define PageFault 0x0e +#define X87FloatingPointException 0x10 +#define AlignmentCheck 0x11 +#define MachineCheck 0x12 +#define SIMDFloatingPointException 0x13 +#define VirtualizationException 0x14 +#define SecurityException 0x1e +#define SyscallInt80 0x80 +#define Syscall 0x81 + +// Ptrace registers. +#define PTRACE_R15 0x00 +#define PTRACE_R14 0x08 +#define PTRACE_R13 0x10 +#define PTRACE_R12 0x18 +#define PTRACE_RBP 0x20 +#define PTRACE_RBX 0x28 +#define PTRACE_R11 0x30 +#define PTRACE_R10 0x38 +#define PTRACE_R9 0x40 +#define PTRACE_R8 0x48 +#define PTRACE_RAX 0x50 +#define PTRACE_RCX 0x58 +#define PTRACE_RDX 0x60 +#define PTRACE_RSI 0x68 +#define PTRACE_RDI 0x70 +#define PTRACE_ORIGRAX 0x78 +#define PTRACE_RIP 0x80 +#define PTRACE_CS 0x88 +#define PTRACE_FLAGS 0x90 +#define PTRACE_RSP 0x98 +#define PTRACE_SS 0xa0 +#define PTRACE_FS 0xa8 +#define PTRACE_GS 0xb0 // Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_impl_arm64.s index d42eda37b..5a5e81152 100644..100755 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_impl_arm64.s @@ -1,3 +1,67 @@ +// build +arm64 + +// Automatically generated, do not edit. + +// CPU offsets. +#define CPU_SELF 0x00 +#define CPU_REGISTERS 0x288 +#define CPU_STACK_TOP 0x110 +#define CPU_ERROR_CODE 0x110 +#define CPU_ERROR_TYPE 0x118 + +// Bits. +#define _RFLAGS_IF 0x200 +#define _KERNEL_FLAGS 0x02 + +// Vectors. +#define DivideByZero 0x00 +#define Debug 0x01 +#define NMI 0x02 +#define Breakpoint 0x03 +#define Overflow 0x04 +#define BoundRangeExceeded 0x05 +#define InvalidOpcode 0x06 +#define DeviceNotAvailable 0x07 +#define DoubleFault 0x08 +#define CoprocessorSegmentOverrun 0x09 +#define InvalidTSS 0x0a +#define SegmentNotPresent 0x0b +#define StackSegmentFault 0x0c +#define GeneralProtectionFault 0x0d +#define PageFault 0x0e +#define X87FloatingPointException 0x10 +#define AlignmentCheck 0x11 +#define MachineCheck 0x12 +#define SIMDFloatingPointException 0x13 +#define VirtualizationException 0x14 +#define SecurityException 0x1e +#define SyscallInt80 0x80 +#define Syscall 0x81 + +// Ptrace registers. +#define PTRACE_R15 0x00 +#define PTRACE_R14 0x08 +#define PTRACE_R13 0x10 +#define PTRACE_R12 0x18 +#define PTRACE_RBP 0x20 +#define PTRACE_RBX 0x28 +#define PTRACE_R11 0x30 +#define PTRACE_R10 0x38 +#define PTRACE_R9 0x40 +#define PTRACE_R8 0x48 +#define PTRACE_RAX 0x50 +#define PTRACE_RCX 0x58 +#define PTRACE_RDX 0x60 +#define PTRACE_RSI 0x68 +#define PTRACE_RDI 0x70 +#define PTRACE_ORIGRAX 0x78 +#define PTRACE_RIP 0x80 +#define PTRACE_CS 0x88 +#define PTRACE_FLAGS 0x90 +#define PTRACE_RSP 0x98 +#define PTRACE_SS 0xa0 +#define PTRACE_FS 0xa8 +#define PTRACE_GS 0xb0 // Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD deleted file mode 100644 index 4cae10459..000000000 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools:defs.bzl", "go_binary") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "defs_impl_arm64", - out = "defs_impl_arm64.go", - package = "main", - template = "//pkg/sentry/platform/ring0:defs_arm64", -) - -go_template_instance( - name = "defs_impl_amd64", - out = "defs_impl_amd64.go", - package = "main", - template = "//pkg/sentry/platform/ring0:defs_amd64", -) - -go_binary( - name = "gen_offsets", - srcs = [ - "defs_impl_amd64.go", - "defs_impl_arm64.go", - "main.go", - ], - visibility = ["//pkg/sentry/platform/ring0:__pkg__"], - deps = [ - "//pkg/cpuid", - "//pkg/sentry/platform/ring0/pagetables", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/platform/ring0/gen_offsets/main.go b/pkg/sentry/platform/ring0/gen_offsets/main.go deleted file mode 100644 index a4927da2f..000000000 --- a/pkg/sentry/platform/ring0/gen_offsets/main.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary gen_offsets is a helper for generating offset headers. -package main - -import ( - "os" -) - -func main() { - Emit(os.Stdout) -} diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index ccacaea6b..ccacaea6b 100644..100755 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index af075aae4..af075aae4 100644..100755 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 0e6a6235b..0e6a6235b 100644..100755 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go deleted file mode 100644 index 85cc3fdad..000000000 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -package ring0 - -import ( - "fmt" - "io" - "reflect" - "syscall" -) - -// Emit prints architecture-specific offsets. -func Emit(w io.Writer) { - fmt.Fprintf(w, "// Automatically generated, do not edit.\n") - - c := &CPU{} - fmt.Fprintf(w, "\n// CPU offsets.\n") - fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) - fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) - - fmt.Fprintf(w, "\n// Bits.\n") - fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) - fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) - - fmt.Fprintf(w, "\n// Vectors.\n") - fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero) - fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug) - fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI) - fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint) - fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow) - fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded) - fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode) - fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable) - fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault) - fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun) - fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS) - fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent) - fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault) - fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault) - fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) - fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException) - fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck) - fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck) - fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException) - fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException) - fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException) - fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80) - fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) - - p := &syscall.PtraceRegs{} - fmt.Fprintf(w, "\n// Ptrace registers.\n") - fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_FS 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_GS 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer()) -} diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD deleted file mode 100644 index 581841555..000000000 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ /dev/null @@ -1,112 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "select_arch") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_template( - name = "generic_walker", - srcs = select_arch( - amd64 = ["walker_amd64.go"], - arm64 = ["walker_arm64.go"], - ), - opt_types = [ - "Visitor", - ], - visibility = [":__pkg__"], -) - -go_template_instance( - name = "walker_map", - out = "walker_map.go", - package = "pagetables", - prefix = "map", - template = ":generic_walker", - types = { - "Visitor": "mapVisitor", - }, -) - -go_template_instance( - name = "walker_unmap", - out = "walker_unmap.go", - package = "pagetables", - prefix = "unmap", - template = ":generic_walker", - types = { - "Visitor": "unmapVisitor", - }, -) - -go_template_instance( - name = "walker_lookup", - out = "walker_lookup.go", - package = "pagetables", - prefix = "lookup", - template = ":generic_walker", - types = { - "Visitor": "lookupVisitor", - }, -) - -go_template_instance( - name = "walker_empty", - out = "walker_empty.go", - package = "pagetables", - prefix = "empty", - template = ":generic_walker", - types = { - "Visitor": "emptyVisitor", - }, -) - -go_template_instance( - name = "walker_check", - out = "walker_check.go", - package = "pagetables", - prefix = "check", - template = ":generic_walker", - types = { - "Visitor": "checkVisitor", - }, -) - -go_library( - name = "pagetables", - srcs = [ - "allocator.go", - "allocator_unsafe.go", - "pagetables.go", - "pagetables_aarch64.go", - "pagetables_amd64.go", - "pagetables_arm64.go", - "pagetables_x86.go", - "pcids.go", - "walker_amd64.go", - "walker_arm64.go", - "walker_empty.go", - "walker_lookup.go", - "walker_map.go", - "walker_unmap.go", - ], - visibility = [ - "//pkg/sentry/platform/kvm:__subpackages__", - "//pkg/sentry/platform/ring0:__subpackages__", - ], - deps = [ - "//pkg/sync", - "//pkg/usermem", - ], -) - -go_test( - name = "pagetables_test", - size = "small", - srcs = [ - "pagetables_amd64_test.go", - "pagetables_arm64_test.go", - "pagetables_test.go", - "walker_check.go", - ], - library = ":pagetables", - deps = ["//pkg/usermem"], -) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index 78510ebed..78510ebed 100644..100755 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64_state_autogen.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64_state_autogen.go new file mode 100755 index 000000000..ae9d2b272 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package pagetables diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_state_autogen.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_state_autogen.go new file mode 100755 index 000000000..f48a8acd1 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package pagetables diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go deleted file mode 100644 index 54e8e554f..000000000 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -package pagetables - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/usermem" -) - -func Test2MAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a huge page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read}}, - }) -} - -func Test1GAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a super page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read}}, - }) -} - -func TestSplit1GPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a super page and knock out the middle. - pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*42) - pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x00007f0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read}}, - }) -} - -func TestSplit2MPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a huge page and knock out the middle. - pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*42) - pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x00007f0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read}}, - }) -} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go index 1a49f12a2..1a49f12a2 100644..100755 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_state_autogen.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_state_autogen.go new file mode 100755 index 000000000..ae9d2b272 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package pagetables diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go deleted file mode 100644 index 2f73d424f..000000000 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package pagetables - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/usermem" -) - -func Test2MAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a huge page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) - pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47) - - pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42) - pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, - {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, - {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}}, - {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}}, - }) -} - -func Test1GAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a super page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) - pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, - {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, - }) -} - -func TestSplit1GPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a super page and knock out the middle. - pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42) - pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, - {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, - }) -} - -func TestSplit2MPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a huge page and knock out the middle. - pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42) - pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, - {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, - }) -} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go b/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go new file mode 100755 index 000000000..52bab66fe --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build i386 amd64 + +package pagetables diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go deleted file mode 100644 index 5c88d087d..000000000 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagetables - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/usermem" -) - -type mapping struct { - start uintptr - length uintptr - addr uintptr - opts MapOpts -} - -type checkVisitor struct { - expected []mapping // Input. - current int // Temporary. - found []mapping // Output. - failed string // Output. -} - -func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) { - v.found = append(v.found, mapping{ - start: start, - length: align + 1, - addr: pte.Address(), - opts: pte.Opts(), - }) - if v.failed != "" { - // Don't keep looking for errors. - return - } - - if v.current >= len(v.expected) { - v.failed = "more mappings than expected" - } else if v.expected[v.current].start != start { - v.failed = "start didn't match expected" - } else if v.expected[v.current].length != (align + 1) { - v.failed = "end didn't match expected" - } else if v.expected[v.current].addr != pte.Address() { - v.failed = "address didn't match expected" - } else if v.expected[v.current].opts != pte.Opts() { - v.failed = "opts didn't match" - } - v.current++ -} - -func (*checkVisitor) requiresAlloc() bool { return false } - -func (*checkVisitor) requiresSplit() bool { return false } - -func checkMappings(t *testing.T, pt *PageTables, m []mapping) { - // Iterate over all the mappings. - w := checkWalker{ - pageTables: pt, - visitor: checkVisitor{ - expected: m, - }, - } - w.iterateRange(0, ^uintptr(0)) - - // Were we expected additional mappings? - if w.visitor.failed == "" && w.visitor.current != len(w.visitor.expected) { - w.visitor.failed = "insufficient mappings found" - } - - // Emit a meaningful error message on failure. - if w.visitor.failed != "" { - t.Errorf("%s; got %#v, wanted %#v", w.visitor.failed, w.visitor.found, w.visitor.expected) - } -} - -func TestUnmap(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map and unmap one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Unmap(0x400000, pteSize) - - checkMappings(t, pt, nil) -} - -func TestReadOnly(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}}, - }) -} - -func TestReadWrite(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - }) -} - -func TestSerialEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map two sequential entries. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x401000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x401000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.ReadWrite}}, - }) -} - -func TestSpanningEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Span a pgd with two pages. - pt.Map(0x00007efffffff000, 2*pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x00007efffffff000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000, pteSize, pteSize * 43, MapOpts{AccessType: usermem.Read}}, - }) -} - -func TestSparseEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map two entries in different pgds. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.Read}}, - }) -} diff --git a/pkg/sentry/platform/ring0/pagetables/pcids.go b/pkg/sentry/platform/ring0/pagetables/pcids.go index 9206030bf..9206030bf 100644..100755 --- a/pkg/sentry/platform/ring0/pagetables/pcids.go +++ b/pkg/sentry/platform/ring0/pagetables/pcids.go diff --git a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go b/pkg/sentry/platform/ring0/pagetables/walker_amd64.go index 8f9dacd93..8f9dacd93 100644..100755 --- a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/walker_amd64.go diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go index c261d393a..c261d393a 100644..100755 --- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go diff --git a/pkg/sentry/platform/ring0/pagetables/walker_empty.go b/pkg/sentry/platform/ring0/pagetables/walker_empty.go new file mode 100755 index 000000000..417784e17 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_empty.go @@ -0,0 +1,255 @@ +package pagetables + +// Walker walks page tables. +type emptyWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor emptyVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *emptyWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func emptynext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *emptyWalker) iterateRangeCanonical(start, end uintptr) { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = emptynext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = emptynext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = emptynext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < emptynext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = emptynext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = emptynext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = emptynext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < emptynext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = emptynext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_lookup.go b/pkg/sentry/platform/ring0/pagetables/walker_lookup.go new file mode 100755 index 000000000..906c9c50f --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_lookup.go @@ -0,0 +1,255 @@ +package pagetables + +// Walker walks page tables. +type lookupWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor lookupVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *lookupWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func lookupnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *lookupWalker) iterateRangeCanonical(start, end uintptr) { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = lookupnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = lookupnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = lookupnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < lookupnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = lookupnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = lookupnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = lookupnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < lookupnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = lookupnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_map.go b/pkg/sentry/platform/ring0/pagetables/walker_map.go new file mode 100755 index 000000000..61ee3c825 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_map.go @@ -0,0 +1,255 @@ +package pagetables + +// Walker walks page tables. +type mapWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor mapVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *mapWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func mapnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *mapWalker) iterateRangeCanonical(start, end uintptr) { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = mapnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = mapnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = mapnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < mapnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = mapnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = mapnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = mapnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < mapnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = mapnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_unmap.go b/pkg/sentry/platform/ring0/pagetables/walker_unmap.go new file mode 100755 index 000000000..be2aa0ce4 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_unmap.go @@ -0,0 +1,255 @@ +package pagetables + +// Walker walks page tables. +type unmapWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor unmapVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *unmapWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func unmapnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *unmapWalker) iterateRangeCanonical(start, end uintptr) { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = unmapnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = unmapnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = unmapnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < unmapnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = unmapnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = unmapnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = unmapnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < unmapnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = unmapnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/platform/ring0/ring0_amd64_state_autogen.go b/pkg/sentry/platform/ring0/ring0_amd64_state_autogen.go new file mode 100755 index 000000000..96cf5d331 --- /dev/null +++ b/pkg/sentry/platform/ring0/ring0_amd64_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +// +build amd64 +// +build amd64 +// +build amd64 + +package ring0 diff --git a/pkg/sentry/platform/ring0/ring0_arm64_state_autogen.go b/pkg/sentry/platform/ring0/ring0_arm64_state_autogen.go new file mode 100755 index 000000000..7f2ab3537 --- /dev/null +++ b/pkg/sentry/platform/ring0/ring0_arm64_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +// +build arm64 +// +build arm64 +// +build arm64 + +package ring0 diff --git a/pkg/sentry/platform/ring0/ring0_state_autogen.go b/pkg/sentry/platform/ring0/ring0_state_autogen.go new file mode 100755 index 000000000..327aba163 --- /dev/null +++ b/pkg/sentry/platform/ring0/ring0_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ring0 diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go deleted file mode 100644 index 5f80d64e8..000000000 --- a/pkg/sentry/platform/ring0/x86.go +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build i386 amd64 - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/cpuid" -) - -// Useful bits. -const ( - _CR0_PE = 1 << 0 - _CR0_ET = 1 << 4 - _CR0_AM = 1 << 18 - _CR0_PG = 1 << 31 - - _CR4_PSE = 1 << 4 - _CR4_PAE = 1 << 5 - _CR4_PGE = 1 << 7 - _CR4_OSFXSR = 1 << 9 - _CR4_OSXMMEXCPT = 1 << 10 - _CR4_FSGSBASE = 1 << 16 - _CR4_PCIDE = 1 << 17 - _CR4_OSXSAVE = 1 << 18 - _CR4_SMEP = 1 << 20 - - _RFLAGS_AC = 1 << 18 - _RFLAGS_NT = 1 << 14 - _RFLAGS_IOPL = 3 << 12 - _RFLAGS_DF = 1 << 10 - _RFLAGS_IF = 1 << 9 - _RFLAGS_STEP = 1 << 8 - _RFLAGS_RESERVED = 1 << 1 - - _EFER_SCE = 0x001 - _EFER_LME = 0x100 - _EFER_LMA = 0x400 - _EFER_NX = 0x800 - - _MSR_STAR = 0xc0000081 - _MSR_LSTAR = 0xc0000082 - _MSR_CSTAR = 0xc0000083 - _MSR_SYSCALL_MASK = 0xc0000084 - _MSR_PLATFORM_INFO = 0xce - _MSR_MISC_FEATURES = 0x140 - - _PLATFORM_INFO_CPUID_FAULT = 1 << 31 - - _MISC_FEATURE_CPUID_TRAP = 0x1 -) - -const ( - // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = _RFLAGS_RESERVED - - // UserFlagsSet are always set in userspace. - UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF - - // KernelFlagsClear should always be clear in the kernel. - KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT - - // UserFlagsClear are always cleared in userspace. - UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL -) - -// Vector is an exception vector. -type Vector uintptr - -// Exception vectors. -const ( - DivideByZero Vector = iota - Debug - NMI - Breakpoint - Overflow - BoundRangeExceeded - InvalidOpcode - DeviceNotAvailable - DoubleFault - CoprocessorSegmentOverrun - InvalidTSS - SegmentNotPresent - StackSegmentFault - GeneralProtectionFault - PageFault - _ - X87FloatingPointException - AlignmentCheck - MachineCheck - SIMDFloatingPointException - VirtualizationException - SecurityException = 0x1e - SyscallInt80 = 0x80 - _NR_INTERRUPTS = SyscallInt80 + 1 -) - -// System call vectors. -const ( - Syscall Vector = _NR_INTERRUPTS -) - -// VirtualAddressBits returns the number bits available for virtual addresses. -// -// Note that sign-extension semantics apply to the highest order bit. -// -// FIXME(b/69382326): This should use the cpuid passed to Init. -func VirtualAddressBits() uint32 { - ax, _, _, _ := cpuid.HostID(0x80000008, 0) - return (ax >> 8) & 0xff -} - -// PhysicalAddressBits returns the number of bits available for physical addresses. -// -// FIXME(b/69382326): This should use the cpuid passed to Init. -func PhysicalAddressBits() uint32 { - ax, _, _, _ := cpuid.HostID(0x80000008, 0) - return ax & 0xff -} - -// Selector is a segment Selector. -type Selector uint16 - -// SegmentDescriptor is a segment descriptor. -type SegmentDescriptor struct { - bits [2]uint32 -} - -// descriptorTable is a collection of descriptors. -type descriptorTable [32]SegmentDescriptor - -// SegmentDescriptorFlags are typed flags within a descriptor. -type SegmentDescriptorFlags uint32 - -// SegmentDescriptorFlag declarations. -const ( - SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set). - SegmentDescriptorWrite = 1 << 9 // Write permission. - SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used. - SegmentDescriptorExecute = 1 << 11 // Execute permission. - SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data. - SegmentDescriptorPresent = 1 << 15 // Present. - SegmentDescriptorAVL = 1 << 20 // Available. - SegmentDescriptorLong = 1 << 21 // Long mode. - SegmentDescriptorDB = 1 << 22 // 16 or 32-bit. - SegmentDescriptorG = 1 << 23 // Granularity: page or byte. -) - -// Base returns the descriptor's base linear address. -func (d *SegmentDescriptor) Base() uint32 { - return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16 -} - -// Limit returns the descriptor size. -func (d *SegmentDescriptor) Limit() uint32 { - l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000 - if d.bits[1]&uint32(SegmentDescriptorG) != 0 { - l <<= 12 - l |= 0xFFF - } - return l -} - -// Flags returns descriptor flags. -func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags { - return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00) -} - -// DPL returns the descriptor privilege level. -func (d *SegmentDescriptor) DPL() int { - return int((d.bits[1] >> 13) & 3) -} - -func (d *SegmentDescriptor) setNull() { - d.bits[0] = 0 - d.bits[1] = 0 -} - -func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) { - flags |= SegmentDescriptorPresent - if limit>>12 != 0 { - limit >>= 12 - flags |= SegmentDescriptorG - } - d.bits[0] = base<<16 | limit&0xFFFF - d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13 -} - -func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorDB| - SegmentDescriptorExecute| - SegmentDescriptorSystem) -} - -func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorG| - SegmentDescriptorLong| - SegmentDescriptorExecute| - SegmentDescriptorSystem) -} - -func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorWrite| - SegmentDescriptorSystem) -} - -// setHi is only used for the TSS segment, which is magically 64-bits. -func (d *SegmentDescriptor) setHi(base uint32) { - d.bits[0] = base - d.bits[1] = 0 -} - -// Gate64 is a 64-bit task, trap, or interrupt gate. -type Gate64 struct { - bits [4]uint32 -} - -// idt64 is a 64-bit interrupt descriptor table. -type idt64 [_NR_INTERRUPTS]Gate64 - -func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) { - g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF - g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7 - g.bits[2] = uint32(rip >> 32) -} - -func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) { - g.setInterrupt(cs, rip, dpl, ist) - g.bits[1] |= 1 << 8 -} - -// TaskState64 is a 64-bit task state structure. -type TaskState64 struct { - _ uint32 - rsp0Lo, rsp0Hi uint32 - rsp1Lo, rsp1Hi uint32 - rsp2Lo, rsp2Hi uint32 - _ [2]uint32 - ist1Lo, ist1Hi uint32 - ist2Lo, ist2Hi uint32 - ist3Lo, ist3Hi uint32 - ist4Lo, ist4Hi uint32 - ist5Lo, ist5Hi uint32 - ist6Lo, ist6Hi uint32 - ist7Lo, ist7Hi uint32 - _ [2]uint32 - _ uint16 - ioPerm uint16 -} diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sentry/sighandling/BUILD deleted file mode 100644 index 6c38a3f44..000000000 --- a/pkg/sentry/sighandling/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sighandling", - srcs = [ - "sighandling.go", - "sighandling_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/abi/linux"], -) diff --git a/pkg/sentry/sighandling/sighandling_state_autogen.go b/pkg/sentry/sighandling/sighandling_state_autogen.go new file mode 100755 index 000000000..da9d96382 --- /dev/null +++ b/pkg/sentry/sighandling/sighandling_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sighandling diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD deleted file mode 100644 index 611fa22c3..000000000 --- a/pkg/sentry/socket/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "socket", - srcs = ["socket.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket/unix/transport", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD deleted file mode 100644 index 4d42d29cb..000000000 --- a/pkg/sentry/socket/control/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "control", - srcs = ["control.go"], - imports = [ - "gvisor.dev/gvisor/pkg/sentry/fs", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/sentry/fs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/socket/control/control_state_autogen.go b/pkg/sentry/socket/control/control_state_autogen.go new file mode 100755 index 000000000..8a37b04c0 --- /dev/null +++ b/pkg/sentry/socket/control/control_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package control + +import ( + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/state" +) + +func (x *RightsFiles) save(m state.Map) { + m.SaveValue("", ([]*fs.File)(*x)) +} + +func (x *RightsFiles) load(m state.Map) { + m.LoadValue("", new([]*fs.File), func(y interface{}) { *x = (RightsFiles)(y.([]*fs.File)) }) +} + +func (x *scmCredentials) beforeSave() {} +func (x *scmCredentials) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("kuid", &x.kuid) + m.Save("kgid", &x.kgid) +} + +func (x *scmCredentials) afterLoad() {} +func (x *scmCredentials) load(m state.Map) { + m.Load("t", &x.t) + m.Load("kuid", &x.kuid) + m.Load("kgid", &x.kgid) +} + +func init() { + state.Register("pkg/sentry/socket/control.RightsFiles", (*RightsFiles)(nil), state.Fns{Save: (*RightsFiles).save, Load: (*RightsFiles).load}) + state.Register("pkg/sentry/socket/control.scmCredentials", (*scmCredentials)(nil), state.Fns{Save: (*scmCredentials).save, Load: (*scmCredentials).load}) +} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD deleted file mode 100644 index 023bad156..000000000 --- a/pkg/sentry/socket/hostinet/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hostinet", - srcs = [ - "device.go", - "hostinet.go", - "save_restore.go", - "socket.go", - "socket_unsafe.go", - "sockopt_impl.go", - "stack.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip/stack", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go new file mode 100755 index 000000000..b0a59ba93 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostinet diff --git a/pkg/sentry/socket/hostinet/hostinet_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go new file mode 100755 index 000000000..b0a59ba93 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostinet diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go index 8a783712e..8a783712e 100644..100755 --- a/pkg/sentry/socket/hostinet/sockopt_impl.go +++ b/pkg/sentry/socket/hostinet/sockopt_impl.go diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD deleted file mode 100644 index 7cd2ce55b..000000000 --- a/pkg/sentry/socket/netfilter/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "netfilter", - srcs = [ - "extensions.go", - "netfilter.go", - "targets.go", - "tcp_matcher.go", - "udp_matcher.go", - ], - # This target depends on netstack and should only be used by epsocket, - # which is allowed to depend on netstack. - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/log", - "//pkg/sentry/kernel", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/stack", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index b4b244abf..b4b244abf 100644..100755 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go diff --git a/pkg/sentry/socket/netfilter/netfilter_state_autogen.go b/pkg/sentry/socket/netfilter/netfilter_state_autogen.go new file mode 100755 index 000000000..6e95d89a4 --- /dev/null +++ b/pkg/sentry/socket/netfilter/netfilter_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package netfilter diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index c421b87cf..c421b87cf 100644..100755 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index f9945e214..f9945e214 100644..100755 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 86aa11696..86aa11696 100644..100755 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD deleted file mode 100644 index 1911cd9b8..000000000 --- a/pkg/sentry/socket/netlink/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "netlink", - srcs = [ - "message.go", - "provider.go", - "socket.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/netlink/port", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sync", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "netlink_test", - size = "small", - srcs = [ - "message_test.go", - ], - deps = [ - ":netlink", - "//pkg/abi/linux", - ], -) diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go deleted file mode 100644 index ef13d9386..000000000 --- a/pkg/sentry/socket/netlink/message_test.go +++ /dev/null @@ -1,312 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package message_test - -import ( - "bytes" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/socket/netlink" -) - -type dummyNetlinkMsg struct { - Foo uint16 -} - -func TestParseMessage(t *testing.T) { - tests := []struct { - desc string - input []byte - - header linux.NetlinkMessageHeader - dataMsg *dummyNetlinkMsg - restLen int - ok bool - }{ - { - desc: "valid", - input: []byte{ - 0x14, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - header: linux.NetlinkMessageHeader{ - Length: 20, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "valid with next message", - input: []byte{ - 0x14, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - 0xFF, // Next message (rest) - }, - header: linux.NetlinkMessageHeader{ - Length: 20, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 1, - ok: true, - }, - { - desc: "valid for last message without padding", - input: []byte{ - 0x12, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, // Data message - }, - header: linux.NetlinkMessageHeader{ - Length: 18, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "valid for last message not to be aligned", - input: []byte{ - 0x13, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, // Data message - 0x00, // Excessive 1 byte permitted at end - }, - header: linux.NetlinkMessageHeader{ - Length: 19, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "header.Length too short", - input: []byte{ - 0x04, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - ok: false, - }, - { - desc: "header.Length too long", - input: []byte{ - 0xFF, 0xFF, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - ok: false, - }, - { - desc: "header incomplete", - input: []byte{ - 0x04, 0x00, 0x00, 0x00, // Length - }, - ok: false, - }, - { - desc: "empty message", - input: []byte{}, - ok: false, - }, - } - for _, test := range tests { - msg, rest, ok := netlink.ParseMessage(test.input) - if ok != test.ok { - t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) - continue - } - if !test.ok { - continue - } - if !reflect.DeepEqual(msg.Header(), test.header) { - t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) - } - - dataMsg := &dummyNetlinkMsg{} - _, dataOk := msg.GetData(dataMsg) - if !dataOk { - t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) - } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { - t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) - } - - if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) { - t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want) - } - } -} - -func TestAttrView(t *testing.T) { - tests := []struct { - desc string - input []byte - - // Outputs for ParseFirst. - hdr linux.NetlinkAttrHeader - value []byte - restLen int - ok bool - - // Outputs for Empty. - isEmpty bool - }{ - { - desc: "valid", - input: []byte{ - 0x06, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding - }, - hdr: linux.NetlinkAttrHeader{ - Length: 6, - Type: 1, - }, - value: []byte{0x30, 0x31}, - restLen: 0, - ok: true, - isEmpty: false, - }, - { - desc: "at alignment", - input: []byte{ - 0x08, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - hdr: linux.NetlinkAttrHeader{ - Length: 8, - Type: 1, - }, - value: []byte{0x30, 0x31, 0x32, 0x33}, - restLen: 0, - ok: true, - isEmpty: false, - }, - { - desc: "at alignment with rest data", - input: []byte{ - 0x08, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - 0xFF, 0xFE, // Rest data - }, - hdr: linux.NetlinkAttrHeader{ - Length: 8, - Type: 1, - }, - value: []byte{0x30, 0x31, 0x32, 0x33}, - restLen: 2, - ok: true, - isEmpty: false, - }, - { - desc: "hdr.Length too long", - input: []byte{ - 0xFF, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - ok: false, - isEmpty: false, - }, - { - desc: "hdr.Length too short", - input: []byte{ - 0x01, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - ok: false, - isEmpty: false, - }, - { - desc: "empty", - input: []byte{}, - ok: false, - isEmpty: true, - }, - } - for _, test := range tests { - attrs := netlink.AttrsView(test.input) - - // Test ParseFirst(). - hdr, value, rest, ok := attrs.ParseFirst() - if ok != test.ok { - t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) - } else if test.ok { - if !reflect.DeepEqual(hdr, test.hdr) { - t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr) - } - if !bytes.Equal(value, test.value) { - t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value) - } - if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) { - t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest) - } - } - - // Test Empty(). - if got, want := attrs.Empty(), test.isEmpty; got != want { - t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want) - } - } -} diff --git a/pkg/sentry/socket/netlink/netlink_state_autogen.go b/pkg/sentry/socket/netlink/netlink_state_autogen.go new file mode 100755 index 000000000..792ac6774 --- /dev/null +++ b/pkg/sentry/socket/netlink/netlink_state_autogen.go @@ -0,0 +1,52 @@ +// automatically generated by stateify. + +package netlink + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Socket) beforeSave() {} +func (x *Socket) save(m state.Map) { + x.beforeSave() + m.Save("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Save("ports", &x.ports) + m.Save("protocol", &x.protocol) + m.Save("skType", &x.skType) + m.Save("ep", &x.ep) + m.Save("connection", &x.connection) + m.Save("bound", &x.bound) + m.Save("portID", &x.portID) + m.Save("sendBufferSize", &x.sendBufferSize) + m.Save("passcred", &x.passcred) + m.Save("filter", &x.filter) +} + +func (x *Socket) afterLoad() {} +func (x *Socket) load(m state.Map) { + m.Load("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Load("ports", &x.ports) + m.Load("protocol", &x.protocol) + m.Load("skType", &x.skType) + m.Load("ep", &x.ep) + m.Load("connection", &x.connection) + m.Load("bound", &x.bound) + m.Load("portID", &x.portID) + m.Load("sendBufferSize", &x.sendBufferSize) + m.Load("passcred", &x.passcred) + m.Load("filter", &x.filter) +} + +func (x *kernelSCM) beforeSave() {} +func (x *kernelSCM) save(m state.Map) { + x.beforeSave() +} + +func (x *kernelSCM) afterLoad() {} +func (x *kernelSCM) load(m state.Map) { +} + +func init() { + state.Register("pkg/sentry/socket/netlink.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load}) + state.Register("pkg/sentry/socket/netlink.kernelSCM", (*kernelSCM)(nil), state.Fns{Save: (*kernelSCM).save, Load: (*kernelSCM).load}) +} diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD deleted file mode 100644 index 3a22923d8..000000000 --- a/pkg/sentry/socket/netlink/port/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "port", - srcs = ["port.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/sync"], -) - -go_test( - name = "port_test", - srcs = ["port_test.go"], - library = ":port", -) diff --git a/pkg/sentry/socket/netlink/port/port_state_autogen.go b/pkg/sentry/socket/netlink/port/port_state_autogen.go new file mode 100755 index 000000000..c509cc7d5 --- /dev/null +++ b/pkg/sentry/socket/netlink/port/port_state_autogen.go @@ -0,0 +1,22 @@ +// automatically generated by stateify. + +package port + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Manager) beforeSave() {} +func (x *Manager) save(m state.Map) { + x.beforeSave() + m.Save("ports", &x.ports) +} + +func (x *Manager) afterLoad() {} +func (x *Manager) load(m state.Map) { + m.Load("ports", &x.ports) +} + +func init() { + state.Register("pkg/sentry/socket/netlink/port.Manager", (*Manager)(nil), state.Fns{Save: (*Manager).save, Load: (*Manager).load}) +} diff --git a/pkg/sentry/socket/netlink/port/port_test.go b/pkg/sentry/socket/netlink/port/port_test.go deleted file mode 100644 index 516f6cd6c..000000000 --- a/pkg/sentry/socket/netlink/port/port_test.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package port - -import ( - "testing" -) - -func TestAllocateHint(t *testing.T) { - m := New() - - // We can get the hint port. - p, ok := m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(0, 1) got %d want 1", p) - } - - // Hint is taken. - p, ok = m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p == 1 { - t.Errorf("m.Allocate(0, 1) got 1 want anything else") - } - - // Hint is available for a different protocol. - p, ok = m.Allocate(1, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(1, 1) got %d want 1", p) - } - - m.Release(0, 1) - - // Hint is available again after release. - p, ok = m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(0, 1) got %d want 1", p) - } -} - -func TestAllocateExhausted(t *testing.T) { - m := New() - - // Fill all ports (0 is already reserved). - for i := int32(1); i < maxPorts; i++ { - p, ok := m.Allocate(0, i) - if !ok { - t.Fatalf("m.Allocate got !ok want ok") - } - if p != i { - t.Fatalf("m.Allocate(0, %d) got %d want %d", i, p, i) - } - } - - // Now no more can be allocated. - p, ok := m.Allocate(0, 1) - if ok { - t.Errorf("m.Allocate got %d, ok want !ok", p) - } -} diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD deleted file mode 100644 index 93127398d..000000000 --- a/pkg/sentry/socket/netlink/route/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "route", - srcs = [ - "protocol.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket/netlink", - "//pkg/syserr", - ], -) diff --git a/pkg/sentry/socket/netlink/route/route_state_autogen.go b/pkg/sentry/socket/netlink/route/route_state_autogen.go new file mode 100755 index 000000000..bd10fe189 --- /dev/null +++ b/pkg/sentry/socket/netlink/route/route_state_autogen.go @@ -0,0 +1,20 @@ +// automatically generated by stateify. + +package route + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Protocol) beforeSave() {} +func (x *Protocol) save(m state.Map) { + x.beforeSave() +} + +func (x *Protocol) afterLoad() {} +func (x *Protocol) load(m state.Map) { +} + +func init() { + state.Register("pkg/sentry/socket/netlink/route.Protocol", (*Protocol)(nil), state.Fns{Save: (*Protocol).save, Load: (*Protocol).load}) +} diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD deleted file mode 100644 index b6434923c..000000000 --- a/pkg/sentry/socket/netlink/uevent/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "uevent", - srcs = ["protocol.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/kernel", - "//pkg/sentry/socket/netlink", - "//pkg/syserr", - ], -) diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go index 029ba21b5..029ba21b5 100644..100755 --- a/pkg/sentry/socket/netlink/uevent/protocol.go +++ b/pkg/sentry/socket/netlink/uevent/protocol.go diff --git a/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go b/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go new file mode 100755 index 000000000..b82dddf32 --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go @@ -0,0 +1,20 @@ +// automatically generated by stateify. + +package uevent + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Protocol) beforeSave() {} +func (x *Protocol) save(m state.Map) { + x.beforeSave() +} + +func (x *Protocol) afterLoad() {} +func (x *Protocol) load(m state.Map) { +} + +func init() { + state.Register("pkg/sentry/socket/netlink/uevent.Protocol", (*Protocol)(nil), state.Fns{Save: (*Protocol).save, Load: (*Protocol).load}) +} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD deleted file mode 100644 index ab01cb4fa..000000000 --- a/pkg/sentry/socket/netstack/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "netstack", - srcs = [ - "device.go", - "netstack.go", - "provider.go", - "save_restore.go", - "stack.go", - ], - visibility = [ - "//pkg/sentry:internal", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/log", - "//pkg/metric", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/netfilter", - "//pkg/sentry/unimpl", - "//pkg/sync", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/netstack/device.go b/pkg/sentry/socket/netstack/device.go index fbeb89fb8..fbeb89fb8 100644..100755 --- a/pkg/sentry/socket/netstack/device.go +++ b/pkg/sentry/socket/netstack/device.go diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 13a9a60b4..13a9a60b4 100644..100755 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go diff --git a/pkg/sentry/socket/netstack/netstack_state_autogen.go b/pkg/sentry/socket/netstack/netstack_state_autogen.go new file mode 100755 index 000000000..608f23f63 --- /dev/null +++ b/pkg/sentry/socket/netstack/netstack_state_autogen.go @@ -0,0 +1,56 @@ +// automatically generated by stateify. + +package netstack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SocketOperations) beforeSave() {} +func (x *SocketOperations) save(m state.Map) { + x.beforeSave() + m.Save("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Save("Queue", &x.Queue) + m.Save("family", &x.family) + m.Save("Endpoint", &x.Endpoint) + m.Save("skType", &x.skType) + m.Save("protocol", &x.protocol) + m.Save("readView", &x.readView) + m.Save("readCM", &x.readCM) + m.Save("sender", &x.sender) + m.Save("sockOptTimestamp", &x.sockOptTimestamp) + m.Save("timestampValid", &x.timestampValid) + m.Save("timestampNS", &x.timestampNS) + m.Save("sockOptInq", &x.sockOptInq) +} + +func (x *SocketOperations) afterLoad() {} +func (x *SocketOperations) load(m state.Map) { + m.Load("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Load("Queue", &x.Queue) + m.Load("family", &x.family) + m.Load("Endpoint", &x.Endpoint) + m.Load("skType", &x.skType) + m.Load("protocol", &x.protocol) + m.Load("readView", &x.readView) + m.Load("readCM", &x.readCM) + m.Load("sender", &x.sender) + m.Load("sockOptTimestamp", &x.sockOptTimestamp) + m.Load("timestampValid", &x.timestampValid) + m.Load("timestampNS", &x.timestampNS) + m.Load("sockOptInq", &x.sockOptInq) +} + +func (x *Stack) beforeSave() {} +func (x *Stack) save(m state.Map) { + x.beforeSave() +} + +func (x *Stack) load(m state.Map) { + m.AfterLoad(x.afterLoad) +} + +func init() { + state.Register("pkg/sentry/socket/netstack.SocketOperations", (*SocketOperations)(nil), state.Fns{Save: (*SocketOperations).save, Load: (*SocketOperations).load}) + state.Register("pkg/sentry/socket/netstack.Stack", (*Stack)(nil), state.Fns{Save: (*Stack).save, Load: (*Stack).load}) +} diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 5f181f017..5f181f017 100644..100755 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go diff --git a/pkg/sentry/socket/netstack/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go index c7aaf722a..c7aaf722a 100644..100755 --- a/pkg/sentry/socket/netstack/save_restore.go +++ b/pkg/sentry/socket/netstack/save_restore.go diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 0692482e9..0692482e9 100644..100755 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go diff --git a/pkg/sentry/socket/socket_state_autogen.go b/pkg/sentry/socket/socket_state_autogen.go new file mode 100755 index 000000000..900c217c7 --- /dev/null +++ b/pkg/sentry/socket/socket_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package socket + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SendReceiveTimeout) beforeSave() {} +func (x *SendReceiveTimeout) save(m state.Map) { + x.beforeSave() + m.Save("send", &x.send) + m.Save("recv", &x.recv) +} + +func (x *SendReceiveTimeout) afterLoad() {} +func (x *SendReceiveTimeout) load(m state.Map) { + m.Load("send", &x.send) + m.Load("recv", &x.recv) +} + +func init() { + state.Register("pkg/sentry/socket.SendReceiveTimeout", (*SendReceiveTimeout)(nil), state.Fns{Save: (*SendReceiveTimeout).save, Load: (*SendReceiveTimeout).load}) +} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD deleted file mode 100644 index 08743deba..000000000 --- a/pkg/sentry/socket/unix/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "unix", - srcs = [ - "device.go", - "io.go", - "unix.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/refs", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/netstack", - "//pkg/sentry/socket/unix/transport", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD deleted file mode 100644 index 74bcd6300..000000000 --- a/pkg/sentry/socket/unix/transport/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "transport_message_list", - out = "transport_message_list.go", - package = "transport", - prefix = "message", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*message", - "Linker": "*message", - }, -) - -go_library( - name = "transport", - srcs = [ - "connectioned.go", - "connectioned_state.go", - "connectionless.go", - "queue.go", - "transport_message_list.go", - "unix.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/ilist", - "//pkg/refs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/unix/transport/transport_message_list.go b/pkg/sentry/socket/unix/transport/transport_message_list.go new file mode 100755 index 000000000..9edc731b4 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/transport_message_list.go @@ -0,0 +1,186 @@ +package transport + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type messageElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (messageElementMapper) linkerFor(elem *message) *message { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type messageList struct { + head *message + tail *message +} + +// Reset resets list l to the empty state. +func (l *messageList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *messageList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *messageList) Front() *message { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *messageList) Back() *message { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *messageList) PushFront(e *message) { + linker := messageElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + messageElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *messageList) PushBack(e *message) { + linker := messageElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + messageElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *messageList) PushBackList(m *messageList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + messageElementMapper{}.linkerFor(l.tail).SetNext(m.head) + messageElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *messageList) InsertAfter(b, e *message) { + bLinker := messageElementMapper{}.linkerFor(b) + eLinker := messageElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + messageElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *messageList) InsertBefore(a, e *message) { + aLinker := messageElementMapper{}.linkerFor(a) + eLinker := messageElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + messageElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *messageList) Remove(e *message) { + linker := messageElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + messageElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + messageElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type messageEntry struct { + next *message + prev *message +} + +// Next returns the entry that follows e in the list. +func (e *messageEntry) Next() *message { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *messageEntry) Prev() *message { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *messageEntry) SetNext(elem *message) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *messageEntry) SetPrev(elem *message) { + e.prev = elem +} diff --git a/pkg/sentry/socket/unix/transport/transport_state_autogen.go b/pkg/sentry/socket/unix/transport/transport_state_autogen.go new file mode 100755 index 000000000..b47951498 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/transport_state_autogen.go @@ -0,0 +1,193 @@ +// automatically generated by stateify. + +package transport + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *connectionedEndpoint) beforeSave() {} +func (x *connectionedEndpoint) save(m state.Map) { + x.beforeSave() + var acceptedChan []*connectionedEndpoint = x.saveAcceptedChan() + m.SaveValue("acceptedChan", acceptedChan) + m.Save("baseEndpoint", &x.baseEndpoint) + m.Save("id", &x.id) + m.Save("idGenerator", &x.idGenerator) + m.Save("stype", &x.stype) +} + +func (x *connectionedEndpoint) afterLoad() {} +func (x *connectionedEndpoint) load(m state.Map) { + m.Load("baseEndpoint", &x.baseEndpoint) + m.Load("id", &x.id) + m.Load("idGenerator", &x.idGenerator) + m.Load("stype", &x.stype) + m.LoadValue("acceptedChan", new([]*connectionedEndpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*connectionedEndpoint)) }) +} + +func (x *connectionlessEndpoint) beforeSave() {} +func (x *connectionlessEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("baseEndpoint", &x.baseEndpoint) +} + +func (x *connectionlessEndpoint) afterLoad() {} +func (x *connectionlessEndpoint) load(m state.Map) { + m.Load("baseEndpoint", &x.baseEndpoint) +} + +func (x *queue) beforeSave() {} +func (x *queue) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("ReaderQueue", &x.ReaderQueue) + m.Save("WriterQueue", &x.WriterQueue) + m.Save("closed", &x.closed) + m.Save("unread", &x.unread) + m.Save("used", &x.used) + m.Save("limit", &x.limit) + m.Save("dataList", &x.dataList) +} + +func (x *queue) afterLoad() {} +func (x *queue) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("ReaderQueue", &x.ReaderQueue) + m.Load("WriterQueue", &x.WriterQueue) + m.Load("closed", &x.closed) + m.Load("unread", &x.unread) + m.Load("used", &x.used) + m.Load("limit", &x.limit) + m.Load("dataList", &x.dataList) +} + +func (x *messageList) beforeSave() {} +func (x *messageList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *messageList) afterLoad() {} +func (x *messageList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *messageEntry) beforeSave() {} +func (x *messageEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *messageEntry) afterLoad() {} +func (x *messageEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *ControlMessages) beforeSave() {} +func (x *ControlMessages) save(m state.Map) { + x.beforeSave() + m.Save("Rights", &x.Rights) + m.Save("Credentials", &x.Credentials) +} + +func (x *ControlMessages) afterLoad() {} +func (x *ControlMessages) load(m state.Map) { + m.Load("Rights", &x.Rights) + m.Load("Credentials", &x.Credentials) +} + +func (x *message) beforeSave() {} +func (x *message) save(m state.Map) { + x.beforeSave() + m.Save("messageEntry", &x.messageEntry) + m.Save("Data", &x.Data) + m.Save("Control", &x.Control) + m.Save("Address", &x.Address) +} + +func (x *message) afterLoad() {} +func (x *message) load(m state.Map) { + m.Load("messageEntry", &x.messageEntry) + m.Load("Data", &x.Data) + m.Load("Control", &x.Control) + m.Load("Address", &x.Address) +} + +func (x *queueReceiver) beforeSave() {} +func (x *queueReceiver) save(m state.Map) { + x.beforeSave() + m.Save("readQueue", &x.readQueue) +} + +func (x *queueReceiver) afterLoad() {} +func (x *queueReceiver) load(m state.Map) { + m.Load("readQueue", &x.readQueue) +} + +func (x *streamQueueReceiver) beforeSave() {} +func (x *streamQueueReceiver) save(m state.Map) { + x.beforeSave() + m.Save("queueReceiver", &x.queueReceiver) + m.Save("buffer", &x.buffer) + m.Save("control", &x.control) + m.Save("addr", &x.addr) +} + +func (x *streamQueueReceiver) afterLoad() {} +func (x *streamQueueReceiver) load(m state.Map) { + m.Load("queueReceiver", &x.queueReceiver) + m.Load("buffer", &x.buffer) + m.Load("control", &x.control) + m.Load("addr", &x.addr) +} + +func (x *connectedEndpoint) beforeSave() {} +func (x *connectedEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("endpoint", &x.endpoint) + m.Save("writeQueue", &x.writeQueue) +} + +func (x *connectedEndpoint) afterLoad() {} +func (x *connectedEndpoint) load(m state.Map) { + m.Load("endpoint", &x.endpoint) + m.Load("writeQueue", &x.writeQueue) +} + +func (x *baseEndpoint) beforeSave() {} +func (x *baseEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("Queue", &x.Queue) + m.Save("passcred", &x.passcred) + m.Save("receiver", &x.receiver) + m.Save("connected", &x.connected) + m.Save("path", &x.path) +} + +func (x *baseEndpoint) afterLoad() {} +func (x *baseEndpoint) load(m state.Map) { + m.Load("Queue", &x.Queue) + m.Load("passcred", &x.passcred) + m.Load("receiver", &x.receiver) + m.Load("connected", &x.connected) + m.Load("path", &x.path) +} + +func init() { + state.Register("pkg/sentry/socket/unix/transport.connectionedEndpoint", (*connectionedEndpoint)(nil), state.Fns{Save: (*connectionedEndpoint).save, Load: (*connectionedEndpoint).load}) + state.Register("pkg/sentry/socket/unix/transport.connectionlessEndpoint", (*connectionlessEndpoint)(nil), state.Fns{Save: (*connectionlessEndpoint).save, Load: (*connectionlessEndpoint).load}) + state.Register("pkg/sentry/socket/unix/transport.queue", (*queue)(nil), state.Fns{Save: (*queue).save, Load: (*queue).load}) + state.Register("pkg/sentry/socket/unix/transport.messageList", (*messageList)(nil), state.Fns{Save: (*messageList).save, Load: (*messageList).load}) + state.Register("pkg/sentry/socket/unix/transport.messageEntry", (*messageEntry)(nil), state.Fns{Save: (*messageEntry).save, Load: (*messageEntry).load}) + state.Register("pkg/sentry/socket/unix/transport.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load}) + state.Register("pkg/sentry/socket/unix/transport.message", (*message)(nil), state.Fns{Save: (*message).save, Load: (*message).load}) + state.Register("pkg/sentry/socket/unix/transport.queueReceiver", (*queueReceiver)(nil), state.Fns{Save: (*queueReceiver).save, Load: (*queueReceiver).load}) + state.Register("pkg/sentry/socket/unix/transport.streamQueueReceiver", (*streamQueueReceiver)(nil), state.Fns{Save: (*streamQueueReceiver).save, Load: (*streamQueueReceiver).load}) + state.Register("pkg/sentry/socket/unix/transport.connectedEndpoint", (*connectedEndpoint)(nil), state.Fns{Save: (*connectedEndpoint).save, Load: (*connectedEndpoint).load}) + state.Register("pkg/sentry/socket/unix/transport.baseEndpoint", (*baseEndpoint)(nil), state.Fns{Save: (*baseEndpoint).save, Load: (*baseEndpoint).load}) +} diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go new file mode 100755 index 000000000..755373941 --- /dev/null +++ b/pkg/sentry/socket/unix/unix_state_autogen.go @@ -0,0 +1,28 @@ +// automatically generated by stateify. + +package unix + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SocketOperations) beforeSave() {} +func (x *SocketOperations) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Save("ep", &x.ep) + m.Save("stype", &x.stype) +} + +func (x *SocketOperations) afterLoad() {} +func (x *SocketOperations) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Load("ep", &x.ep) + m.Load("stype", &x.stype) +} + +func init() { + state.Register("pkg/sentry/socket/unix.SocketOperations", (*SocketOperations)(nil), state.Fns{Save: (*SocketOperations).save, Load: (*SocketOperations).load}) +} diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD deleted file mode 100644 index 0ea4aab8b..000000000 --- a/pkg/sentry/state/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "state", - srcs = [ - "state.go", - "state_metadata.go", - "state_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/time", - "//pkg/sentry/watchdog", - "//pkg/state/statefile", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/state/state_state_autogen.go b/pkg/sentry/state/state_state_autogen.go new file mode 100755 index 000000000..6c2b29632 --- /dev/null +++ b/pkg/sentry/state/state_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package state diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD deleted file mode 100644 index 88d5db9fc..000000000 --- a/pkg/sentry/strace/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -load("//tools:defs.bzl", "go_library", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "strace", - srcs = [ - "capability.go", - "clone.go", - "epoll.go", - "futex.go", - "linux64_amd64.go", - "linux64_arm64.go", - "open.go", - "poll.go", - "ptrace.go", - "select.go", - "signal.go", - "socket.go", - "strace.go", - "syscalls.go", - ], - visibility = ["//:sandbox"], - deps = [ - ":strace_go_proto", - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/bits", - "//pkg/eventchannel", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/kernel", - "//pkg/sentry/socket/netlink", - "//pkg/sentry/socket/netstack", - "//pkg/sentry/syscalls/linux", - "//pkg/usermem", - ], -) - -proto_library( - name = "strace", - srcs = ["strace.proto"], - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go index a6e48b836..a6e48b836 100644..100755 --- a/pkg/sentry/strace/epoll.go +++ b/pkg/sentry/strace/epoll.go diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index 71b92eaee..71b92eaee 100644..100755 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index bd7361a52..bd7361a52 100644..100755 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go index 3a4c32aa0..3a4c32aa0 100644..100755 --- a/pkg/sentry/strace/select.go +++ b/pkg/sentry/strace/select.go diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto deleted file mode 100644 index 906c52c51..000000000 --- a/pkg/sentry/strace/strace.proto +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -message Strace { - // Process name that made the syscall. - string process = 1; - - // Syscall function name. - string function = 2; - - // List of syscall arguments formatted as strings. - repeated string args = 3; - - oneof info { - StraceEnter enter = 4; - StraceExit exit = 5; - } -} - -message StraceEnter {} - -message StraceExit { - // Return value formatted as string. - string return = 1; - - // Formatted error string in case syscall failed. - string error = 2; - - // Value of errno upon syscall exit. - int64 err_no = 3; // errno is a macro and gets expanded :-( - - // Time elapsed between syscall enter and exit. - int64 elapsed_ns = 4; -} diff --git a/pkg/sentry/strace/strace_amd64_state_autogen.go b/pkg/sentry/strace/strace_amd64_state_autogen.go new file mode 100755 index 000000000..c7d4b3eb4 --- /dev/null +++ b/pkg/sentry/strace/strace_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package strace diff --git a/pkg/sentry/strace/strace_arm64_state_autogen.go b/pkg/sentry/strace/strace_arm64_state_autogen.go new file mode 100755 index 000000000..9b8f66dc9 --- /dev/null +++ b/pkg/sentry/strace/strace_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package strace diff --git a/pkg/sentry/strace/strace_go_proto/strace.pb.go b/pkg/sentry/strace/strace_go_proto/strace.pb.go new file mode 100755 index 000000000..ef45661bc --- /dev/null +++ b/pkg/sentry/strace/strace_go_proto/strace.pb.go @@ -0,0 +1,247 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/strace/strace.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type Strace struct { + Process string `protobuf:"bytes,1,opt,name=process,proto3" json:"process,omitempty"` + Function string `protobuf:"bytes,2,opt,name=function,proto3" json:"function,omitempty"` + Args []string `protobuf:"bytes,3,rep,name=args,proto3" json:"args,omitempty"` + // Types that are valid to be assigned to Info: + // *Strace_Enter + // *Strace_Exit + Info isStrace_Info `protobuf_oneof:"info"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Strace) Reset() { *m = Strace{} } +func (m *Strace) String() string { return proto.CompactTextString(m) } +func (*Strace) ProtoMessage() {} +func (*Strace) Descriptor() ([]byte, []int) { + return fileDescriptor_50c4b43677c82b5f, []int{0} +} + +func (m *Strace) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Strace.Unmarshal(m, b) +} +func (m *Strace) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Strace.Marshal(b, m, deterministic) +} +func (m *Strace) XXX_Merge(src proto.Message) { + xxx_messageInfo_Strace.Merge(m, src) +} +func (m *Strace) XXX_Size() int { + return xxx_messageInfo_Strace.Size(m) +} +func (m *Strace) XXX_DiscardUnknown() { + xxx_messageInfo_Strace.DiscardUnknown(m) +} + +var xxx_messageInfo_Strace proto.InternalMessageInfo + +func (m *Strace) GetProcess() string { + if m != nil { + return m.Process + } + return "" +} + +func (m *Strace) GetFunction() string { + if m != nil { + return m.Function + } + return "" +} + +func (m *Strace) GetArgs() []string { + if m != nil { + return m.Args + } + return nil +} + +type isStrace_Info interface { + isStrace_Info() +} + +type Strace_Enter struct { + Enter *StraceEnter `protobuf:"bytes,4,opt,name=enter,proto3,oneof"` +} + +type Strace_Exit struct { + Exit *StraceExit `protobuf:"bytes,5,opt,name=exit,proto3,oneof"` +} + +func (*Strace_Enter) isStrace_Info() {} + +func (*Strace_Exit) isStrace_Info() {} + +func (m *Strace) GetInfo() isStrace_Info { + if m != nil { + return m.Info + } + return nil +} + +func (m *Strace) GetEnter() *StraceEnter { + if x, ok := m.GetInfo().(*Strace_Enter); ok { + return x.Enter + } + return nil +} + +func (m *Strace) GetExit() *StraceExit { + if x, ok := m.GetInfo().(*Strace_Exit); ok { + return x.Exit + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*Strace) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*Strace_Enter)(nil), + (*Strace_Exit)(nil), + } +} + +type StraceEnter struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *StraceEnter) Reset() { *m = StraceEnter{} } +func (m *StraceEnter) String() string { return proto.CompactTextString(m) } +func (*StraceEnter) ProtoMessage() {} +func (*StraceEnter) Descriptor() ([]byte, []int) { + return fileDescriptor_50c4b43677c82b5f, []int{1} +} + +func (m *StraceEnter) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_StraceEnter.Unmarshal(m, b) +} +func (m *StraceEnter) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_StraceEnter.Marshal(b, m, deterministic) +} +func (m *StraceEnter) XXX_Merge(src proto.Message) { + xxx_messageInfo_StraceEnter.Merge(m, src) +} +func (m *StraceEnter) XXX_Size() int { + return xxx_messageInfo_StraceEnter.Size(m) +} +func (m *StraceEnter) XXX_DiscardUnknown() { + xxx_messageInfo_StraceEnter.DiscardUnknown(m) +} + +var xxx_messageInfo_StraceEnter proto.InternalMessageInfo + +type StraceExit struct { + Return string `protobuf:"bytes,1,opt,name=return,proto3" json:"return,omitempty"` + Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"` + ErrNo int64 `protobuf:"varint,3,opt,name=err_no,json=errNo,proto3" json:"err_no,omitempty"` + ElapsedNs int64 `protobuf:"varint,4,opt,name=elapsed_ns,json=elapsedNs,proto3" json:"elapsed_ns,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *StraceExit) Reset() { *m = StraceExit{} } +func (m *StraceExit) String() string { return proto.CompactTextString(m) } +func (*StraceExit) ProtoMessage() {} +func (*StraceExit) Descriptor() ([]byte, []int) { + return fileDescriptor_50c4b43677c82b5f, []int{2} +} + +func (m *StraceExit) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_StraceExit.Unmarshal(m, b) +} +func (m *StraceExit) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_StraceExit.Marshal(b, m, deterministic) +} +func (m *StraceExit) XXX_Merge(src proto.Message) { + xxx_messageInfo_StraceExit.Merge(m, src) +} +func (m *StraceExit) XXX_Size() int { + return xxx_messageInfo_StraceExit.Size(m) +} +func (m *StraceExit) XXX_DiscardUnknown() { + xxx_messageInfo_StraceExit.DiscardUnknown(m) +} + +var xxx_messageInfo_StraceExit proto.InternalMessageInfo + +func (m *StraceExit) GetReturn() string { + if m != nil { + return m.Return + } + return "" +} + +func (m *StraceExit) GetError() string { + if m != nil { + return m.Error + } + return "" +} + +func (m *StraceExit) GetErrNo() int64 { + if m != nil { + return m.ErrNo + } + return 0 +} + +func (m *StraceExit) GetElapsedNs() int64 { + if m != nil { + return m.ElapsedNs + } + return 0 +} + +func init() { + proto.RegisterType((*Strace)(nil), "gvisor.Strace") + proto.RegisterType((*StraceEnter)(nil), "gvisor.StraceEnter") + proto.RegisterType((*StraceExit)(nil), "gvisor.StraceExit") +} + +func init() { proto.RegisterFile("pkg/sentry/strace/strace.proto", fileDescriptor_50c4b43677c82b5f) } + +var fileDescriptor_50c4b43677c82b5f = []byte{ + // 255 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x90, 0xdd, 0x4a, 0xf4, 0x30, + 0x10, 0x86, 0xb7, 0x5f, 0xdb, 0x7c, 0x76, 0x16, 0x4f, 0xc6, 0x1f, 0x82, 0xa0, 0x94, 0x1e, 0x05, + 0x84, 0x2e, 0xe8, 0x1d, 0x08, 0xc2, 0x1e, 0xed, 0x41, 0xbc, 0x80, 0xa5, 0xd6, 0xd9, 0x12, 0x94, + 0x24, 0x4c, 0xb2, 0xb2, 0x5e, 0x96, 0x77, 0x28, 0xa6, 0xf1, 0x07, 0x8f, 0x92, 0x67, 0xde, 0x87, + 0x0c, 0x6f, 0xe0, 0xca, 0x3f, 0x4f, 0xab, 0x40, 0x36, 0xf2, 0xdb, 0x2a, 0x44, 0x1e, 0x46, 0xca, + 0x47, 0xef, 0xd9, 0x45, 0x87, 0x62, 0x7a, 0x35, 0xc1, 0x71, 0xf7, 0x5e, 0x80, 0x78, 0x48, 0x01, + 0x4a, 0xf8, 0xef, 0xd9, 0x8d, 0x14, 0x82, 0x2c, 0xda, 0x42, 0x35, 0xfa, 0x0b, 0xf1, 0x02, 0x8e, + 0x76, 0x7b, 0x3b, 0x46, 0xe3, 0xac, 0xfc, 0x97, 0xa2, 0x6f, 0x46, 0x84, 0x6a, 0xe0, 0x29, 0xc8, + 0xb2, 0x2d, 0x55, 0xa3, 0xd3, 0x1d, 0xaf, 0xa1, 0x26, 0x1b, 0x89, 0x65, 0xd5, 0x16, 0x6a, 0x79, + 0x73, 0xd2, 0xcf, 0xcb, 0xfa, 0x79, 0xd1, 0xfd, 0x67, 0xb4, 0x5e, 0xe8, 0xd9, 0x41, 0x05, 0x15, + 0x1d, 0x4c, 0x94, 0x75, 0x72, 0xf1, 0x8f, 0x7b, 0x30, 0x71, 0xbd, 0xd0, 0xc9, 0xb8, 0x13, 0x50, + 0x19, 0xbb, 0x73, 0xdd, 0x31, 0x2c, 0x7f, 0xbd, 0xd4, 0x79, 0x80, 0x1f, 0x19, 0xcf, 0x41, 0x30, + 0xc5, 0x3d, 0xdb, 0x5c, 0x22, 0x13, 0x9e, 0x42, 0x4d, 0xcc, 0x8e, 0x73, 0x81, 0x19, 0xf0, 0x0c, + 0x04, 0x31, 0x6f, 0xad, 0x93, 0x65, 0x5b, 0xa8, 0x32, 0x8d, 0x37, 0x0e, 0x2f, 0x01, 0xe8, 0x65, + 0xf0, 0x81, 0x9e, 0xb6, 0x36, 0xa4, 0x16, 0xa5, 0x6e, 0xf2, 0x64, 0x13, 0x1e, 0x45, 0xfa, 0xc3, + 0xdb, 0x8f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x42, 0x9a, 0xbc, 0x81, 0x65, 0x01, 0x00, 0x00, +} diff --git a/pkg/sentry/strace/strace_state_autogen.go b/pkg/sentry/strace/strace_state_autogen.go new file mode 100755 index 000000000..33f6a7a54 --- /dev/null +++ b/pkg/sentry/strace/strace_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package strace diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD deleted file mode 100644 index b8d1bd415..000000000 --- a/pkg/sentry/syscalls/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "syscalls", - srcs = [ - "epoll.go", - "syscalls.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/arch", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/time", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD deleted file mode 100644 index 0d24fd3c4..000000000 --- a/pkg/sentry/syscalls/linux/BUILD +++ /dev/null @@ -1,104 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "linux", - srcs = [ - "error.go", - "flags.go", - "linux64.go", - "linux64_amd64.go", - "linux64_arm64.go", - "sigset.go", - "sys_aio.go", - "sys_capability.go", - "sys_clone_amd64.go", - "sys_clone_arm64.go", - "sys_epoll.go", - "sys_eventfd.go", - "sys_file.go", - "sys_futex.go", - "sys_getdents.go", - "sys_identity.go", - "sys_inotify.go", - "sys_lseek.go", - "sys_mempolicy.go", - "sys_mmap.go", - "sys_mount.go", - "sys_pipe.go", - "sys_poll.go", - "sys_prctl.go", - "sys_random.go", - "sys_read.go", - "sys_rlimit.go", - "sys_rseq.go", - "sys_rusage.go", - "sys_sched.go", - "sys_seccomp.go", - "sys_sem.go", - "sys_shm.go", - "sys_signal.go", - "sys_socket.go", - "sys_splice.go", - "sys_stat.go", - "sys_stat_amd64.go", - "sys_stat_arm64.go", - "sys_sync.go", - "sys_sysinfo.go", - "sys_syslog.go", - "sys_thread.go", - "sys_time.go", - "sys_timer.go", - "sys_timerfd.go", - "sys_tls.go", - "sys_utsname.go", - "sys_write.go", - "sys_xattr.go", - "timespec.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/bpf", - "//pkg/context", - "//pkg/log", - "//pkg/metric", - "//pkg/rand", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fs/timerfd", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/fsbridge", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/eventfd", - "//pkg/sentry/kernel/fasync", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/kernel/signalfd", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/syscalls", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index 79066ad2a..79066ad2a 100644..100755 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go index 7421619de..7421619de 100644..100755 --- a/pkg/sentry/syscalls/linux/linux64_arm64.go +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go diff --git a/pkg/sentry/syscalls/linux/linux_amd64_state_autogen.go b/pkg/sentry/syscalls/linux/linux_amd64_state_autogen.go new file mode 100755 index 000000000..a98193a5b --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build amd64 +// +build amd64 + +package linux diff --git a/pkg/sentry/syscalls/linux/linux_arm64_state_autogen.go b/pkg/sentry/syscalls/linux/linux_arm64_state_autogen.go new file mode 100755 index 000000000..b144adbda --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build arm64 +// +build arm64 + +package linux diff --git a/pkg/sentry/syscalls/linux/linux_state_autogen.go b/pkg/sentry/syscalls/linux/linux_state_autogen.go new file mode 100755 index 000000000..5107dd246 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_state_autogen.go @@ -0,0 +1,83 @@ +// automatically generated by stateify. + +// +build amd64 +// +build amd64 arm64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *ioEvent) beforeSave() {} +func (x *ioEvent) save(m state.Map) { + x.beforeSave() + m.Save("Data", &x.Data) + m.Save("Obj", &x.Obj) + m.Save("Result", &x.Result) + m.Save("Result2", &x.Result2) +} + +func (x *ioEvent) afterLoad() {} +func (x *ioEvent) load(m state.Map) { + m.Load("Data", &x.Data) + m.Load("Obj", &x.Obj) + m.Load("Result", &x.Result) + m.Load("Result2", &x.Result2) +} + +func (x *futexWaitRestartBlock) beforeSave() {} +func (x *futexWaitRestartBlock) save(m state.Map) { + x.beforeSave() + m.Save("duration", &x.duration) + m.Save("addr", &x.addr) + m.Save("private", &x.private) + m.Save("val", &x.val) + m.Save("mask", &x.mask) +} + +func (x *futexWaitRestartBlock) afterLoad() {} +func (x *futexWaitRestartBlock) load(m state.Map) { + m.Load("duration", &x.duration) + m.Load("addr", &x.addr) + m.Load("private", &x.private) + m.Load("val", &x.val) + m.Load("mask", &x.mask) +} + +func (x *pollRestartBlock) beforeSave() {} +func (x *pollRestartBlock) save(m state.Map) { + x.beforeSave() + m.Save("pfdAddr", &x.pfdAddr) + m.Save("nfds", &x.nfds) + m.Save("timeout", &x.timeout) +} + +func (x *pollRestartBlock) afterLoad() {} +func (x *pollRestartBlock) load(m state.Map) { + m.Load("pfdAddr", &x.pfdAddr) + m.Load("nfds", &x.nfds) + m.Load("timeout", &x.timeout) +} + +func (x *clockNanosleepRestartBlock) beforeSave() {} +func (x *clockNanosleepRestartBlock) save(m state.Map) { + x.beforeSave() + m.Save("c", &x.c) + m.Save("duration", &x.duration) + m.Save("rem", &x.rem) +} + +func (x *clockNanosleepRestartBlock) afterLoad() {} +func (x *clockNanosleepRestartBlock) load(m state.Map) { + m.Load("c", &x.c) + m.Load("duration", &x.duration) + m.Load("rem", &x.rem) +} + +func init() { + state.Register("pkg/sentry/syscalls/linux.ioEvent", (*ioEvent)(nil), state.Fns{Save: (*ioEvent).save, Load: (*ioEvent).load}) + state.Register("pkg/sentry/syscalls/linux.futexWaitRestartBlock", (*futexWaitRestartBlock)(nil), state.Fns{Save: (*futexWaitRestartBlock).save, Load: (*futexWaitRestartBlock).load}) + state.Register("pkg/sentry/syscalls/linux.pollRestartBlock", (*pollRestartBlock)(nil), state.Fns{Save: (*pollRestartBlock).save, Load: (*pollRestartBlock).load}) + state.Register("pkg/sentry/syscalls/linux.clockNanosleepRestartBlock", (*clockNanosleepRestartBlock)(nil), state.Fns{Save: (*clockNanosleepRestartBlock).save, Load: (*clockNanosleepRestartBlock).load}) +} diff --git a/pkg/sentry/syscalls/linux/sys_clone_amd64.go b/pkg/sentry/syscalls/linux/sys_clone_amd64.go index dd43cf18d..dd43cf18d 100644..100755 --- a/pkg/sentry/syscalls/linux/sys_clone_amd64.go +++ b/pkg/sentry/syscalls/linux/sys_clone_amd64.go diff --git a/pkg/sentry/syscalls/linux/sys_clone_arm64.go b/pkg/sentry/syscalls/linux/sys_clone_arm64.go index cf68a8949..cf68a8949 100644..100755 --- a/pkg/sentry/syscalls/linux/sys_clone_arm64.go +++ b/pkg/sentry/syscalls/linux/sys_clone_arm64.go diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go index 90db10ea6..90db10ea6 100644..100755 --- a/pkg/sentry/syscalls/linux/sys_rseq.go +++ b/pkg/sentry/syscalls/linux/sys_rseq.go diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go index 0a04a6113..0a04a6113 100644..100755 --- a/pkg/sentry/syscalls/linux/sys_stat_amd64.go +++ b/pkg/sentry/syscalls/linux/sys_stat_amd64.go diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go index 5a3b1bfad..5a3b1bfad 100644..100755 --- a/pkg/sentry/syscalls/linux/sys_stat_arm64.go +++ b/pkg/sentry/syscalls/linux/sys_stat_arm64.go diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 2de5e3422..2de5e3422 100644..100755 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD deleted file mode 100644 index e7695e995..000000000 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "vfs2", - srcs = [ - "epoll.go", - "epoll_unsafe.go", - "execve.go", - "fd.go", - "filesystem.go", - "fscontext.go", - "getdents.go", - "ioctl.go", - "linux64.go", - "linux64_override_amd64.go", - "linux64_override_arm64.go", - "mmap.go", - "path.go", - "poll.go", - "read_write.go", - "setstat.go", - "stat.go", - "stat_amd64.go", - "stat_arm64.go", - "sync.go", - "xattr.go", - ], - marshal = True, - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/fspath", - "//pkg/gohacks", - "//pkg/sentry/arch", - "//pkg/sentry/fsbridge", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/memmap", - "//pkg/sentry/syscalls", - "//pkg/sentry/syscalls/linux", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index d6cb0e79a..d6cb0e79a 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go index 825f325bf..825f325bf 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go index aef0078a8..aef0078a8 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/execve.go +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 3afcea665..3afcea665 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index fc5ceea4c..fc5ceea4c 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go index 317409a18..317409a18 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/fscontext.go +++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go index ddc140b65..ddc140b65 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/getdents.go +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 5a2418da9..5a2418da9 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64.go b/pkg/sentry/syscalls/linux/vfs2/linux64.go index 19ee36081..19ee36081 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/linux64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64.go diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go index 7d220bc20..7d220bc20 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go index a6b367468..a6b367468 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go index 60a43f0a0..60a43f0a0 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/mmap.go +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go index 97da6c647..97da6c647 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/path.go +++ b/pkg/sentry/syscalls/linux/vfs2/path.go diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index dbf4882da..dbf4882da 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index 35f6308d6..35f6308d6 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 9250659ff..9250659ff 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go index a74ea6fd5..a74ea6fd5 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/stat.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go index 2da538fc6..2da538fc6 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go index 88b9c7627..88b9c7627 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index 365250b0b..365250b0b 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_abi_autogen_unsafe.go new file mode 100755 index 000000000..fb2182415 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_abi_autogen_unsafe.go @@ -0,0 +1,122 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*sigSetWithSize)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *sigSetWithSize) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *sigSetWithSize) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.sigsetAddr)) + dst = dst[8:] + usermem.ByteOrder.PutUint64(dst[:8], uint64(s.sizeofSigset)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *sigSetWithSize) UnmarshalBytes(src []byte) { + s.sigsetAddr = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.sizeofSigset = uint64(usermem.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +func (s *sigSetWithSize) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *sigSetWithSize) MarshalUnsafe(dst []byte) { + safecopy.CopyIn(dst, unsafe.Pointer(s)) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *sigSetWithSize) UnmarshalUnsafe(src []byte) { + safecopy.CopyOut(unsafe.Pointer(s), src) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (s *sigSetWithSize) CopyOut(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyOutBytes. + runtime.KeepAlive(s) + return err +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (s *sigSetWithSize) CopyIn(task marshal.Task, addr usermem.Addr) error { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + _, err := task.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the CopyInBytes. + runtime.KeepAlive(s) + return err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *sigSetWithSize) WriteTo(w io.Writer) (int64, error) { + // Bypass escape analysis on s. The no-op arithmetic operation on the + // pointer makes the compiler think val doesn't depend on s. + // See src/runtime/stubs.go:noescape() in the golang toolchain. + ptr := unsafe.Pointer(s) + val := uintptr(ptr) + val = val^0 + + // Construct a slice backed by s's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = val + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + len, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until after the Write. + runtime.KeepAlive(s) + return int64(len), err +} + diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_abi_autogen_unsafe.go new file mode 100755 index 000000000..fc1b597de --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_abi_autogen_unsafe.go @@ -0,0 +1,10 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// +build amd64 +// +build amd64 + +package vfs2 + +import ( +) + diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_state_autogen.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_state_autogen.go new file mode 100755 index 000000000..b44f63872 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build amd64 +// +build amd64 + +package vfs2 diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_abi_autogen_unsafe.go new file mode 100755 index 000000000..c2958877a --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_abi_autogen_unsafe.go @@ -0,0 +1,10 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// +build arm64 +// +build arm64 + +package vfs2 + +import ( +) + diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_state_autogen.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_state_autogen.go new file mode 100755 index 000000000..b61fa85e9 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build arm64 +// +build arm64 + +package vfs2 diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_state_autogen.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_state_autogen.go new file mode 100755 index 000000000..570100331 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_state_autogen.go @@ -0,0 +1,26 @@ +// automatically generated by stateify. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *pollRestartBlock) beforeSave() {} +func (x *pollRestartBlock) save(m state.Map) { + x.beforeSave() + m.Save("pfdAddr", &x.pfdAddr) + m.Save("nfds", &x.nfds) + m.Save("timeout", &x.timeout) +} + +func (x *pollRestartBlock) afterLoad() {} +func (x *pollRestartBlock) load(m state.Map) { + m.Load("pfdAddr", &x.pfdAddr) + m.Load("nfds", &x.nfds) + m.Load("timeout", &x.timeout) +} + +func init() { + state.Register("pkg/sentry/syscalls/linux/vfs2.pollRestartBlock", (*pollRestartBlock)(nil), state.Fns{Save: (*pollRestartBlock).save, Load: (*pollRestartBlock).load}) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index 89e9ff4d7..89e9ff4d7 100644..100755 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go diff --git a/pkg/sentry/syscalls/syscalls_state_autogen.go b/pkg/sentry/syscalls/syscalls_state_autogen.go new file mode 100755 index 000000000..b577e39a3 --- /dev/null +++ b/pkg/sentry/syscalls/syscalls_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package syscalls diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD deleted file mode 100644 index 04f81a35b..000000000 --- a/pkg/sentry/time/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "seqatomic_parameters", - out = "seqatomic_parameters_unsafe.go", - package = "time", - suffix = "Parameters", - template = "//pkg/sync:generic_seqatomic", - types = { - "Value": "Parameters", - }, -) - -go_library( - name = "time", - srcs = [ - "arith_arm64.go", - "calibrated_clock.go", - "clock_id.go", - "clocks.go", - "muldiv_amd64.s", - "muldiv_arm64.s", - "parameters.go", - "sampler.go", - "sampler_unsafe.go", - "seqatomic_parameters_unsafe.go", - "tsc_amd64.s", - "tsc_arm64.s", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/metric", - "//pkg/sync", - "//pkg/syserror", - ], -) - -go_test( - name = "time_test", - srcs = [ - "calibrated_clock_test.go", - "parameters_test.go", - "sampler_test.go", - ], - library = ":time", -) diff --git a/pkg/sentry/time/LICENSE b/pkg/sentry/time/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/sentry/time/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sentry/time/calibrated_clock_test.go b/pkg/sentry/time/calibrated_clock_test.go deleted file mode 100644 index d6622bfe2..000000000 --- a/pkg/sentry/time/calibrated_clock_test.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "testing" - "time" -) - -// newTestCalibratedClock returns a CalibratedClock that collects samples from -// the given sample list and cycle counts from the given cycle list. -func newTestCalibratedClock(samples []sample, cycles []TSCValue) *CalibratedClock { - return &CalibratedClock{ - ref: newTestSampler(samples, cycles), - } -} - -func TestConstantFrequency(t *testing.T) { - // Perfectly constant frequency. - samples := []sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - {before: 200000, after: 200000 + defaultOverheadCycles, ref: 200}, - {before: 300000, after: 300000 + defaultOverheadCycles, ref: 300}, - {before: 400000, after: 400000 + defaultOverheadCycles, ref: 400}, - {before: 500000, after: 500000 + defaultOverheadCycles, ref: 500}, - {before: 600000, after: 600000 + defaultOverheadCycles, ref: 600}, - {before: 700000, after: 700000 + defaultOverheadCycles, ref: 700}, - } - - c := newTestCalibratedClock(samples, nil) - - // Update from all samples. - for range samples { - c.Update() - } - - c.mu.RLock() - if !c.ready { - c.mu.RUnlock() - t.Fatalf("clock not ready") - } - // A bit after the last sample. - now, ok := c.params.ComputeTime(750000) - c.mu.RUnlock() - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - - t.Logf("now: %v", now) - - // Time should be between the current sample and where we'd expect the - // next sample. - if now < 700 || now > 800 { - t.Errorf("now got %v want > 700 && < 800", now) - } -} - -func TestErrorCorrection(t *testing.T) { - testCases := []struct { - name string - samples [5]sample - projectedTimeStart int64 - projectedTimeEnd int64 - }{ - // Initial calibration should be ~1MHz for each of these, and - // the reference clock changes in samples[2]. - { - name: "slow-down", - samples: [5]sample{ - {before: 1000000, after: 1000001, ref: ReferenceNS(1 * ApproxUpdateInterval.Nanoseconds())}, - {before: 2000000, after: 2000001, ref: ReferenceNS(2 * ApproxUpdateInterval.Nanoseconds())}, - // Reference clock has slowed down, causing 100ms of error. - {before: 3010000, after: 3010001, ref: ReferenceNS(3 * ApproxUpdateInterval.Nanoseconds())}, - {before: 4020000, after: 4020001, ref: ReferenceNS(4 * ApproxUpdateInterval.Nanoseconds())}, - {before: 5030000, after: 5030001, ref: ReferenceNS(5 * ApproxUpdateInterval.Nanoseconds())}, - }, - projectedTimeStart: 3005 * time.Millisecond.Nanoseconds(), - projectedTimeEnd: 3015 * time.Millisecond.Nanoseconds(), - }, - { - name: "speed-up", - samples: [5]sample{ - {before: 1000000, after: 1000001, ref: ReferenceNS(1 * ApproxUpdateInterval.Nanoseconds())}, - {before: 2000000, after: 2000001, ref: ReferenceNS(2 * ApproxUpdateInterval.Nanoseconds())}, - // Reference clock has sped up, causing 100ms of error. - {before: 2990000, after: 2990001, ref: ReferenceNS(3 * ApproxUpdateInterval.Nanoseconds())}, - {before: 3980000, after: 3980001, ref: ReferenceNS(4 * ApproxUpdateInterval.Nanoseconds())}, - {before: 4970000, after: 4970001, ref: ReferenceNS(5 * ApproxUpdateInterval.Nanoseconds())}, - }, - projectedTimeStart: 2985 * time.Millisecond.Nanoseconds(), - projectedTimeEnd: 2995 * time.Millisecond.Nanoseconds(), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := newTestCalibratedClock(tc.samples[:], nil) - - // Initial calibration takes two updates. - _, ok := c.Update() - if ok { - t.Fatalf("Update ready too early") - } - - params, ok := c.Update() - if !ok { - t.Fatalf("Update not ready") - } - - // Initial calibration is ~1MHz. - hz := params.Frequency - if hz < 990000 || hz > 1010000 { - t.Fatalf("Frequency got %v want > 990kHz && < 1010kHz", hz) - } - - // Project time at the next update. Given the 1MHz - // calibration, it is expected to be ~3.1s/2.9s, not - // the actual 3s. - // - // N.B. the next update time is the "after" time above. - projected, ok := params.ComputeTime(tc.samples[2].after) - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - if projected < tc.projectedTimeStart || projected > tc.projectedTimeEnd { - t.Fatalf("ComputeTime(%v) got %v want > %v && < %v", tc.samples[2].after, projected, tc.projectedTimeStart, tc.projectedTimeEnd) - } - - // Update again to see the changed reference clock. - params, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - - // We now know that TSC = tc.samples[2].after -> 3s, - // but with the previous params indicated that TSC - // tc.samples[2].after -> 3.5s/2.5s. We can't allow the - // clock to go backwards, and having the clock jump - // forwards is undesirable. There should be a smooth - // transition that corrects the clock error over time. - // Check that the clock is continuous at TSC = - // tc.samples[2].after. - newProjected, ok := params.ComputeTime(tc.samples[2].after) - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - if newProjected != projected { - t.Errorf("Discontinuous time; ComputeTime(%v) got %v want %v", tc.samples[2].after, newProjected, projected) - } - - // As the reference clock stablizes, ensure that the clock error - // decreases. - initialErr := c.errorNS - t.Logf("initial error: %v ns", initialErr) - - _, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - if c.errorNS.Magnitude() > initialErr.Magnitude() { - t.Errorf("errorNS increased, got %v want |%v| <= |%v|", c.errorNS, c.errorNS, initialErr) - } - - _, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - if c.errorNS.Magnitude() > initialErr.Magnitude() { - t.Errorf("errorNS increased, got %v want |%v| <= |%v|", c.errorNS, c.errorNS, initialErr) - } - - t.Logf("final error: %v ns", c.errorNS) - }) - } -} diff --git a/pkg/sentry/time/parameters_test.go b/pkg/sentry/time/parameters_test.go deleted file mode 100644 index e1b9084ac..000000000 --- a/pkg/sentry/time/parameters_test.go +++ /dev/null @@ -1,486 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "math" - "testing" - "time" -) - -func TestParametersComputeTime(t *testing.T) { - testCases := []struct { - name string - params Parameters - now TSCValue - want int64 - }{ - { - // Now is the same as the base cycles. - name: "base-cycles", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 10000, - want: 5000 * time.Millisecond.Nanoseconds(), - }, - { - // Now is the behind the base cycles. Time is frozen. - name: "backwards", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 9000, - want: 5000 * time.Millisecond.Nanoseconds(), - }, - { - // Now is ahead of the base cycles. - name: "ahead", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 15000, - want: 5500 * time.Millisecond.Nanoseconds(), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, ok := tc.params.ComputeTime(tc.now) - if !ok { - t.Errorf("ComputeTime ok got %v want true", got) - } - if got != tc.want { - t.Errorf("ComputeTime got %+v want %+v", got, tc.want) - } - }) - } -} - -func TestParametersErrorAdjust(t *testing.T) { - testCases := []struct { - name string - oldParams Parameters - now TSCValue - newParams Parameters - want Parameters - errorNS ReferenceNS - wantErr bool - }{ - { - // newParams are perfectly continuous with oldParams - // and don't need adjustment. - name: "continuous", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - newParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - }, - { - // Same as "continuous", but with now ahead of - // newParams.BaseCycles. The result is the same as - // there is no error to correct. - name: "continuous-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 60000, - newParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - }, - { - // errorAdjust bails out if the TSC goes backwards. - name: "tsc-backwards", - oldParams: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(1000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 9000, - newParams: Parameters{ - BaseCycles: 9000, - BaseRef: ReferenceNS(1100 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - wantErr: true, - }, - { - // errorAdjust bails out if new params are from after now. - name: "params-after-now", - oldParams: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(1000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 11000, - newParams: Parameters{ - BaseCycles: 12000, - BaseRef: ReferenceNS(1200 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - wantErr: true, - }, - { - // Host clock sped up. - name: "speed-up", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 45000, - // Host frequency changed to 9000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - // From oldParams, we think ref = 4.5s at cycles = 45000. - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 9000, - }, - want: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(4500 * time.Millisecond.Nanoseconds()), - // We must decrease the new frequency by 50% to - // correct 0.5s of error in 1s - // (ApproxUpdateInterval). - Frequency: 4500, - }, - errorNS: ReferenceNS(-500 * time.Millisecond.Nanoseconds()), - }, - { - // Host clock sped up, with now ahead of newParams. - name: "speed-up-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - // Host frequency changed to 9000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 9000, - }, - // nextRef = 6000ms - // nextCycles = 9000 * (6000ms - 5000ms) + 45000 - // nextCycles = 9000 * (1s) + 45000 - // nextCycles = 54000 - // f = (54000 - 50000) / 1s = 4000 - // - // ref = 5000ms - (50000 - 45000) / 4000 - // ref = 3.75s - want: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(3750 * time.Millisecond.Nanoseconds()), - Frequency: 4000, - }, - // oldNow = 50000 * 10000 = 5s - // newNow = (50000 - 45000) / 9000 + 5s = 5.555s - errorNS: ReferenceNS((5000*time.Millisecond - 5555555555).Nanoseconds()), - }, - { - // Host clock sped up. The new parameters are so far - // ahead that the next update time already passed. - name: "speed-up-uncorrectable-baseref", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - // Host frequency changed to 5000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(9000 * time.Millisecond.Nanoseconds()), - Frequency: 5000, - }, - // The next update should be at 10s, but newParams - // already passed 6s. Thus it is impossible to correct - // the clock by then. - wantErr: true, - }, - { - // Host clock sped up. The new parameters are moving so - // fast that the next update should be before now. - name: "speed-up-uncorrectable-frequency", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 55000, - // Host frequency changed to 7500 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(6000 * time.Millisecond.Nanoseconds()), - Frequency: 7500, - }, - // The next update should be at 6.5s, but newParams are - // so far ahead and fast that they reach 6.5s at cycle - // 48750, which before now! Thus it is impossible to - // correct the clock by then. - wantErr: true, - }, - { - // Host clock slowed down. - name: "slow-down", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 55000, - // Host frequency changed to 11000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 55000, - // From oldParams, we think ref = 5.5s at cycles = 55000. - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 11000, - }, - want: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5500 * time.Millisecond.Nanoseconds()), - // We must increase the new frequency by 50% to - // correct 0.5s of error in 1s - // (ApproxUpdateInterval). - Frequency: 16500, - }, - errorNS: ReferenceNS(500 * time.Millisecond.Nanoseconds()), - }, - { - // Host clock slowed down, with now ahead of newParams. - name: "slow-down-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 60000, - // Host frequency changed to 11000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 11000, - }, - // nextRef = 7000ms - // nextCycles = 11000 * (7000ms - 5000ms) + 55000 - // nextCycles = 11000 * (2000ms) + 55000 - // nextCycles = 77000 - // f = (77000 - 60000) / 1s = 17000 - // - // ref = 6000ms - (60000 - 55000) / 17000 - // ref = 5705882353ns - want: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5705882353), - Frequency: 17000, - }, - // oldNow = 60000 * 10000 = 6s - // newNow = (60000 - 55000) / 11000 + 5s = 5.4545s - errorNS: ReferenceNS((6*time.Second - 5454545454).Nanoseconds()), - }, - { - // Host time went backwards. - name: "time-backwards", - oldParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 60000, - newParams: Parameters{ - BaseCycles: 60000, - // From oldParams, we think ref = 6s at cycles = 60000. - BaseRef: ReferenceNS(4000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(6000 * time.Millisecond.Nanoseconds()), - // We must increase the frequency by 200% to - // correct 2s of error in 1s - // (ApproxUpdateInterval). - Frequency: 30000, - }, - errorNS: ReferenceNS(2000 * time.Millisecond.Nanoseconds()), - }, - { - // Host time went backwards, with now ahead of newParams. - name: "time-backwards-nowdiff", - oldParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 65000, - // nextRef = 7500ms - // nextCycles = 10000 * (7500ms - 4000ms) + 60000 - // nextCycles = 10000 * (3500ms) + 60000 - // nextCycles = 95000 - // f = (95000 - 65000) / 1s = 30000 - // - // ref = 6500ms - (65000 - 60000) / 30000 - // ref = 6333333333ns - newParams: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(4000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(6333333334), - Frequency: 30000, - }, - // oldNow = 65000 * 10000 = 6.5s - // newNow = (65000 - 60000) / 10000 + 4s = 4.5s - errorNS: ReferenceNS(2000 * time.Millisecond.Nanoseconds()), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, errorNS, err := errorAdjust(tc.oldParams, tc.newParams, tc.now) - if err != nil && !tc.wantErr { - t.Errorf("err got %v want nil", err) - } else if err == nil && tc.wantErr { - t.Errorf("err got nil want non-nil") - } - - if got != tc.want { - t.Errorf("Parameters got %+v want %+v", got, tc.want) - } - if errorNS != tc.errorNS { - t.Errorf("errorNS got %v want %v", errorNS, tc.errorNS) - } - }) - } -} - -func testMuldiv(t *testing.T, v uint64) { - for i := uint64(1); i <= 1000000; i++ { - mult := uint64(1000000000) - div := i * mult - res, ok := muldiv64(v, mult, div) - if !ok { - t.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div) - } - if want := v / i; res != want { - t.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want) - } - } -} - -func TestMulDiv(t *testing.T) { - testMuldiv(t, math.MaxUint64) - for i := int64(-10); i <= 10; i++ { - testMuldiv(t, uint64(i)) - } -} - -func TestMulDivZero(t *testing.T) { - if r, ok := muldiv64(2, 4, 0); ok { - t.Errorf("muldiv64(2, 4, 0) got %d, ok want !ok", r) - } - - if r, ok := muldiv64(0, 0, 0); ok { - t.Errorf("muldiv64(0, 0, 0) got %d, ok want !ok", r) - } -} - -func TestMulDivOverflow(t *testing.T) { - testCases := []struct { - name string - val uint64 - mult uint64 - div uint64 - ok bool - ret uint64 - }{ - { - name: "2^62", - val: 1 << 63, - mult: 4, - div: 8, - ok: true, - ret: 1 << 62, - }, - { - name: "2^64-1", - val: 0xffffffffffffffff, - mult: 1, - div: 1, - ok: true, - ret: 0xffffffffffffffff, - }, - { - name: "2^64", - val: 1 << 63, - mult: 4, - div: 2, - ok: false, - }, - { - name: "2^125", - val: 1 << 63, - mult: 1 << 63, - div: 2, - ok: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - r, ok := muldiv64(tc.val, tc.mult, tc.div) - if ok != tc.ok { - t.Errorf("ok got %v want %v", ok, tc.ok) - } - if tc.ok && r != tc.ret { - t.Errorf("ret got %v want %v", r, tc.ret) - } - }) - } -} diff --git a/pkg/sentry/time/sampler_test.go b/pkg/sentry/time/sampler_test.go deleted file mode 100644 index 3e70a1134..000000000 --- a/pkg/sentry/time/sampler_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "errors" - "testing" -) - -// errNoSamples is returned when testReferenceClocks runs out of samples. -var errNoSamples = errors.New("no samples available") - -// testReferenceClocks returns a preset list of samples and cycle counts. -type testReferenceClocks struct { - samples []sample - cycles []TSCValue -} - -// Sample implements referenceClocks.Sample, returning the next sample in the list. -func (t *testReferenceClocks) Sample(_ ClockID) (sample, error) { - if len(t.samples) == 0 { - return sample{}, errNoSamples - } - - s := t.samples[0] - if len(t.samples) == 1 { - t.samples = nil - } else { - t.samples = t.samples[1:] - } - - return s, nil -} - -// Cycles implements referenceClocks.Cycles, returning the next TSCValue in the list. -func (t *testReferenceClocks) Cycles() TSCValue { - if len(t.cycles) == 0 { - return 0 - } - - c := t.cycles[0] - if len(t.cycles) == 1 { - t.cycles = nil - } else { - t.cycles = t.cycles[1:] - } - - return c -} - -// newTestSampler returns a sampler that collects samples from -// the given sample list and cycle counts from the given cycle list. -func newTestSampler(samples []sample, cycles []TSCValue) *sampler { - return &sampler{ - clocks: &testReferenceClocks{ - samples: samples, - cycles: cycles, - }, - overhead: defaultOverheadCycles, - } -} - -// generateSamples generates n samples with the given overhead. -func generateSamples(n int, overhead TSCValue) []sample { - samples := []sample{{before: 1000000, after: 1000000 + overhead, ref: 100}} - for i := 0; i < n-1; i++ { - prev := samples[len(samples)-1] - samples = append(samples, sample{ - before: prev.before + 1000000, - after: prev.after + 1000000, - ref: prev.ref + 100, - }) - } - return samples -} - -// TestSample ensures that samples can be collected. -func TestSample(t *testing.T) { - testCases := []struct { - name string - samples []sample - err error - }{ - { - name: "basic", - samples: []sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - }, - err: nil, - }, - { - // Sample with backwards TSC ignored. - // referenceClock should retry and get errNoSamples. - name: "backwards-tsc-ignored", - samples: []sample{ - {before: 100000, after: 90000, ref: 100}, - }, - err: errNoSamples, - }, - { - // Sample far above overhead skipped. - // referenceClock should retry and get errNoSamples. - name: "reject-overhead", - samples: []sample{ - {before: 100000, after: 100000 + 5*defaultOverheadCycles, ref: 100}, - }, - err: errNoSamples, - }, - { - // Maximum overhead allowed is bounded. - name: "over-max-overhead", - // Generate a bunch of samples. The reference clock - // needs a while to ramp up its expected overhead. - samples: generateSamples(100, 2*maxOverheadCycles), - err: errOverheadTooHigh, - }, - { - // Overhead at maximum overhead is allowed. - name: "max-overhead", - // Generate a bunch of samples. The reference clock - // needs a while to ramp up its expected overhead. - samples: generateSamples(100, maxOverheadCycles), - err: nil, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := newTestSampler(tc.samples, nil) - err := s.Sample() - if err != tc.err { - t.Errorf("Sample err got %v want %v", err, tc.err) - } - }) - } -} - -// TestOutliersIgnored tests that referenceClock ignores samples with very high -// overhead. -func TestOutliersIgnored(t *testing.T) { - s := newTestSampler([]sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - {before: 200000, after: 200000 + defaultOverheadCycles, ref: 200}, - {before: 300000, after: 300000 + defaultOverheadCycles, ref: 300}, - {before: 400000, after: 400000 + defaultOverheadCycles, ref: 400}, - {before: 500000, after: 500000 + 5*defaultOverheadCycles, ref: 500}, // Ignored - {before: 600000, after: 600000 + defaultOverheadCycles, ref: 600}, - {before: 700000, after: 700000 + defaultOverheadCycles, ref: 700}, - }, nil) - - // Collect 5 samples. - for i := 0; i < 5; i++ { - err := s.Sample() - if err != nil { - t.Fatalf("Unexpected error while sampling: %v", err) - } - } - - oldest, newest, ok := s.Range() - if !ok { - t.Fatalf("Range not ok") - } - - if oldest.ref != 100 { - t.Errorf("oldest.ref got %v want %v", oldest.ref, 100) - } - - // We skipped the high-overhead sample. - if newest.ref != 600 { - t.Errorf("newest.ref got %v want %v", newest.ref, 600) - } -} diff --git a/pkg/sentry/time/seqatomic_parameters_unsafe.go b/pkg/sentry/time/seqatomic_parameters_unsafe.go new file mode 100755 index 000000000..efd3ccae2 --- /dev/null +++ b/pkg/sentry/time/seqatomic_parameters_unsafe.go @@ -0,0 +1,55 @@ +package time + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/pkg/sync" +) + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoadParameters(sc *sync.SeqCount, ptr *Parameters) Parameters { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val Parameters + for { + epoch := sc.BeginRead() + if sync.RaceEnabled { + + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoadParameters(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Parameters) (Parameters, bool) { + var val Parameters + if sync.RaceEnabled { + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func initParameters() { + var val Parameters + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/sentry/time/time_arm64_state_autogen.go b/pkg/sentry/time/time_arm64_state_autogen.go new file mode 100755 index 000000000..2adc9c9e0 --- /dev/null +++ b/pkg/sentry/time/time_arm64_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package time diff --git a/pkg/sentry/time/time_state_autogen.go b/pkg/sentry/time/time_state_autogen.go new file mode 100755 index 000000000..2adc9c9e0 --- /dev/null +++ b/pkg/sentry/time/time_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package time diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD deleted file mode 100644 index 5d4aa3a63..000000000 --- a/pkg/sentry/unimpl/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library", "proto_library") - -package(licenses = ["notice"]) - -proto_library( - name = "unimplemented_syscall", - srcs = ["unimplemented_syscall.proto"], - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_proto"], -) - -go_library( - name = "unimpl", - srcs = ["events.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/context", - "//pkg/log", - ], -) diff --git a/pkg/sentry/unimpl/unimpl_state_autogen.go b/pkg/sentry/unimpl/unimpl_state_autogen.go new file mode 100755 index 000000000..b37d16f87 --- /dev/null +++ b/pkg/sentry/unimpl/unimpl_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package unimpl diff --git a/pkg/sentry/unimpl/unimplemented_syscall.proto b/pkg/sentry/unimpl/unimplemented_syscall.proto deleted file mode 100644 index 0d7a94be7..000000000 --- a/pkg/sentry/unimpl/unimplemented_syscall.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -import "pkg/sentry/arch/registers.proto"; - -message UnimplementedSyscall { - // Task ID. - int32 tid = 1; - - // Registers at the time of the call. - Registers registers = 2; -} diff --git a/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go b/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go new file mode 100755 index 000000000..4dfb169cc --- /dev/null +++ b/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go @@ -0,0 +1,91 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/unimpl/unimplemented_syscall.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + registers_go_proto "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type UnimplementedSyscall struct { + Tid int32 `protobuf:"varint,1,opt,name=tid,proto3" json:"tid,omitempty"` + Registers *registers_go_proto.Registers `protobuf:"bytes,2,opt,name=registers,proto3" json:"registers,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *UnimplementedSyscall) Reset() { *m = UnimplementedSyscall{} } +func (m *UnimplementedSyscall) String() string { return proto.CompactTextString(m) } +func (*UnimplementedSyscall) ProtoMessage() {} +func (*UnimplementedSyscall) Descriptor() ([]byte, []int) { + return fileDescriptor_ddc2fcd2bea3c75d, []int{0} +} + +func (m *UnimplementedSyscall) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_UnimplementedSyscall.Unmarshal(m, b) +} +func (m *UnimplementedSyscall) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_UnimplementedSyscall.Marshal(b, m, deterministic) +} +func (m *UnimplementedSyscall) XXX_Merge(src proto.Message) { + xxx_messageInfo_UnimplementedSyscall.Merge(m, src) +} +func (m *UnimplementedSyscall) XXX_Size() int { + return xxx_messageInfo_UnimplementedSyscall.Size(m) +} +func (m *UnimplementedSyscall) XXX_DiscardUnknown() { + xxx_messageInfo_UnimplementedSyscall.DiscardUnknown(m) +} + +var xxx_messageInfo_UnimplementedSyscall proto.InternalMessageInfo + +func (m *UnimplementedSyscall) GetTid() int32 { + if m != nil { + return m.Tid + } + return 0 +} + +func (m *UnimplementedSyscall) GetRegisters() *registers_go_proto.Registers { + if m != nil { + return m.Registers + } + return nil +} + +func init() { + proto.RegisterType((*UnimplementedSyscall)(nil), "gvisor.UnimplementedSyscall") +} + +func init() { + proto.RegisterFile("pkg/sentry/unimpl/unimplemented_syscall.proto", fileDescriptor_ddc2fcd2bea3c75d) +} + +var fileDescriptor_ddc2fcd2bea3c75d = []byte{ + // 149 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2d, 0xc8, 0x4e, 0xd7, + 0x2f, 0x4e, 0xcd, 0x2b, 0x29, 0xaa, 0xd4, 0x2f, 0xcd, 0xcb, 0xcc, 0x2d, 0xc8, 0x81, 0x52, 0xa9, + 0xb9, 0xa9, 0x79, 0x25, 0xa9, 0x29, 0xf1, 0xc5, 0x95, 0xc5, 0xc9, 0x89, 0x39, 0x39, 0x7a, 0x05, + 0x45, 0xf9, 0x25, 0xf9, 0x42, 0x6c, 0xe9, 0x65, 0x99, 0xc5, 0xf9, 0x45, 0x52, 0xf2, 0x48, 0xda, + 0x12, 0x8b, 0x92, 0x33, 0xf4, 0x8b, 0x52, 0xd3, 0x33, 0x8b, 0x4b, 0x52, 0x8b, 0x8a, 0x21, 0x0a, + 0x95, 0x22, 0xb9, 0x44, 0x42, 0x91, 0xcd, 0x09, 0x86, 0x18, 0x23, 0x24, 0xc0, 0xc5, 0x5c, 0x92, + 0x99, 0x22, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1, 0x1a, 0x04, 0x62, 0x0a, 0xe9, 0x73, 0x71, 0xc2, 0x35, + 0x4b, 0x30, 0x29, 0x30, 0x6a, 0x70, 0x1b, 0x09, 0xea, 0x41, 0xac, 0xd1, 0x0b, 0x82, 0x49, 0x04, + 0x21, 0xd4, 0x24, 0xb1, 0x81, 0x6d, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x51, 0x4a, 0x47, + 0x79, 0xbb, 0x00, 0x00, 0x00, +} diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD deleted file mode 100644 index 7467e6398..000000000 --- a/pkg/sentry/uniqueid/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "uniqueid", - srcs = ["context.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/sentry/socket/unix/transport", - ], -) diff --git a/pkg/sentry/uniqueid/uniqueid_state_autogen.go b/pkg/sentry/uniqueid/uniqueid_state_autogen.go new file mode 100755 index 000000000..1890fdf46 --- /dev/null +++ b/pkg/sentry/uniqueid/uniqueid_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package uniqueid diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD deleted file mode 100644 index 099315613..000000000 --- a/pkg/sentry/usage/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "usage", - srcs = [ - "cpu.go", - "io.go", - "memory.go", - "memory_unsafe.go", - "usage.go", - ], - visibility = [ - "//:sandbox", - ], - deps = [ - "//pkg/bits", - "//pkg/memutil", - "//pkg/sync", - ], -) diff --git a/pkg/sentry/usage/usage_state_autogen.go b/pkg/sentry/usage/usage_state_autogen.go new file mode 100755 index 000000000..42979ea25 --- /dev/null +++ b/pkg/sentry/usage/usage_state_autogen.go @@ -0,0 +1,50 @@ +// automatically generated by stateify. + +package usage + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *CPUStats) beforeSave() {} +func (x *CPUStats) save(m state.Map) { + x.beforeSave() + m.Save("UserTime", &x.UserTime) + m.Save("SysTime", &x.SysTime) + m.Save("VoluntarySwitches", &x.VoluntarySwitches) +} + +func (x *CPUStats) afterLoad() {} +func (x *CPUStats) load(m state.Map) { + m.Load("UserTime", &x.UserTime) + m.Load("SysTime", &x.SysTime) + m.Load("VoluntarySwitches", &x.VoluntarySwitches) +} + +func (x *IO) beforeSave() {} +func (x *IO) save(m state.Map) { + x.beforeSave() + m.Save("CharsRead", &x.CharsRead) + m.Save("CharsWritten", &x.CharsWritten) + m.Save("ReadSyscalls", &x.ReadSyscalls) + m.Save("WriteSyscalls", &x.WriteSyscalls) + m.Save("BytesRead", &x.BytesRead) + m.Save("BytesWritten", &x.BytesWritten) + m.Save("BytesWriteCancelled", &x.BytesWriteCancelled) +} + +func (x *IO) afterLoad() {} +func (x *IO) load(m state.Map) { + m.Load("CharsRead", &x.CharsRead) + m.Load("CharsWritten", &x.CharsWritten) + m.Load("ReadSyscalls", &x.ReadSyscalls) + m.Load("WriteSyscalls", &x.WriteSyscalls) + m.Load("BytesRead", &x.BytesRead) + m.Load("BytesWritten", &x.BytesWritten) + m.Load("BytesWriteCancelled", &x.BytesWriteCancelled) +} + +func init() { + state.Register("pkg/sentry/usage.CPUStats", (*CPUStats)(nil), state.Fns{Save: (*CPUStats).save, Load: (*CPUStats).load}) + state.Register("pkg/sentry/usage.IO", (*IO)(nil), state.Fns{Save: (*IO).save, Load: (*IO).load}) +} diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD deleted file mode 100644 index cb4deb068..000000000 --- a/pkg/sentry/vfs/BUILD +++ /dev/null @@ -1,79 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "epoll_interest_list", - out = "epoll_interest_list.go", - package = "vfs", - prefix = "epollInterest", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*epollInterest", - "Linker": "*epollInterest", - }, -) - -go_library( - name = "vfs", - srcs = [ - "anonfs.go", - "context.go", - "debug.go", - "dentry.go", - "device.go", - "epoll.go", - "epoll_interest_list.go", - "file_description.go", - "file_description_impl_util.go", - "filesystem.go", - "filesystem_impl_util.go", - "filesystem_type.go", - "mount.go", - "mount_unsafe.go", - "options.go", - "pathname.go", - "permissions.go", - "resolving_path.go", - "vfs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fd", - "//pkg/fspath", - "//pkg/gohacks", - "//pkg/log", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "vfs_test", - size = "small", - srcs = [ - "file_description_impl_util_test.go", - "mount_test.go", - ], - library = ":vfs", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sync", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md deleted file mode 100644 index 9aa133bcb..000000000 --- a/pkg/sentry/vfs/README.md +++ /dev/null @@ -1,197 +0,0 @@ -# The gVisor Virtual Filesystem - -THIS PACKAGE IS CURRENTLY EXPERIMENTAL AND NOT READY OR ENABLED FOR PRODUCTION -USE. For the filesystem implementation currently used by gVisor, see the `fs` -package. - -## Implementation Notes - -### Reference Counting - -Filesystem, Dentry, Mount, MountNamespace, and FileDescription are all -reference-counted. Mount and MountNamespace are exclusively VFS-managed; when -their reference count reaches zero, VFS releases their resources. Filesystem and -FileDescription management is shared between VFS and filesystem implementations; -when their reference count reaches zero, VFS notifies the implementation by -calling `FilesystemImpl.Release()` or `FileDescriptionImpl.Release()` -respectively and then releases VFS-owned resources. Dentries are exclusively -managed by filesystem implementations; reference count changes are abstracted -through DentryImpl, which should release resources when reference count reaches -zero. - -Filesystem references are held by: - -- Mount: Each referenced Mount holds a reference on the mounted Filesystem. - -Dentry references are held by: - -- FileDescription: Each referenced FileDescription holds a reference on the - Dentry through which it was opened, via `FileDescription.vd.dentry`. - -- Mount: Each referenced Mount holds a reference on its mount point and on the - mounted filesystem root. The mount point is mutable (`mount(MS_MOVE)`). - -Mount references are held by: - -- FileDescription: Each referenced FileDescription holds a reference on the - Mount on which it was opened, via `FileDescription.vd.mount`. - -- Mount: Each referenced Mount holds a reference on its parent, which is the - mount containing its mount point. - -- VirtualFilesystem: A reference is held on each Mount that has not been - umounted. - -MountNamespace and FileDescription references are held by users of VFS. The -expectation is that each `kernel.Task` holds a reference on its corresponding -MountNamespace, and each file descriptor holds a reference on its represented -FileDescription. - -Notes: - -- Dentries do not hold a reference on their owning Filesystem. Instead, all - uses of a Dentry occur in the context of a Mount, which holds a reference on - the relevant Filesystem (see e.g. the VirtualDentry type). As a corollary, - when releasing references on both a Dentry and its corresponding Mount, the - Dentry's reference must be released first (because releasing the Mount's - reference may release the last reference on the Filesystem, whose state may - be required to release the Dentry reference). - -### The Inheritance Pattern - -Filesystem, Dentry, and FileDescription are all concepts featuring both state -that must be shared between VFS and filesystem implementations, and operations -that are implementation-defined. To facilitate this, each of these three -concepts follows the same pattern, shown below for Dentry: - -```go -// Dentry represents a node in a filesystem tree. -type Dentry struct { - // VFS-required dentry state. - parent *Dentry - // ... - - // impl is the DentryImpl associated with this Dentry. impl is immutable. - // This should be the last field in Dentry. - impl DentryImpl -} - -// Init must be called before first use of d. -func (d *Dentry) Init(impl DentryImpl) { - d.impl = impl -} - -// Impl returns the DentryImpl associated with d. -func (d *Dentry) Impl() DentryImpl { - return d.impl -} - -// DentryImpl contains implementation-specific details of a Dentry. -// Implementations of DentryImpl should contain their associated Dentry by -// value as their first field. -type DentryImpl interface { - // VFS-required implementation-defined dentry operations. - IncRef() - // ... -} -``` - -This construction, which is essentially a type-safe analogue to Linux's -`container_of` pattern, has the following properties: - -- VFS works almost exclusively with pointers to Dentry rather than DentryImpl - interface objects, such as in the type of `Dentry.parent`. This avoids - interface method calls (which are somewhat expensive to perform, and defeat - inlining and escape analysis), reduces the size of VFS types (since an - interface object is two pointers in size), and allows pointers to be loaded - and stored atomically using `sync/atomic`. Implementation-defined behavior - is accessed via `Dentry.impl` when required. - -- Filesystem implementations can access the implementation-defined state - associated with objects of VFS types by type-asserting or type-switching - (e.g. `Dentry.Impl().(*myDentry)`). Type assertions to a concrete type - require only an equality comparison of the interface object's type pointer - to a static constant, and are consequently very fast. - -- Filesystem implementations can access the VFS state associated with objects - of implementation-defined types directly. - -- VFS and implementation-defined state for a given type occupy the same - object, minimizing memory allocations and maximizing memory locality. `impl` - is the last field in `Dentry`, and `Dentry` is the first field in - `DentryImpl` implementations, for similar reasons: this tends to cause - fetching of the `Dentry.impl` interface object to also fetch `DentryImpl` - fields, either because they are in the same cache line or via next-line - prefetching. - -## Future Work - -- Most `mount(2)` features, and unmounting, are incomplete. - -- VFS1 filesystems are not directly compatible with VFS2. It may be possible - to implement shims that implement `vfs.FilesystemImpl` for - `fs.MountNamespace`, `vfs.DentryImpl` for `fs.Dirent`, and - `vfs.FileDescriptionImpl` for `fs.File`, which may be adequate for - filesystems that are not performance-critical (e.g. sysfs); however, it is - not clear that this will be less effort than simply porting the filesystems - in question. Practically speaking, the following filesystems will probably - need to be ported or made compatible through a shim to evaluate filesystem - performance on realistic workloads: - - - devfs/procfs/sysfs, which will realistically be necessary to execute - most applications. (Note that procfs and sysfs do not support hard - links, so they do not require the complexity of separate inode objects. - Also note that Linux's /dev is actually a variant of tmpfs called - devtmpfs.) - - - tmpfs. This should be relatively straightforward: copy/paste memfs, - store regular file contents in pgalloc-allocated memory instead of - `[]byte`, and add support for file timestamps. (In fact, it probably - makes more sense to convert memfs to tmpfs and not keep the former.) - - - A remote filesystem, either lisafs (if it is ready by the time that - other benchmarking prerequisites are) or v9fs (aka 9P, aka gofers). - - - epoll files. - - Filesystems that will need to be ported before switching to VFS2, but can - probably be skipped for early testing: - - - overlayfs, which is needed for (at least) synthetic mount points. - - - Support for host ttys. - - - timerfd files. - - Filesystems that can be probably dropped: - - - ashmem, which is far too incomplete to use. - - - binder, which is similarly far too incomplete to use. - - - whitelistfs, which we are already actively attempting to remove. - -- Save/restore. For instance, it is unclear if the current implementation of - the `state` package supports the inheritance pattern described above. - -- Many features that were previously implemented by VFS must now be - implemented by individual filesystems (though, in most cases, this should - consist of calls to hooks or libraries provided by `vfs` or other packages). - This includes, but is not necessarily limited to: - - - Block and character device special files - - - Inotify - - - File locking - - - `O_ASYNC` - -- Reference counts in the `vfs` package do not use the `refs` package since - `refs.AtomicRefCount` adds 64 bytes of overhead to each 8-byte reference - count, resulting in considerable cache bloat. 24 bytes of this overhead is - for weak reference support, which have poor performance and will not be used - by VFS2. The remaining 40 bytes is to store a descriptive string and stack - trace for reference leak checking; we can support reference leak checking - without incurring this space overhead by including the applicable - information directly in finalizers for applicable types. diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 925996517..925996517 100644..100755 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go index 82781e6d3..82781e6d3 100644..100755 --- a/pkg/sentry/vfs/context.go +++ b/pkg/sentry/vfs/context.go diff --git a/pkg/sentry/vfs/debug.go b/pkg/sentry/vfs/debug.go index 0ed20f249..0ed20f249 100644..100755 --- a/pkg/sentry/vfs/debug.go +++ b/pkg/sentry/vfs/debug.go diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 35b208721..35b208721 100644..100755 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go index bda5576fa..bda5576fa 100644..100755 --- a/pkg/sentry/vfs/device.go +++ b/pkg/sentry/vfs/device.go diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 3da45d744..3da45d744 100644..100755 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go diff --git a/pkg/sentry/vfs/epoll_interest_list.go b/pkg/sentry/vfs/epoll_interest_list.go new file mode 100755 index 000000000..1bd41f400 --- /dev/null +++ b/pkg/sentry/vfs/epoll_interest_list.go @@ -0,0 +1,186 @@ +package vfs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type epollInterestElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (epollInterestElementMapper) linkerFor(elem *epollInterest) *epollInterest { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type epollInterestList struct { + head *epollInterest + tail *epollInterest +} + +// Reset resets list l to the empty state. +func (l *epollInterestList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *epollInterestList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *epollInterestList) Front() *epollInterest { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *epollInterestList) Back() *epollInterest { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *epollInterestList) PushFront(e *epollInterest) { + linker := epollInterestElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + epollInterestElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *epollInterestList) PushBack(e *epollInterest) { + linker := epollInterestElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + epollInterestElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *epollInterestList) PushBackList(m *epollInterestList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + epollInterestElementMapper{}.linkerFor(l.tail).SetNext(m.head) + epollInterestElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *epollInterestList) InsertAfter(b, e *epollInterest) { + bLinker := epollInterestElementMapper{}.linkerFor(b) + eLinker := epollInterestElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + epollInterestElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *epollInterestList) InsertBefore(a, e *epollInterest) { + aLinker := epollInterestElementMapper{}.linkerFor(a) + eLinker := epollInterestElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + epollInterestElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *epollInterestList) Remove(e *epollInterest) { + linker := epollInterestElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + epollInterestElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + epollInterestElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type epollInterestEntry struct { + next *epollInterest + prev *epollInterest +} + +// Next returns the entry that follows e in the list. +func (e *epollInterestEntry) Next() *epollInterest { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *epollInterestEntry) Prev() *epollInterest { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *epollInterestEntry) SetNext(elem *epollInterest) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *epollInterestEntry) SetPrev(elem *epollInterest) { + e.prev = elem +} diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 9a1ad630c..9a1ad630c 100644..100755 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 45191d1c3..45191d1c3 100644..100755 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go deleted file mode 100644 index 3a75d4d62..000000000 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs - -import ( - "bytes" - "fmt" - "io" - "sync/atomic" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" -) - -// fileDescription is the common fd struct which a filesystem implementation -// embeds in all of its file description implementations as required. -type fileDescription struct { - vfsfd FileDescription - FileDescriptionDefaultImpl -} - -// genCount contains the number of times its DynamicBytesSource.Generate() -// implementation has been called. -type genCount struct { - count uint64 // accessed using atomic memory ops -} - -// Generate implements DynamicBytesSource.Generate. -func (g *genCount) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "%d", atomic.AddUint64(&g.count, 1)) - return nil -} - -type storeData struct { - data string -} - -var _ WritableDynamicBytesSource = (*storeData)(nil) - -// Generate implements DynamicBytesSource. -func (d *storeData) Generate(ctx context.Context, buf *bytes.Buffer) error { - buf.WriteString(d.data) - return nil -} - -// Generate implements WritableDynamicBytesSource. -func (d *storeData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { - buf := make([]byte, src.NumBytes()) - n, err := src.CopyIn(ctx, buf) - if err != nil { - return 0, err - } - - d.data = string(buf[:n]) - return 0, nil -} - -// testFD is a read-only FileDescriptionImpl representing a regular file. -type testFD struct { - fileDescription - DynamicBytesFileDescriptionImpl - - data DynamicBytesSource -} - -func newTestFD(vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription { - vd := vfsObj.NewAnonVirtualDentry("genCountFD") - defer vd.DecRef() - var fd testFD - fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}) - fd.DynamicBytesFileDescriptionImpl.SetDataSource(data) - return &fd.vfsfd -} - -// Release implements FileDescriptionImpl.Release. -func (fd *testFD) Release() { -} - -// SetStatusFlags implements FileDescriptionImpl.SetStatusFlags. -// Stat implements FileDescriptionImpl.Stat. -func (fd *testFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { - // Note that Statx.Mask == 0 in the return value. - return linux.Statx{}, nil -} - -// SetStat implements FileDescriptionImpl.SetStat. -func (fd *testFD) SetStat(ctx context.Context, opts SetStatOptions) error { - return syserror.EPERM -} - -func TestGenCountFD(t *testing.T) { - ctx := contexttest.Context(t) - - vfsObj := &VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - t.Fatalf("VFS init: %v", err) - } - fd := newTestFD(vfsObj, linux.O_RDWR, &genCount{}) - defer fd.DecRef() - - // The first read causes Generate to be called to fill the FD's buffer. - buf := make([]byte, 2) - ioseq := usermem.BytesIOSequence(buf) - n, err := fd.Read(ctx, ioseq, ReadOptions{}) - if n != 1 || (err != nil && err != io.EOF) { - t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err) - } - if want := byte('1'); buf[0] != want { - t.Errorf("first Read: got byte %c, wanted %c", buf[0], want) - } - - // A second read without seeking is still at EOF. - n, err = fd.Read(ctx, ioseq, ReadOptions{}) - if n != 0 || err != io.EOF { - t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err) - } - - // Seeking to the beginning of the file causes it to be regenerated. - n, err = fd.Seek(ctx, 0, linux.SEEK_SET) - if n != 0 || err != nil { - t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err) - } - n, err = fd.Read(ctx, ioseq, ReadOptions{}) - if n != 1 || (err != nil && err != io.EOF) { - t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err) - } - if want := byte('2'); buf[0] != want { - t.Errorf("Read after Seek: got byte %c, wanted %c", buf[0], want) - } - - // PRead at the beginning of the file also causes it to be regenerated. - n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{}) - if n != 1 || (err != nil && err != io.EOF) { - t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err) - } - if want := byte('3'); buf[0] != want { - t.Errorf("PRead: got byte %c, wanted %c", buf[0], want) - } - - // Write and PWrite fails. - if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EINVAL { - t.Errorf("Write: got err %v, wanted %v", err, syserror.EINVAL) - } - if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EINVAL { - t.Errorf("Write: got err %v, wanted %v", err, syserror.EINVAL) - } -} - -func TestWritable(t *testing.T) { - ctx := contexttest.Context(t) - - vfsObj := &VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - t.Fatalf("VFS init: %v", err) - } - fd := newTestFD(vfsObj, linux.O_RDWR, &storeData{data: "init"}) - defer fd.DecRef() - - buf := make([]byte, 10) - ioseq := usermem.BytesIOSequence(buf) - if n, err := fd.Read(ctx, ioseq, ReadOptions{}); n != 4 && err != io.EOF { - t.Fatalf("Read: got (%v, %v), wanted (4, EOF)", n, err) - } - if want := "init"; want == string(buf) { - t.Fatalf("Read: got %v, wanted %v", string(buf), want) - } - - // Test PWrite. - want := "write" - writeIOSeq := usermem.BytesIOSequence([]byte(want)) - if n, err := fd.PWrite(ctx, writeIOSeq, 0, WriteOptions{}); int(n) != len(want) && err != nil { - t.Errorf("PWrite: got err (%v, %v), wanted (%v, nil)", n, err, len(want)) - } - if n, err := fd.PRead(ctx, ioseq, 0, ReadOptions{}); int(n) != len(want) && err != io.EOF { - t.Fatalf("PRead: got (%v, %v), wanted (%v, EOF)", n, err, len(want)) - } - if want == string(buf) { - t.Fatalf("PRead: got %v, wanted %v", string(buf), want) - } - - // Test Seek to 0 followed by Write. - want = "write2" - writeIOSeq = usermem.BytesIOSequence([]byte(want)) - if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 && err != nil { - t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err) - } - if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); int(n) != len(want) && err != nil { - t.Errorf("Write: got err (%v, %v), wanted (%v, nil)", n, err, len(want)) - } - if n, err := fd.PRead(ctx, ioseq, 0, ReadOptions{}); int(n) != len(want) && err != io.EOF { - t.Fatalf("PRead: got (%v, %v), wanted (%v, EOF)", n, err, len(want)) - } - if want == string(buf) { - t.Fatalf("PRead: got %v, wanted %v", string(buf), want) - } - - // Test failure if offset != 0. - if n, err := fd.Seek(ctx, 1, linux.SEEK_SET); n != 0 && err != nil { - t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err) - } - if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && err != syserror.EINVAL { - t.Errorf("Write: got err (%v, %v), wanted (0, EINVAL)", n, err) - } - if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && err != syserror.EINVAL { - t.Errorf("PWrite: got err (%v, %v), wanted (0, EINVAL)", n, err) - } -} diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index c43dcff3d..c43dcff3d 100644..100755 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go index 7315a588e..7315a588e 100644..100755 --- a/pkg/sentry/vfs/filesystem_impl_util.go +++ b/pkg/sentry/vfs/filesystem_impl_util.go diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go index bb9cada81..bb9cada81 100644..100755 --- a/pkg/sentry/vfs/filesystem_type.go +++ b/pkg/sentry/vfs/filesystem_type.go diff --git a/pkg/sentry/vfs/lock/BUILD b/pkg/sentry/vfs/lock/BUILD deleted file mode 100644 index d9ab063b7..000000000 --- a/pkg/sentry/vfs/lock/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "lock", - srcs = ["lock.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/fs/lock", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/vfs/lock/lock.go b/pkg/sentry/vfs/lock/lock.go deleted file mode 100644 index 724dfe743..000000000 --- a/pkg/sentry/vfs/lock/lock.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package lock provides POSIX and BSD style file locking for VFS2 file -// implementations. -// -// The actual implementations can be found in the lock package under -// sentry/fs/lock. -package lock - -import ( - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/syserror" -) - -// FileLocks supports POSIX and BSD style locks, which correspond to fcntl(2) -// and flock(2) respectively in Linux. It can be embedded into various file -// implementations for VFS2 that support locking. -// -// Note that in Linux these two types of locks are _not_ cooperative, because -// race and deadlock conditions make merging them prohibitive. We do the same -// and keep them oblivious to each other. -type FileLocks struct { - // bsd is a set of BSD-style advisory file wide locks, see flock(2). - bsd fslock.Locks - - // posix is a set of POSIX-style regional advisory locks, see fcntl(2). - posix fslock.Locks -} - -// LockBSD tries to acquire a BSD-style lock on the entire file. -func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { - if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) { - return nil - } - return syserror.ErrWouldBlock -} - -// UnlockBSD releases a BSD-style lock on the entire file. -// -// This operation is always successful, even if there did not exist a lock on -// the requested region held by uid in the first place. -func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) { - fl.bsd.UnlockRegion(uid, fslock.LockRange{0, fslock.LockEOF}) -} - -// LockPOSIX tries to acquire a POSIX-style lock on a file region. -func (fl *FileLocks) LockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { - if fl.posix.LockRegion(uid, t, rng, block) { - return nil - } - return syserror.ErrWouldBlock -} - -// UnlockPOSIX releases a POSIX-style lock on a file region. -// -// This operation is always successful, even if there did not exist a lock on -// the requested region held by uid in the first place. -func (fl *FileLocks) UnlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) { - fl.posix.UnlockRegion(uid, rng) -} diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 05f6233f9..05f6233f9 100644..100755 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go deleted file mode 100644 index 3b933468d..000000000 --- a/pkg/sentry/vfs/mount_test.go +++ /dev/null @@ -1,458 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs - -import ( - "fmt" - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestMountTableLookupEmpty(t *testing.T) { - var mt mountTable - mt.Init() - - parent := &Mount{} - point := &Dentry{} - if m := mt.Lookup(parent, point); m != nil { - t.Errorf("empty mountTable lookup: got %p, wanted nil", m) - } -} - -func TestMountTableInsertLookup(t *testing.T) { - var mt mountTable - mt.Init() - - mount := &Mount{} - mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) - mt.Insert(mount) - - if m := mt.Lookup(mount.parent(), mount.point()); m != mount { - t.Errorf("mountTable positive lookup: got %p, wanted %p", m, mount) - } - - otherParent := &Mount{} - if m := mt.Lookup(otherParent, mount.point()); m != nil { - t.Errorf("mountTable lookup with wrong mount parent: got %p, wanted nil", m) - } - otherPoint := &Dentry{} - if m := mt.Lookup(mount.parent(), otherPoint); m != nil { - t.Errorf("mountTable lookup with wrong mount point: got %p, wanted nil", m) - } -} - -// TODO: concurrent lookup/insertion/removal - -// must be powers of 2 -var benchNumMounts = []int{1 << 2, 1 << 5, 1 << 8} - -// For all of the following: -// -// - BenchmarkMountTableFoo tests usage pattern "Foo" for mountTable. -// -// - BenchmarkMountMapFoo tests usage pattern "Foo" for a -// sync.RWMutex-protected map. (Mutator benchmarks do not use a RWMutex, since -// mountTable also requires external synchronization between mutators.) -// -// - BenchmarkMountSyncMapFoo tests usage pattern "Foo" for a sync.Map. -// -// ParallelLookup is by far the most common and performance-sensitive operation -// for this application. NegativeLookup is also important, but less so (only -// relevant with multiple mount namespaces and significant differences in -// mounts between them). Insertion and removal are benchmarked for -// completeness. -const enableComparativeBenchmarks = false - -func newBenchMount() *Mount { - mount := &Mount{} - mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) - return mount -} - -func BenchmarkMountTableParallelLookup(b *testing.B) { - for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%dx%d", numG, numMounts) - b.Run(desc, func(b *testing.B) { - var mt mountTable - mt.Init() - keys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - mt.Insert(mount) - keys = append(keys, mount.loadKey()) - } - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for g := 0; g < numG; g++ { - ready.Add(1) - end.Add(1) - go func() { - defer end.Done() - ready.Done() - <-begin - for i := 0; i < b.N; i++ { - k := keys[i&(numMounts-1)] - m := mt.Lookup(k.mount, k.dentry) - if m == nil { - b.Fatalf("lookup failed") - } - if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) - } - if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) - } - } - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } - } -} - -func BenchmarkMountMapParallelLookup(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%dx%d", numG, numMounts) - b.Run(desc, func(b *testing.B) { - var mu sync.RWMutex - ms := make(map[VirtualDentry]*Mount) - keys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - key := mount.loadKey() - ms[key] = mount - keys = append(keys, key) - } - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for g := 0; g < numG; g++ { - ready.Add(1) - end.Add(1) - go func() { - defer end.Done() - ready.Done() - <-begin - for i := 0; i < b.N; i++ { - k := keys[i&(numMounts-1)] - mu.RLock() - m := ms[k] - mu.RUnlock() - if m == nil { - b.Fatalf("lookup failed") - } - if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) - } - if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) - } - } - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } - } -} - -func BenchmarkMountSyncMapParallelLookup(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%dx%d", numG, numMounts) - b.Run(desc, func(b *testing.B) { - var ms sync.Map - keys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - key := mount.loadKey() - ms.Store(key, mount) - keys = append(keys, key) - } - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for g := 0; g < numG; g++ { - ready.Add(1) - end.Add(1) - go func() { - defer end.Done() - ready.Done() - <-begin - for i := 0; i < b.N; i++ { - k := keys[i&(numMounts-1)] - mi, ok := ms.Load(k) - if !ok { - b.Fatalf("lookup failed") - } - m := mi.(*Mount) - if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) - } - if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) - } - } - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } - } -} - -func BenchmarkMountTableNegativeLookup(b *testing.B) { - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%d", numMounts) - b.Run(desc, func(b *testing.B) { - var mt mountTable - mt.Init() - for i := 0; i < numMounts; i++ { - mt.Insert(newBenchMount()) - } - negkeys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - negkeys = append(negkeys, VirtualDentry{ - mount: &Mount{}, - dentry: &Dentry{}, - }) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - k := negkeys[i&(numMounts-1)] - m := mt.Lookup(k.mount, k.dentry) - if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) - } - } - }) - } -} - -func BenchmarkMountMapNegativeLookup(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%d", numMounts) - b.Run(desc, func(b *testing.B) { - var mu sync.RWMutex - ms := make(map[VirtualDentry]*Mount) - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - ms[mount.loadKey()] = mount - } - negkeys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - negkeys = append(negkeys, VirtualDentry{ - mount: &Mount{}, - dentry: &Dentry{}, - }) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - k := negkeys[i&(numMounts-1)] - mu.RLock() - m := ms[k] - mu.RUnlock() - if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) - } - } - }) - } -} - -func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%d", numMounts) - b.Run(desc, func(b *testing.B) { - var ms sync.Map - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - ms.Store(mount.loadKey(), mount) - } - negkeys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - negkeys = append(negkeys, VirtualDentry{ - mount: &Mount{}, - dentry: &Dentry{}, - }) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - k := negkeys[i&(numMounts-1)] - m, _ := ms.Load(k) - if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) - } - } - }) - } -} - -func BenchmarkMountTableInsert(b *testing.B) { - // Preallocate Mounts so that allocation time isn't included in the - // benchmark. - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - - var mt mountTable - mt.Init() - b.ResetTimer() - for i := range mounts { - mt.Insert(mounts[i]) - } -} - -func BenchmarkMountMapInsert(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - // Preallocate Mounts so that allocation time isn't included in the - // benchmark. - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - - ms := make(map[VirtualDentry]*Mount) - b.ResetTimer() - for i := range mounts { - mount := mounts[i] - ms[mount.loadKey()] = mount - } -} - -func BenchmarkMountSyncMapInsert(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - // Preallocate Mounts so that allocation time isn't included in the - // benchmark. - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - - var ms sync.Map - b.ResetTimer() - for i := range mounts { - mount := mounts[i] - ms.Store(mount.loadKey(), mount) - } -} - -func BenchmarkMountTableRemove(b *testing.B) { - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - var mt mountTable - mt.Init() - for i := range mounts { - mt.Insert(mounts[i]) - } - - b.ResetTimer() - for i := range mounts { - mt.Remove(mounts[i]) - } -} - -func BenchmarkMountMapRemove(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - ms := make(map[VirtualDentry]*Mount) - for i := range mounts { - mount := mounts[i] - ms[mount.loadKey()] = mount - } - - b.ResetTimer() - for i := range mounts { - mount := mounts[i] - delete(ms, mount.loadKey()) - } -} - -func BenchmarkMountSyncMapRemove(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - var ms sync.Map - for i := range mounts { - mount := mounts[i] - ms.Store(mount.loadKey(), mount) - } - - b.ResetTimer() - for i := range mounts { - mount := mounts[i] - ms.Delete(mount.loadKey()) - } -} diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index bc7581698..bc7581698 100644..100755 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 3e90dc4ed..3e90dc4ed 100644..100755 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index b318c681a..b318c681a 100644..100755 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index 8e250998a..8e250998a 100644..100755 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index eb4ebb511..eb4ebb511 100644..100755 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 2e2880171..2e2880171 100644..100755 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go diff --git a/pkg/sentry/vfs/vfs_state_autogen.go b/pkg/sentry/vfs/vfs_state_autogen.go new file mode 100755 index 000000000..036defa97 --- /dev/null +++ b/pkg/sentry/vfs/vfs_state_autogen.go @@ -0,0 +1,219 @@ +// automatically generated by stateify. + +// +build go1.12 +// +build !go1.15 + +package vfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Dentry) beforeSave() {} +func (x *Dentry) save(m state.Map) { + x.beforeSave() + m.Save("parent", &x.parent) + m.Save("name", &x.name) + m.Save("flags", &x.flags) + m.Save("mounts", &x.mounts) + m.Save("children", &x.children) + m.Save("impl", &x.impl) +} + +func (x *Dentry) afterLoad() {} +func (x *Dentry) load(m state.Map) { + m.Load("parent", &x.parent) + m.Load("name", &x.name) + m.Load("flags", &x.flags) + m.Load("mounts", &x.mounts) + m.Load("children", &x.children) + m.Load("impl", &x.impl) +} + +func (x *registeredDevice) beforeSave() {} +func (x *registeredDevice) save(m state.Map) { + x.beforeSave() + m.Save("dev", &x.dev) + m.Save("opts", &x.opts) +} + +func (x *registeredDevice) afterLoad() {} +func (x *registeredDevice) load(m state.Map) { + m.Load("dev", &x.dev) + m.Load("opts", &x.opts) +} + +func (x *RegisterDeviceOptions) beforeSave() {} +func (x *RegisterDeviceOptions) save(m state.Map) { + x.beforeSave() + m.Save("GroupName", &x.GroupName) +} + +func (x *RegisterDeviceOptions) afterLoad() {} +func (x *RegisterDeviceOptions) load(m state.Map) { + m.Load("GroupName", &x.GroupName) +} + +func (x *epollInterestList) beforeSave() {} +func (x *epollInterestList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *epollInterestList) afterLoad() {} +func (x *epollInterestList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *epollInterestEntry) beforeSave() {} +func (x *epollInterestEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *epollInterestEntry) afterLoad() {} +func (x *epollInterestEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *Filesystem) beforeSave() {} +func (x *Filesystem) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) + m.Save("vfs", &x.vfs) + m.Save("impl", &x.impl) +} + +func (x *Filesystem) afterLoad() {} +func (x *Filesystem) load(m state.Map) { + m.Load("refs", &x.refs) + m.Load("vfs", &x.vfs) + m.Load("impl", &x.impl) +} + +func (x *registeredFilesystemType) beforeSave() {} +func (x *registeredFilesystemType) save(m state.Map) { + x.beforeSave() + m.Save("fsType", &x.fsType) + m.Save("opts", &x.opts) +} + +func (x *registeredFilesystemType) afterLoad() {} +func (x *registeredFilesystemType) load(m state.Map) { + m.Load("fsType", &x.fsType) + m.Load("opts", &x.opts) +} + +func (x *Mount) beforeSave() {} +func (x *Mount) save(m state.Map) { + x.beforeSave() + m.Save("vfs", &x.vfs) + m.Save("fs", &x.fs) + m.Save("root", &x.root) + m.Save("key", &x.key) + m.Save("ns", &x.ns) + m.Save("refs", &x.refs) + m.Save("children", &x.children) + m.Save("umounted", &x.umounted) + m.Save("flags", &x.flags) + m.Save("writers", &x.writers) +} + +func (x *Mount) afterLoad() {} +func (x *Mount) load(m state.Map) { + m.Load("vfs", &x.vfs) + m.Load("fs", &x.fs) + m.Load("root", &x.root) + m.Load("key", &x.key) + m.Load("ns", &x.ns) + m.Load("refs", &x.refs) + m.Load("children", &x.children) + m.Load("umounted", &x.umounted) + m.Load("flags", &x.flags) + m.Load("writers", &x.writers) +} + +func (x *MountNamespace) beforeSave() {} +func (x *MountNamespace) save(m state.Map) { + x.beforeSave() + m.Save("root", &x.root) + m.Save("refs", &x.refs) + m.Save("mountpoints", &x.mountpoints) +} + +func (x *MountNamespace) afterLoad() {} +func (x *MountNamespace) load(m state.Map) { + m.Load("root", &x.root) + m.Load("refs", &x.refs) + m.Load("mountpoints", &x.mountpoints) +} + +func (x *mountTable) beforeSave() {} +func (x *mountTable) save(m state.Map) { + x.beforeSave() + m.Save("seed", &x.seed) + m.Save("size", &x.size) +} + +func (x *mountTable) afterLoad() {} +func (x *mountTable) load(m state.Map) { + m.Load("seed", &x.seed) + m.Load("size", &x.size) +} + +func (x *VirtualFilesystem) beforeSave() {} +func (x *VirtualFilesystem) save(m state.Map) { + x.beforeSave() + m.Save("mounts", &x.mounts) + m.Save("mountpoints", &x.mountpoints) + m.Save("anonMount", &x.anonMount) + m.Save("devices", &x.devices) + m.Save("anonBlockDevMinorNext", &x.anonBlockDevMinorNext) + m.Save("anonBlockDevMinor", &x.anonBlockDevMinor) + m.Save("fsTypes", &x.fsTypes) + m.Save("filesystems", &x.filesystems) +} + +func (x *VirtualFilesystem) afterLoad() {} +func (x *VirtualFilesystem) load(m state.Map) { + m.Load("mounts", &x.mounts) + m.Load("mountpoints", &x.mountpoints) + m.Load("anonMount", &x.anonMount) + m.Load("devices", &x.devices) + m.Load("anonBlockDevMinorNext", &x.anonBlockDevMinorNext) + m.Load("anonBlockDevMinor", &x.anonBlockDevMinor) + m.Load("fsTypes", &x.fsTypes) + m.Load("filesystems", &x.filesystems) +} + +func (x *VirtualDentry) beforeSave() {} +func (x *VirtualDentry) save(m state.Map) { + x.beforeSave() + m.Save("mount", &x.mount) + m.Save("dentry", &x.dentry) +} + +func (x *VirtualDentry) afterLoad() {} +func (x *VirtualDentry) load(m state.Map) { + m.Load("mount", &x.mount) + m.Load("dentry", &x.dentry) +} + +func init() { + state.Register("pkg/sentry/vfs.Dentry", (*Dentry)(nil), state.Fns{Save: (*Dentry).save, Load: (*Dentry).load}) + state.Register("pkg/sentry/vfs.registeredDevice", (*registeredDevice)(nil), state.Fns{Save: (*registeredDevice).save, Load: (*registeredDevice).load}) + state.Register("pkg/sentry/vfs.RegisterDeviceOptions", (*RegisterDeviceOptions)(nil), state.Fns{Save: (*RegisterDeviceOptions).save, Load: (*RegisterDeviceOptions).load}) + state.Register("pkg/sentry/vfs.epollInterestList", (*epollInterestList)(nil), state.Fns{Save: (*epollInterestList).save, Load: (*epollInterestList).load}) + state.Register("pkg/sentry/vfs.epollInterestEntry", (*epollInterestEntry)(nil), state.Fns{Save: (*epollInterestEntry).save, Load: (*epollInterestEntry).load}) + state.Register("pkg/sentry/vfs.Filesystem", (*Filesystem)(nil), state.Fns{Save: (*Filesystem).save, Load: (*Filesystem).load}) + state.Register("pkg/sentry/vfs.registeredFilesystemType", (*registeredFilesystemType)(nil), state.Fns{Save: (*registeredFilesystemType).save, Load: (*registeredFilesystemType).load}) + state.Register("pkg/sentry/vfs.Mount", (*Mount)(nil), state.Fns{Save: (*Mount).save, Load: (*Mount).load}) + state.Register("pkg/sentry/vfs.MountNamespace", (*MountNamespace)(nil), state.Fns{Save: (*MountNamespace).save, Load: (*MountNamespace).load}) + state.Register("pkg/sentry/vfs.mountTable", (*mountTable)(nil), state.Fns{Save: (*mountTable).save, Load: (*mountTable).load}) + state.Register("pkg/sentry/vfs.VirtualFilesystem", (*VirtualFilesystem)(nil), state.Fns{Save: (*VirtualFilesystem).save, Load: (*VirtualFilesystem).load}) + state.Register("pkg/sentry/vfs.VirtualDentry", (*VirtualDentry)(nil), state.Fns{Save: (*VirtualDentry).save, Load: (*VirtualDentry).load}) +} diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD deleted file mode 100644 index 1c5a1c9b6..000000000 --- a/pkg/sentry/watchdog/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "watchdog", - srcs = ["watchdog.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/metric", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sync", - ], -) diff --git a/pkg/sentry/watchdog/watchdog_state_autogen.go b/pkg/sentry/watchdog/watchdog_state_autogen.go new file mode 100755 index 000000000..bce0200e7 --- /dev/null +++ b/pkg/sentry/watchdog/watchdog_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package watchdog diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD deleted file mode 100644 index e131455f7..000000000 --- a/pkg/sleep/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sleep", - srcs = [ - "commit_amd64.s", - "commit_arm64.s", - "commit_asm.go", - "commit_noasm.go", - "sleep_unsafe.go", - ], - visibility = ["//:sandbox"], -) - -go_test( - name = "sleep_test", - size = "medium", - srcs = [ - "sleep_test.go", - ], - library = ":sleep", -) diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s index d0ef15b20..d0ef15b20 100644..100755 --- a/pkg/sleep/commit_arm64.s +++ b/pkg/sleep/commit_arm64.s diff --git a/pkg/sleep/empty.s b/pkg/sleep/empty.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/sleep/empty.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/sleep/sleep_state_autogen.go b/pkg/sleep/sleep_state_autogen.go new file mode 100755 index 000000000..e8727e1c9 --- /dev/null +++ b/pkg/sleep/sleep_state_autogen.go @@ -0,0 +1,9 @@ +// automatically generated by stateify. + +// +build amd64 arm64 +// +build !race +// +build !amd64,!arm64 +// +build go1.11 +// +build !go1.15 + +package sleep diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go deleted file mode 100644 index af47e2ba1..000000000 --- a/pkg/sleep/sleep_test.go +++ /dev/null @@ -1,573 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sleep - -import ( - "math/rand" - "runtime" - "testing" - "time" -) - -// ZeroWakerNotAsserted tests that a zero-value waker is in non-asserted state. -func ZeroWakerNotAsserted(t *testing.T) { - var w Waker - if w.IsAsserted() { - t.Fatalf("Zero waker is asserted") - } - - if w.Clear() { - t.Fatalf("Zero waker is asserted") - } -} - -// AssertedWakerAfterAssert tests that a waker properly reports its state as -// asserted once its Assert() method is called. -func AssertedWakerAfterAssert(t *testing.T) { - var w Waker - w.Assert() - if !w.IsAsserted() { - t.Fatalf("Asserted waker is not reported as such") - } - - if !w.Clear() { - t.Fatalf("Asserted waker is not reported as such") - } -} - -// AssertedWakerAfterTwoAsserts tests that a waker properly reports its state as -// asserted once its Assert() method is called twice. -func AssertedWakerAfterTwoAsserts(t *testing.T) { - var w Waker - w.Assert() - w.Assert() - if !w.IsAsserted() { - t.Fatalf("Asserted waker is not reported as such") - } - - if !w.Clear() { - t.Fatalf("Asserted waker is not reported as such") - } -} - -// NotAssertedWakerWithSleeper tests that a waker properly reports its state as -// not asserted after a sleeper is associated with it. -func NotAssertedWakerWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - if w.IsAsserted() { - t.Fatalf("Non-asserted waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Non-asserted waker is reported as asserted") - } -} - -// NotAssertedWakerAfterWake tests that a waker properly reports its state as -// not asserted after a previous assert is consumed by a sleeper. That is, tests -// the "edge-triggered" behavior. -func NotAssertedWakerAfterWake(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Assert() - s.Fetch(true) - if w.IsAsserted() { - t.Fatalf("Consumed waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Consumed waker is reported as asserted") - } -} - -// AssertedWakerBeforeAdd tests that a waker causes a sleeper to not sleep if -// it's already asserted before being added. -func AssertedWakerBeforeAdd(t *testing.T) { - var w Waker - var s Sleeper - w.Assert() - s.AddWaker(&w, 0) - - if _, ok := s.Fetch(false); !ok { - t.Fatalf("Fetch failed even though asserted waker was added") - } -} - -// ClearedWaker tests that a waker properly reports its state as not asserted -// after it is cleared. -func ClearedWaker(t *testing.T) { - var w Waker - w.Assert() - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// ClearedWakerWithSleeper tests that a waker properly reports its state as -// not asserted when it is cleared while it has a sleeper associated with it. -func ClearedWakerWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// ClearedWakerAssertedWithSleeper tests that a waker properly reports its state -// as not asserted when it is cleared while it has a sleeper associated with it -// and has been asserted. -func ClearedWakerAssertedWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Assert() - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// TestBlock tests that a sleeper actually blocks waiting for the waker to -// assert its state. -func TestBlock(t *testing.T) { - var w Waker - var s Sleeper - - s.AddWaker(&w, 0) - - // Assert waker after one second. - before := time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() - - // Fetch the result and make sure it took at least 500ms. - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - if d := time.Now().Sub(before); d < 500*time.Millisecond { - t.Fatalf("Duration was too short: %v", d) - } - - // Check that already-asserted waker completes inline. - w.Assert() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - - // Check that fetch sleeps if waker had been asserted but was reset - // before Fetch is called. - w.Assert() - w.Clear() - before = time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - if d := time.Now().Sub(before); d < 500*time.Millisecond { - t.Fatalf("Duration was too short: %v", d) - } -} - -// TestNonBlock checks that a sleeper won't block if waker isn't asserted. -func TestNonBlock(t *testing.T) { - var w Waker - var s Sleeper - - // Don't block when there's no waker. - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when there is no waker") - } - - // Don't block when waker isn't asserted. - s.AddWaker(&w, 0) - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker was not asserted") - } - - // Don't block when waker was asserted, but isn't anymore. - w.Assert() - w.Clear() - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker was not asserted anymore") - } - - // Don't block when waker was consumed by previous Fetch(). - w.Assert() - if _, ok := s.Fetch(false); !ok { - t.Fatalf("Fetch failed even though waker was asserted") - } - - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker had been consumed") - } -} - -// TestMultiple checks that a sleeper can wait for and receives notifications -// from multiple wakers. -func TestMultiple(t *testing.T) { - s := Sleeper{} - w1 := Waker{} - w2 := Waker{} - - s.AddWaker(&w1, 0) - s.AddWaker(&w2, 1) - - w1.Assert() - w2.Assert() - - v, ok := s.Fetch(false) - if !ok { - t.Fatalf("Fetch failed when there are asserted wakers") - } - - if v != 0 && v != 1 { - t.Fatalf("Unexpected waker id: %v", v) - } - - want := 1 - v - v, ok = s.Fetch(false) - if !ok { - t.Fatalf("Fetch failed when there is an asserted waker") - } - - if v != want { - t.Fatalf("Unexpected waker id, got %v, want %v", v, want) - } -} - -// TestDoneFunction tests if calling Done() on a sleeper works properly. -func TestDoneFunction(t *testing.T) { - // Trivial case of no waker. - s := Sleeper{} - s.Done() - - // Cases when the sleeper has n wakers, but none are asserted. - for n := 1; n < 20; n++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - s.Done() - } - - // Cases when the sleeper has n wakers, and only the i-th one is - // asserted. - for n := 1; n < 20; n++ { - for i := 0; i < n; i++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - w[i].Assert() - s.Done() - } - } - - // Cases when the sleeper has n wakers, and the i-th one is asserted - // and cleared. - for n := 1; n < 20; n++ { - for i := 0; i < n; i++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - w[i].Assert() - w[i].Clear() - s.Done() - } - } - - // Cases when the sleeper has n wakers, with a random number of them - // asserted. - for n := 1; n < 20; n++ { - for iters := 0; iters < 1000; iters++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - - // Pick the number of asserted elements, then assert - // random wakers. - asserted := rand.Int() % (n + 1) - for j := 0; j < asserted; j++ { - w[rand.Int()%n].Assert() - } - s.Done() - } - } -} - -// TestRace tests that multiple wakers can continuously send wake requests to -// the sleeper. -func TestRace(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - counts := make([]int, wakers) - w := make([]Waker, wakers) - s := Sleeper{} - - // Associate each waker and start goroutines that will assert them. - for i := range w { - s.AddWaker(&w[i], i) - go func(w *Waker) { - n := 0 - for n < wakeRequests { - if !w.IsAsserted() { - w.Assert() - n++ - } else { - runtime.Gosched() - } - } - }(&w[i]) - } - - // Wait for all wake up notifications from all wakers. - for i := 0; i < wakers*wakeRequests; i++ { - v, _ := s.Fetch(true) - counts[v]++ - } - - // Check that we got the right number for each. - for i, v := range counts { - if v != wakeRequests { - t.Errorf("Waker %v only got %v wakes", i, v) - } - } -} - -// TestRaceInOrder tests that multiple wakers can continuously send wake requests to -// the sleeper and that the wakers are retrieved in the order asserted. -func TestRaceInOrder(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - w := make([]Waker, wakers) - s := Sleeper{} - - // Associate each waker and start goroutines that will assert them. - for i := range w { - s.AddWaker(&w[i], i) - } - go func() { - n := 0 - for n < wakeRequests { - wk := w[n%len(w)] - wk.Assert() - n++ - } - }() - - // Wait for all wake up notifications from all wakers. - for i := 0; i < wakeRequests; i++ { - v, _ := s.Fetch(true) - if got, want := v, i%wakers; got != want { - t.Fatalf("got %d want %d", got, want) - } - } -} - -// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up -// from 4 wakers when at least one is already asserted. -func BenchmarkSleeperMultiSelect(b *testing.B) { - const count = 4 - s := Sleeper{} - w := make([]Waker, count) - for i := range w { - s.AddWaker(&w[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w[count-1].Assert() - s.Fetch(true) - } -} - -// BenchmarkGoMultiSelect measures how long it takes to fetch a zero-length -// struct from one of 4 channels when at least one is ready. -func BenchmarkGoMultiSelect(b *testing.B) { - const count = 4 - ch := make([]chan struct{}, count) - for i := range ch { - ch[i] = make(chan struct{}, 1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ch[count-1] <- struct{}{} - select { - case <-ch[0]: - case <-ch[1]: - case <-ch[2]: - case <-ch[3]: - } - } -} - -// BenchmarkSleeperSingleSelect measures how long it takes to fetch a wake up -// from one waker that is already asserted. -func BenchmarkSleeperSingleSelect(b *testing.B) { - s := Sleeper{} - w := Waker{} - s.AddWaker(&w, 0) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Assert() - s.Fetch(true) - } -} - -// BenchmarkGoSingleSelect measures how long it takes to fetch a zero-length -// struct from a channel that already has it buffered. -func BenchmarkGoSingleSelect(b *testing.B) { - ch := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ch <- struct{}{} - <-ch - } -} - -// BenchmarkSleeperAssertNonWaiting measures how long it takes to assert a -// channel that is already asserted. -func BenchmarkSleeperAssertNonWaiting(b *testing.B) { - w := Waker{} - w.Assert() - for i := 0; i < b.N; i++ { - w.Assert() - } - -} - -// BenchmarkGoAssertNonWaiting measures how long it takes to write to a channel -// that has already something written to it. -func BenchmarkGoAssertNonWaiting(b *testing.B) { - ch := make(chan struct{}, 1) - ch <- struct{}{} - for i := 0; i < b.N; i++ { - select { - case ch <- struct{}{}: - default: - } - } -} - -// BenchmarkSleeperWaitOnSingleSelect measures how long it takes to wait on one -// waker channel while another goroutine wakes up the sleeper. This assumes that -// a new goroutine doesn't run immediately (i.e., the creator of a new goroutine -// is allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) { - s := Sleeper{} - w := Waker{} - s.AddWaker(&w, 0) - for i := 0; i < b.N; i++ { - go func() { - w.Assert() - }() - s.Fetch(true) - } - -} - -// BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one -// channel while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkGoWaitOnSingleSelect(b *testing.B) { - ch := make(chan struct{}, 1) - for i := 0; i < b.N; i++ { - go func() { - ch <- struct{}{} - }() - <-ch - } -} - -// BenchmarkSleeperWaitOnMultiSelect measures how long it takes to wait on 4 -// wakers while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) { - const count = 4 - s := Sleeper{} - w := make([]Waker, count) - for i := range w { - s.AddWaker(&w[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w[count-1].Assert() - }() - s.Fetch(true) - } -} - -// BenchmarkGoWaitOnMultiSelect measures how long it takes to wait on 4 channels -// while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkGoWaitOnMultiSelect(b *testing.B) { - const count = 4 - ch := make([]chan struct{}, count) - for i := range ch { - ch[i] = make(chan struct{}, 1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - ch[count-1] <- struct{}{} - }() - select { - case <-ch[0]: - case <-ch[1]: - case <-ch[2]: - case <-ch[3]: - } - } -} diff --git a/pkg/state/BUILD b/pkg/state/BUILD deleted file mode 100644 index 921af9d63..000000000 --- a/pkg/state/BUILD +++ /dev/null @@ -1,69 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "addr_range", - out = "addr_range.go", - package = "state", - prefix = "addr", - template = "//pkg/segment:generic_range", - types = { - "T": "uintptr", - }, -) - -go_template_instance( - name = "addr_set", - out = "addr_set.go", - consts = { - "minDegree": "10", - }, - imports = { - "reflect": "reflect", - }, - package = "state", - prefix = "addr", - template = "//pkg/segment:generic_set", - types = { - "Key": "uintptr", - "Range": "addrRange", - "Value": "reflect.Value", - "Functions": "addrSetFunctions", - }, -) - -go_library( - name = "state", - srcs = [ - "addr_range.go", - "addr_set.go", - "decode.go", - "encode.go", - "encode_unsafe.go", - "map.go", - "printer.go", - "state.go", - "stats.go", - ], - stateify = False, - visibility = ["//:sandbox"], - deps = [ - ":object_go_proto", - "@com_github_golang_protobuf//proto:go_default_library", - ], -) - -proto_library( - name = "object", - srcs = ["object.proto"], - visibility = ["//:sandbox"], -) - -go_test( - name = "state_test", - timeout = "long", - srcs = ["state_test.go"], - library = ":state", -) diff --git a/pkg/state/addr_range.go b/pkg/state/addr_range.go new file mode 100755 index 000000000..45720c643 --- /dev/null +++ b/pkg/state/addr_range.go @@ -0,0 +1,62 @@ +package state + +// A Range represents a contiguous range of T. +// +// +stateify savable +type addrRange struct { + // Start is the inclusive start of the range. + Start uintptr + + // End is the exclusive end of the range. + End uintptr +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r addrRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r addrRange) Length() uintptr { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r addrRange) Contains(x uintptr) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r addrRange) Overlaps(r2 addrRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r addrRange) IsSupersetOf(r2 addrRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r addrRange) Intersect(r2 addrRange) addrRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r addrRange) CanSplitAt(x uintptr) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/state/addr_set.go b/pkg/state/addr_set.go new file mode 100755 index 000000000..5261aa488 --- /dev/null +++ b/pkg/state/addr_set.go @@ -0,0 +1,1274 @@ +package state + +import ( + __generics_imported0 "reflect" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + addrminDegree = 10 + + addrmaxDegree = 2 * addrminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type addrSet struct { + root addrnode `state:".(*addrSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *addrSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *addrSet) IsEmptyRange(r addrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *addrSet) Span() uintptr { + var sz uintptr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *addrSet) SpanRange(r addrRange) uintptr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uintptr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *addrSet) FirstSegment() addrIterator { + if s.root.nrSegments == 0 { + return addrIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *addrSet) LastSegment() addrIterator { + if s.root.nrSegments == 0 { + return addrIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *addrSet) FirstGap() addrGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return addrGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *addrSet) LastGap() addrGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return addrGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *addrSet) Find(key uintptr) (addrIterator, addrGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return addrIterator{n, i}, addrGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return addrIterator{}, addrGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *addrSet) FindSegment(key uintptr) addrIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *addrSet) LowerBoundSegment(min uintptr) addrIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *addrSet) UpperBoundSegment(max uintptr) addrIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *addrSet) FindGap(key uintptr) addrGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *addrSet) LowerBoundGap(min uintptr) addrGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *addrSet) UpperBoundGap(max uintptr) addrGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *addrSet) Add(r addrRange, val __generics_imported0.Value) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *addrSet) AddWithoutMerging(r addrRange, val __generics_imported0.Value) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *addrSet) Insert(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (addrSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *addrSet) InsertWithoutMerging(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *addrSet) InsertWithoutMergingUnchecked(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return addrIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *addrSet) Remove(seg addrIterator) addrGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + addrSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(addrGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *addrSet) RemoveAll() { + s.root = addrnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *addrSet) RemoveRange(r addrRange) addrGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *addrSet) Merge(first, second addrIterator) addrIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *addrSet) MergeUnchecked(first, second addrIterator) addrIterator { + if first.End() == second.Start() { + if mval, ok := (addrSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return addrIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *addrSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *addrSet) MergeRange(r addrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *addrSet) MergeAdjacent(r addrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *addrSet) Split(seg addrIterator, split uintptr) (addrIterator, addrIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *addrSet) SplitUnchecked(seg addrIterator, split uintptr) (addrIterator, addrIterator) { + val1, val2 := (addrSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), addrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *addrSet) SplitAt(split uintptr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *addrSet) Isolate(seg addrIterator, r addrRange) addrIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *addrSet) ApplyContiguous(r addrRange, fn func(seg addrIterator)) addrGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return addrGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return addrGapIterator{} + } + } +} + +// +stateify savable +type addrnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *addrnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [addrmaxDegree - 1]addrRange + values [addrmaxDegree - 1]__generics_imported0.Value + children [addrmaxDegree]*addrnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *addrnode) firstSegment() addrIterator { + for n.hasChildren { + n = n.children[0] + } + return addrIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *addrnode) lastSegment() addrIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return addrIterator{n, n.nrSegments - 1} +} + +func (n *addrnode) prevSibling() *addrnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *addrnode) nextSibling() *addrnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *addrnode) rebalanceBeforeInsert(gap addrGapIterator) addrGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < addrmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:addrminDegree-1], n.keys[:addrminDegree-1]) + copy(left.values[:addrminDegree-1], n.values[:addrminDegree-1]) + copy(right.keys[:addrminDegree-1], n.keys[addrminDegree:]) + copy(right.values[:addrminDegree-1], n.values[addrminDegree:]) + n.keys[0], n.values[0] = n.keys[addrminDegree-1], n.values[addrminDegree-1] + addrzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:addrminDegree], n.children[:addrminDegree]) + copy(right.children[:addrminDegree], n.children[addrminDegree:]) + addrzeroNodeSlice(n.children[2:]) + for i := 0; i < addrminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < addrminDegree { + return addrGapIterator{left, gap.index} + } + return addrGapIterator{right, gap.index - addrminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[addrminDegree-1], n.values[addrminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:addrminDegree-1], n.keys[addrminDegree:]) + copy(sibling.values[:addrminDegree-1], n.values[addrminDegree:]) + addrzeroValueSlice(n.values[addrminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:addrminDegree], n.children[addrminDegree:]) + addrzeroNodeSlice(n.children[addrminDegree:]) + for i := 0; i < addrminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = addrminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < addrminDegree { + return gap + } + return addrGapIterator{sibling, gap.index - addrminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *addrnode) rebalanceAfterRemove(gap addrGapIterator) addrGapIterator { + for { + if n.nrSegments >= addrminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= addrminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return addrGapIterator{n, 0} + } + if gap.node == n { + return addrGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= addrminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return addrGapIterator{n, n.nrSegments} + } + return addrGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return addrGapIterator{p, gap.index} + } + if gap.node == right { + return addrGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *addrnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = addrGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + addrSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type addrIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *addrnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg addrIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg addrIterator) Range() addrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg addrIterator) Start() uintptr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg addrIterator) End() uintptr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg addrIterator) SetRangeUnchecked(r addrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetRange(r addrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg addrIterator) SetStartUnchecked(start uintptr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetStart(start uintptr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg addrIterator) SetEndUnchecked(end uintptr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetEnd(end uintptr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg addrIterator) Value() __generics_imported0.Value { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg addrIterator) ValuePtr() *__generics_imported0.Value { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg addrIterator) SetValue(val __generics_imported0.Value) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg addrIterator) PrevSegment() addrIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return addrIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return addrIterator{} + } + return addrsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg addrIterator) NextSegment() addrIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return addrIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return addrIterator{} + } + return addrsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg addrIterator) PrevGap() addrGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return addrGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg addrIterator) NextGap() addrGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return addrGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg addrIterator) PrevNonEmpty() (addrIterator, addrGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return addrIterator{}, gap + } + return gap.PrevSegment(), addrGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg addrIterator) NextNonEmpty() (addrIterator, addrGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return addrIterator{}, gap + } + return gap.NextSegment(), addrGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type addrGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *addrnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap addrGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap addrGapIterator) Range() addrRange { + return addrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap addrGapIterator) Start() uintptr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return addrSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap addrGapIterator) End() uintptr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return addrSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap addrGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap addrGapIterator) PrevSegment() addrIterator { + return addrsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap addrGapIterator) NextSegment() addrIterator { + return addrsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap addrGapIterator) PrevGap() addrGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return addrGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap addrGapIterator) NextGap() addrGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return addrGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func addrsegmentBeforePosition(n *addrnode, i int) addrIterator { + for i == 0 { + if n.parent == nil { + return addrIterator{} + } + n, i = n.parent, n.parentIndex + } + return addrIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func addrsegmentAfterPosition(n *addrnode, i int) addrIterator { + for i == n.nrSegments { + if n.parent == nil { + return addrIterator{} + } + n, i = n.parent, n.parentIndex + } + return addrIterator{n, i} +} + +func addrzeroValueSlice(slice []__generics_imported0.Value) { + + for i := range slice { + addrSetFunctions{}.ClearValue(&slice[i]) + } +} + +func addrzeroNodeSlice(slice []*addrnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *addrSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *addrnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *addrnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type addrSegmentDataSlices struct { + Start []uintptr + End []uintptr + Values []__generics_imported0.Value +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *addrSet) ExportSortedSlices() *addrSegmentDataSlices { + var sds addrSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *addrSet) ImportSortedSlices(sds *addrSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := addrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *addrSet) saveRoot() *addrSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *addrSet) loadRoot(sds *addrSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/state/object.proto b/pkg/state/object.proto deleted file mode 100644 index 5ebcfb151..000000000 --- a/pkg/state/object.proto +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor.state.statefile; - -// Slice is a slice value. -message Slice { - uint32 length = 1; - uint32 capacity = 2; - uint64 ref_value = 3; -} - -// Array is an array value. -message Array { - repeated Object contents = 1; -} - -// Map is a map value. -message Map { - repeated Object keys = 1; - repeated Object values = 2; -} - -// Interface is an interface value. -message Interface { - string type = 1; - Object value = 2; -} - -// Struct is a basic composite value. -message Struct { - repeated Field fields = 1; -} - -// Field encodes a single field. -message Field { - string name = 1; - Object value = 2; -} - -// Uint16s encodes an uint16 array. To be used inside oneof structure. -message Uint16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated uint32 values = 1; -} - -// Uint32s encodes an uint32 array. To be used inside oneof structure. -message Uint32s { - repeated fixed32 values = 1; -} - -// Uint64s encodes an uint64 array. To be used inside oneof structure. -message Uint64s { - repeated fixed64 values = 1; -} - -// Uintptrs encodes an uintptr array. To be used inside oneof structure. -message Uintptrs { - repeated fixed64 values = 1; -} - -// Int8s encodes an int8 array. To be used inside oneof structure. -message Int8s { - bytes values = 1; -} - -// Int16s encodes an int16 array. To be used inside oneof structure. -message Int16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated int32 values = 1; -} - -// Int32s encodes an int32 array. To be used inside oneof structure. -message Int32s { - repeated sfixed32 values = 1; -} - -// Int64s encodes an int64 array. To be used inside oneof structure. -message Int64s { - repeated sfixed64 values = 1; -} - -// Bools encodes a boolean array. To be used inside oneof structure. -message Bools { - repeated bool values = 1; -} - -// Float64s encodes a float64 array. To be used inside oneof structure. -message Float64s { - repeated double values = 1; -} - -// Float32s encodes a float32 array. To be used inside oneof structure. -message Float32s { - repeated float values = 1; -} - -// Object are primitive encodings. -// -// Note that ref_value references an Object.id, below. -message Object { - oneof value { - bool bool_value = 1; - bytes string_value = 2; - int64 int64_value = 3; - uint64 uint64_value = 4; - double double_value = 5; - uint64 ref_value = 6; - Slice slice_value = 7; - Array array_value = 8; - Interface interface_value = 9; - Struct struct_value = 10; - Map map_value = 11; - bytes byte_array_value = 12; - Uint16s uint16_array_value = 13; - Uint32s uint32_array_value = 14; - Uint64s uint64_array_value = 15; - Uintptrs uintptr_array_value = 16; - Int8s int8_array_value = 17; - Int16s int16_array_value = 18; - Int32s int32_array_value = 19; - Int64s int64_array_value = 20; - Bools bool_array_value = 21; - Float64s float64_array_value = 22; - Float32s float32_array_value = 23; - } -} diff --git a/pkg/state/object_go_proto/object.pb.go b/pkg/state/object_go_proto/object.pb.go new file mode 100755 index 000000000..dc5127149 --- /dev/null +++ b/pkg/state/object_go_proto/object.pb.go @@ -0,0 +1,1195 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/state/object.proto + +package gvisor_state_statefile + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type Slice struct { + Length uint32 `protobuf:"varint,1,opt,name=length,proto3" json:"length,omitempty"` + Capacity uint32 `protobuf:"varint,2,opt,name=capacity,proto3" json:"capacity,omitempty"` + RefValue uint64 `protobuf:"varint,3,opt,name=ref_value,json=refValue,proto3" json:"ref_value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Slice) Reset() { *m = Slice{} } +func (m *Slice) String() string { return proto.CompactTextString(m) } +func (*Slice) ProtoMessage() {} +func (*Slice) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{0} +} + +func (m *Slice) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Slice.Unmarshal(m, b) +} +func (m *Slice) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Slice.Marshal(b, m, deterministic) +} +func (m *Slice) XXX_Merge(src proto.Message) { + xxx_messageInfo_Slice.Merge(m, src) +} +func (m *Slice) XXX_Size() int { + return xxx_messageInfo_Slice.Size(m) +} +func (m *Slice) XXX_DiscardUnknown() { + xxx_messageInfo_Slice.DiscardUnknown(m) +} + +var xxx_messageInfo_Slice proto.InternalMessageInfo + +func (m *Slice) GetLength() uint32 { + if m != nil { + return m.Length + } + return 0 +} + +func (m *Slice) GetCapacity() uint32 { + if m != nil { + return m.Capacity + } + return 0 +} + +func (m *Slice) GetRefValue() uint64 { + if m != nil { + return m.RefValue + } + return 0 +} + +type Array struct { + Contents []*Object `protobuf:"bytes,1,rep,name=contents,proto3" json:"contents,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Array) Reset() { *m = Array{} } +func (m *Array) String() string { return proto.CompactTextString(m) } +func (*Array) ProtoMessage() {} +func (*Array) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{1} +} + +func (m *Array) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Array.Unmarshal(m, b) +} +func (m *Array) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Array.Marshal(b, m, deterministic) +} +func (m *Array) XXX_Merge(src proto.Message) { + xxx_messageInfo_Array.Merge(m, src) +} +func (m *Array) XXX_Size() int { + return xxx_messageInfo_Array.Size(m) +} +func (m *Array) XXX_DiscardUnknown() { + xxx_messageInfo_Array.DiscardUnknown(m) +} + +var xxx_messageInfo_Array proto.InternalMessageInfo + +func (m *Array) GetContents() []*Object { + if m != nil { + return m.Contents + } + return nil +} + +type Map struct { + Keys []*Object `protobuf:"bytes,1,rep,name=keys,proto3" json:"keys,omitempty"` + Values []*Object `protobuf:"bytes,2,rep,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Map) Reset() { *m = Map{} } +func (m *Map) String() string { return proto.CompactTextString(m) } +func (*Map) ProtoMessage() {} +func (*Map) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{2} +} + +func (m *Map) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Map.Unmarshal(m, b) +} +func (m *Map) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Map.Marshal(b, m, deterministic) +} +func (m *Map) XXX_Merge(src proto.Message) { + xxx_messageInfo_Map.Merge(m, src) +} +func (m *Map) XXX_Size() int { + return xxx_messageInfo_Map.Size(m) +} +func (m *Map) XXX_DiscardUnknown() { + xxx_messageInfo_Map.DiscardUnknown(m) +} + +var xxx_messageInfo_Map proto.InternalMessageInfo + +func (m *Map) GetKeys() []*Object { + if m != nil { + return m.Keys + } + return nil +} + +func (m *Map) GetValues() []*Object { + if m != nil { + return m.Values + } + return nil +} + +type Interface struct { + Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` + Value *Object `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Interface) Reset() { *m = Interface{} } +func (m *Interface) String() string { return proto.CompactTextString(m) } +func (*Interface) ProtoMessage() {} +func (*Interface) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{3} +} + +func (m *Interface) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Interface.Unmarshal(m, b) +} +func (m *Interface) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Interface.Marshal(b, m, deterministic) +} +func (m *Interface) XXX_Merge(src proto.Message) { + xxx_messageInfo_Interface.Merge(m, src) +} +func (m *Interface) XXX_Size() int { + return xxx_messageInfo_Interface.Size(m) +} +func (m *Interface) XXX_DiscardUnknown() { + xxx_messageInfo_Interface.DiscardUnknown(m) +} + +var xxx_messageInfo_Interface proto.InternalMessageInfo + +func (m *Interface) GetType() string { + if m != nil { + return m.Type + } + return "" +} + +func (m *Interface) GetValue() *Object { + if m != nil { + return m.Value + } + return nil +} + +type Struct struct { + Fields []*Field `protobuf:"bytes,1,rep,name=fields,proto3" json:"fields,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Struct) Reset() { *m = Struct{} } +func (m *Struct) String() string { return proto.CompactTextString(m) } +func (*Struct) ProtoMessage() {} +func (*Struct) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{4} +} + +func (m *Struct) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Struct.Unmarshal(m, b) +} +func (m *Struct) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Struct.Marshal(b, m, deterministic) +} +func (m *Struct) XXX_Merge(src proto.Message) { + xxx_messageInfo_Struct.Merge(m, src) +} +func (m *Struct) XXX_Size() int { + return xxx_messageInfo_Struct.Size(m) +} +func (m *Struct) XXX_DiscardUnknown() { + xxx_messageInfo_Struct.DiscardUnknown(m) +} + +var xxx_messageInfo_Struct proto.InternalMessageInfo + +func (m *Struct) GetFields() []*Field { + if m != nil { + return m.Fields + } + return nil +} + +type Field struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Value *Object `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Field) Reset() { *m = Field{} } +func (m *Field) String() string { return proto.CompactTextString(m) } +func (*Field) ProtoMessage() {} +func (*Field) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{5} +} + +func (m *Field) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Field.Unmarshal(m, b) +} +func (m *Field) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Field.Marshal(b, m, deterministic) +} +func (m *Field) XXX_Merge(src proto.Message) { + xxx_messageInfo_Field.Merge(m, src) +} +func (m *Field) XXX_Size() int { + return xxx_messageInfo_Field.Size(m) +} +func (m *Field) XXX_DiscardUnknown() { + xxx_messageInfo_Field.DiscardUnknown(m) +} + +var xxx_messageInfo_Field proto.InternalMessageInfo + +func (m *Field) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *Field) GetValue() *Object { + if m != nil { + return m.Value + } + return nil +} + +type Uint16S struct { + Values []uint32 `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Uint16S) Reset() { *m = Uint16S{} } +func (m *Uint16S) String() string { return proto.CompactTextString(m) } +func (*Uint16S) ProtoMessage() {} +func (*Uint16S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{6} +} + +func (m *Uint16S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Uint16S.Unmarshal(m, b) +} +func (m *Uint16S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Uint16S.Marshal(b, m, deterministic) +} +func (m *Uint16S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Uint16S.Merge(m, src) +} +func (m *Uint16S) XXX_Size() int { + return xxx_messageInfo_Uint16S.Size(m) +} +func (m *Uint16S) XXX_DiscardUnknown() { + xxx_messageInfo_Uint16S.DiscardUnknown(m) +} + +var xxx_messageInfo_Uint16S proto.InternalMessageInfo + +func (m *Uint16S) GetValues() []uint32 { + if m != nil { + return m.Values + } + return nil +} + +type Uint32S struct { + Values []uint32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Uint32S) Reset() { *m = Uint32S{} } +func (m *Uint32S) String() string { return proto.CompactTextString(m) } +func (*Uint32S) ProtoMessage() {} +func (*Uint32S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{7} +} + +func (m *Uint32S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Uint32S.Unmarshal(m, b) +} +func (m *Uint32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Uint32S.Marshal(b, m, deterministic) +} +func (m *Uint32S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Uint32S.Merge(m, src) +} +func (m *Uint32S) XXX_Size() int { + return xxx_messageInfo_Uint32S.Size(m) +} +func (m *Uint32S) XXX_DiscardUnknown() { + xxx_messageInfo_Uint32S.DiscardUnknown(m) +} + +var xxx_messageInfo_Uint32S proto.InternalMessageInfo + +func (m *Uint32S) GetValues() []uint32 { + if m != nil { + return m.Values + } + return nil +} + +type Uint64S struct { + Values []uint64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Uint64S) Reset() { *m = Uint64S{} } +func (m *Uint64S) String() string { return proto.CompactTextString(m) } +func (*Uint64S) ProtoMessage() {} +func (*Uint64S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{8} +} + +func (m *Uint64S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Uint64S.Unmarshal(m, b) +} +func (m *Uint64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Uint64S.Marshal(b, m, deterministic) +} +func (m *Uint64S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Uint64S.Merge(m, src) +} +func (m *Uint64S) XXX_Size() int { + return xxx_messageInfo_Uint64S.Size(m) +} +func (m *Uint64S) XXX_DiscardUnknown() { + xxx_messageInfo_Uint64S.DiscardUnknown(m) +} + +var xxx_messageInfo_Uint64S proto.InternalMessageInfo + +func (m *Uint64S) GetValues() []uint64 { + if m != nil { + return m.Values + } + return nil +} + +type Uintptrs struct { + Values []uint64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Uintptrs) Reset() { *m = Uintptrs{} } +func (m *Uintptrs) String() string { return proto.CompactTextString(m) } +func (*Uintptrs) ProtoMessage() {} +func (*Uintptrs) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{9} +} + +func (m *Uintptrs) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Uintptrs.Unmarshal(m, b) +} +func (m *Uintptrs) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Uintptrs.Marshal(b, m, deterministic) +} +func (m *Uintptrs) XXX_Merge(src proto.Message) { + xxx_messageInfo_Uintptrs.Merge(m, src) +} +func (m *Uintptrs) XXX_Size() int { + return xxx_messageInfo_Uintptrs.Size(m) +} +func (m *Uintptrs) XXX_DiscardUnknown() { + xxx_messageInfo_Uintptrs.DiscardUnknown(m) +} + +var xxx_messageInfo_Uintptrs proto.InternalMessageInfo + +func (m *Uintptrs) GetValues() []uint64 { + if m != nil { + return m.Values + } + return nil +} + +type Int8S struct { + Values []byte `protobuf:"bytes,1,opt,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Int8S) Reset() { *m = Int8S{} } +func (m *Int8S) String() string { return proto.CompactTextString(m) } +func (*Int8S) ProtoMessage() {} +func (*Int8S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{10} +} + +func (m *Int8S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Int8S.Unmarshal(m, b) +} +func (m *Int8S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Int8S.Marshal(b, m, deterministic) +} +func (m *Int8S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Int8S.Merge(m, src) +} +func (m *Int8S) XXX_Size() int { + return xxx_messageInfo_Int8S.Size(m) +} +func (m *Int8S) XXX_DiscardUnknown() { + xxx_messageInfo_Int8S.DiscardUnknown(m) +} + +var xxx_messageInfo_Int8S proto.InternalMessageInfo + +func (m *Int8S) GetValues() []byte { + if m != nil { + return m.Values + } + return nil +} + +type Int16S struct { + Values []int32 `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Int16S) Reset() { *m = Int16S{} } +func (m *Int16S) String() string { return proto.CompactTextString(m) } +func (*Int16S) ProtoMessage() {} +func (*Int16S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{11} +} + +func (m *Int16S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Int16S.Unmarshal(m, b) +} +func (m *Int16S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Int16S.Marshal(b, m, deterministic) +} +func (m *Int16S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Int16S.Merge(m, src) +} +func (m *Int16S) XXX_Size() int { + return xxx_messageInfo_Int16S.Size(m) +} +func (m *Int16S) XXX_DiscardUnknown() { + xxx_messageInfo_Int16S.DiscardUnknown(m) +} + +var xxx_messageInfo_Int16S proto.InternalMessageInfo + +func (m *Int16S) GetValues() []int32 { + if m != nil { + return m.Values + } + return nil +} + +type Int32S struct { + Values []int32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Int32S) Reset() { *m = Int32S{} } +func (m *Int32S) String() string { return proto.CompactTextString(m) } +func (*Int32S) ProtoMessage() {} +func (*Int32S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{12} +} + +func (m *Int32S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Int32S.Unmarshal(m, b) +} +func (m *Int32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Int32S.Marshal(b, m, deterministic) +} +func (m *Int32S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Int32S.Merge(m, src) +} +func (m *Int32S) XXX_Size() int { + return xxx_messageInfo_Int32S.Size(m) +} +func (m *Int32S) XXX_DiscardUnknown() { + xxx_messageInfo_Int32S.DiscardUnknown(m) +} + +var xxx_messageInfo_Int32S proto.InternalMessageInfo + +func (m *Int32S) GetValues() []int32 { + if m != nil { + return m.Values + } + return nil +} + +type Int64S struct { + Values []int64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Int64S) Reset() { *m = Int64S{} } +func (m *Int64S) String() string { return proto.CompactTextString(m) } +func (*Int64S) ProtoMessage() {} +func (*Int64S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{13} +} + +func (m *Int64S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Int64S.Unmarshal(m, b) +} +func (m *Int64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Int64S.Marshal(b, m, deterministic) +} +func (m *Int64S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Int64S.Merge(m, src) +} +func (m *Int64S) XXX_Size() int { + return xxx_messageInfo_Int64S.Size(m) +} +func (m *Int64S) XXX_DiscardUnknown() { + xxx_messageInfo_Int64S.DiscardUnknown(m) +} + +var xxx_messageInfo_Int64S proto.InternalMessageInfo + +func (m *Int64S) GetValues() []int64 { + if m != nil { + return m.Values + } + return nil +} + +type Bools struct { + Values []bool `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Bools) Reset() { *m = Bools{} } +func (m *Bools) String() string { return proto.CompactTextString(m) } +func (*Bools) ProtoMessage() {} +func (*Bools) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{14} +} + +func (m *Bools) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Bools.Unmarshal(m, b) +} +func (m *Bools) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Bools.Marshal(b, m, deterministic) +} +func (m *Bools) XXX_Merge(src proto.Message) { + xxx_messageInfo_Bools.Merge(m, src) +} +func (m *Bools) XXX_Size() int { + return xxx_messageInfo_Bools.Size(m) +} +func (m *Bools) XXX_DiscardUnknown() { + xxx_messageInfo_Bools.DiscardUnknown(m) +} + +var xxx_messageInfo_Bools proto.InternalMessageInfo + +func (m *Bools) GetValues() []bool { + if m != nil { + return m.Values + } + return nil +} + +type Float64S struct { + Values []float64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Float64S) Reset() { *m = Float64S{} } +func (m *Float64S) String() string { return proto.CompactTextString(m) } +func (*Float64S) ProtoMessage() {} +func (*Float64S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{15} +} + +func (m *Float64S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Float64S.Unmarshal(m, b) +} +func (m *Float64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Float64S.Marshal(b, m, deterministic) +} +func (m *Float64S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Float64S.Merge(m, src) +} +func (m *Float64S) XXX_Size() int { + return xxx_messageInfo_Float64S.Size(m) +} +func (m *Float64S) XXX_DiscardUnknown() { + xxx_messageInfo_Float64S.DiscardUnknown(m) +} + +var xxx_messageInfo_Float64S proto.InternalMessageInfo + +func (m *Float64S) GetValues() []float64 { + if m != nil { + return m.Values + } + return nil +} + +type Float32S struct { + Values []float32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Float32S) Reset() { *m = Float32S{} } +func (m *Float32S) String() string { return proto.CompactTextString(m) } +func (*Float32S) ProtoMessage() {} +func (*Float32S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{16} +} + +func (m *Float32S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Float32S.Unmarshal(m, b) +} +func (m *Float32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Float32S.Marshal(b, m, deterministic) +} +func (m *Float32S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Float32S.Merge(m, src) +} +func (m *Float32S) XXX_Size() int { + return xxx_messageInfo_Float32S.Size(m) +} +func (m *Float32S) XXX_DiscardUnknown() { + xxx_messageInfo_Float32S.DiscardUnknown(m) +} + +var xxx_messageInfo_Float32S proto.InternalMessageInfo + +func (m *Float32S) GetValues() []float32 { + if m != nil { + return m.Values + } + return nil +} + +type Object struct { + // Types that are valid to be assigned to Value: + // *Object_BoolValue + // *Object_StringValue + // *Object_Int64Value + // *Object_Uint64Value + // *Object_DoubleValue + // *Object_RefValue + // *Object_SliceValue + // *Object_ArrayValue + // *Object_InterfaceValue + // *Object_StructValue + // *Object_MapValue + // *Object_ByteArrayValue + // *Object_Uint16ArrayValue + // *Object_Uint32ArrayValue + // *Object_Uint64ArrayValue + // *Object_UintptrArrayValue + // *Object_Int8ArrayValue + // *Object_Int16ArrayValue + // *Object_Int32ArrayValue + // *Object_Int64ArrayValue + // *Object_BoolArrayValue + // *Object_Float64ArrayValue + // *Object_Float32ArrayValue + Value isObject_Value `protobuf_oneof:"value"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Object) Reset() { *m = Object{} } +func (m *Object) String() string { return proto.CompactTextString(m) } +func (*Object) ProtoMessage() {} +func (*Object) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{17} +} + +func (m *Object) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Object.Unmarshal(m, b) +} +func (m *Object) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Object.Marshal(b, m, deterministic) +} +func (m *Object) XXX_Merge(src proto.Message) { + xxx_messageInfo_Object.Merge(m, src) +} +func (m *Object) XXX_Size() int { + return xxx_messageInfo_Object.Size(m) +} +func (m *Object) XXX_DiscardUnknown() { + xxx_messageInfo_Object.DiscardUnknown(m) +} + +var xxx_messageInfo_Object proto.InternalMessageInfo + +type isObject_Value interface { + isObject_Value() +} + +type Object_BoolValue struct { + BoolValue bool `protobuf:"varint,1,opt,name=bool_value,json=boolValue,proto3,oneof"` +} + +type Object_StringValue struct { + StringValue []byte `protobuf:"bytes,2,opt,name=string_value,json=stringValue,proto3,oneof"` +} + +type Object_Int64Value struct { + Int64Value int64 `protobuf:"varint,3,opt,name=int64_value,json=int64Value,proto3,oneof"` +} + +type Object_Uint64Value struct { + Uint64Value uint64 `protobuf:"varint,4,opt,name=uint64_value,json=uint64Value,proto3,oneof"` +} + +type Object_DoubleValue struct { + DoubleValue float64 `protobuf:"fixed64,5,opt,name=double_value,json=doubleValue,proto3,oneof"` +} + +type Object_RefValue struct { + RefValue uint64 `protobuf:"varint,6,opt,name=ref_value,json=refValue,proto3,oneof"` +} + +type Object_SliceValue struct { + SliceValue *Slice `protobuf:"bytes,7,opt,name=slice_value,json=sliceValue,proto3,oneof"` +} + +type Object_ArrayValue struct { + ArrayValue *Array `protobuf:"bytes,8,opt,name=array_value,json=arrayValue,proto3,oneof"` +} + +type Object_InterfaceValue struct { + InterfaceValue *Interface `protobuf:"bytes,9,opt,name=interface_value,json=interfaceValue,proto3,oneof"` +} + +type Object_StructValue struct { + StructValue *Struct `protobuf:"bytes,10,opt,name=struct_value,json=structValue,proto3,oneof"` +} + +type Object_MapValue struct { + MapValue *Map `protobuf:"bytes,11,opt,name=map_value,json=mapValue,proto3,oneof"` +} + +type Object_ByteArrayValue struct { + ByteArrayValue []byte `protobuf:"bytes,12,opt,name=byte_array_value,json=byteArrayValue,proto3,oneof"` +} + +type Object_Uint16ArrayValue struct { + Uint16ArrayValue *Uint16S `protobuf:"bytes,13,opt,name=uint16_array_value,json=uint16ArrayValue,proto3,oneof"` +} + +type Object_Uint32ArrayValue struct { + Uint32ArrayValue *Uint32S `protobuf:"bytes,14,opt,name=uint32_array_value,json=uint32ArrayValue,proto3,oneof"` +} + +type Object_Uint64ArrayValue struct { + Uint64ArrayValue *Uint64S `protobuf:"bytes,15,opt,name=uint64_array_value,json=uint64ArrayValue,proto3,oneof"` +} + +type Object_UintptrArrayValue struct { + UintptrArrayValue *Uintptrs `protobuf:"bytes,16,opt,name=uintptr_array_value,json=uintptrArrayValue,proto3,oneof"` +} + +type Object_Int8ArrayValue struct { + Int8ArrayValue *Int8S `protobuf:"bytes,17,opt,name=int8_array_value,json=int8ArrayValue,proto3,oneof"` +} + +type Object_Int16ArrayValue struct { + Int16ArrayValue *Int16S `protobuf:"bytes,18,opt,name=int16_array_value,json=int16ArrayValue,proto3,oneof"` +} + +type Object_Int32ArrayValue struct { + Int32ArrayValue *Int32S `protobuf:"bytes,19,opt,name=int32_array_value,json=int32ArrayValue,proto3,oneof"` +} + +type Object_Int64ArrayValue struct { + Int64ArrayValue *Int64S `protobuf:"bytes,20,opt,name=int64_array_value,json=int64ArrayValue,proto3,oneof"` +} + +type Object_BoolArrayValue struct { + BoolArrayValue *Bools `protobuf:"bytes,21,opt,name=bool_array_value,json=boolArrayValue,proto3,oneof"` +} + +type Object_Float64ArrayValue struct { + Float64ArrayValue *Float64S `protobuf:"bytes,22,opt,name=float64_array_value,json=float64ArrayValue,proto3,oneof"` +} + +type Object_Float32ArrayValue struct { + Float32ArrayValue *Float32S `protobuf:"bytes,23,opt,name=float32_array_value,json=float32ArrayValue,proto3,oneof"` +} + +func (*Object_BoolValue) isObject_Value() {} + +func (*Object_StringValue) isObject_Value() {} + +func (*Object_Int64Value) isObject_Value() {} + +func (*Object_Uint64Value) isObject_Value() {} + +func (*Object_DoubleValue) isObject_Value() {} + +func (*Object_RefValue) isObject_Value() {} + +func (*Object_SliceValue) isObject_Value() {} + +func (*Object_ArrayValue) isObject_Value() {} + +func (*Object_InterfaceValue) isObject_Value() {} + +func (*Object_StructValue) isObject_Value() {} + +func (*Object_MapValue) isObject_Value() {} + +func (*Object_ByteArrayValue) isObject_Value() {} + +func (*Object_Uint16ArrayValue) isObject_Value() {} + +func (*Object_Uint32ArrayValue) isObject_Value() {} + +func (*Object_Uint64ArrayValue) isObject_Value() {} + +func (*Object_UintptrArrayValue) isObject_Value() {} + +func (*Object_Int8ArrayValue) isObject_Value() {} + +func (*Object_Int16ArrayValue) isObject_Value() {} + +func (*Object_Int32ArrayValue) isObject_Value() {} + +func (*Object_Int64ArrayValue) isObject_Value() {} + +func (*Object_BoolArrayValue) isObject_Value() {} + +func (*Object_Float64ArrayValue) isObject_Value() {} + +func (*Object_Float32ArrayValue) isObject_Value() {} + +func (m *Object) GetValue() isObject_Value { + if m != nil { + return m.Value + } + return nil +} + +func (m *Object) GetBoolValue() bool { + if x, ok := m.GetValue().(*Object_BoolValue); ok { + return x.BoolValue + } + return false +} + +func (m *Object) GetStringValue() []byte { + if x, ok := m.GetValue().(*Object_StringValue); ok { + return x.StringValue + } + return nil +} + +func (m *Object) GetInt64Value() int64 { + if x, ok := m.GetValue().(*Object_Int64Value); ok { + return x.Int64Value + } + return 0 +} + +func (m *Object) GetUint64Value() uint64 { + if x, ok := m.GetValue().(*Object_Uint64Value); ok { + return x.Uint64Value + } + return 0 +} + +func (m *Object) GetDoubleValue() float64 { + if x, ok := m.GetValue().(*Object_DoubleValue); ok { + return x.DoubleValue + } + return 0 +} + +func (m *Object) GetRefValue() uint64 { + if x, ok := m.GetValue().(*Object_RefValue); ok { + return x.RefValue + } + return 0 +} + +func (m *Object) GetSliceValue() *Slice { + if x, ok := m.GetValue().(*Object_SliceValue); ok { + return x.SliceValue + } + return nil +} + +func (m *Object) GetArrayValue() *Array { + if x, ok := m.GetValue().(*Object_ArrayValue); ok { + return x.ArrayValue + } + return nil +} + +func (m *Object) GetInterfaceValue() *Interface { + if x, ok := m.GetValue().(*Object_InterfaceValue); ok { + return x.InterfaceValue + } + return nil +} + +func (m *Object) GetStructValue() *Struct { + if x, ok := m.GetValue().(*Object_StructValue); ok { + return x.StructValue + } + return nil +} + +func (m *Object) GetMapValue() *Map { + if x, ok := m.GetValue().(*Object_MapValue); ok { + return x.MapValue + } + return nil +} + +func (m *Object) GetByteArrayValue() []byte { + if x, ok := m.GetValue().(*Object_ByteArrayValue); ok { + return x.ByteArrayValue + } + return nil +} + +func (m *Object) GetUint16ArrayValue() *Uint16S { + if x, ok := m.GetValue().(*Object_Uint16ArrayValue); ok { + return x.Uint16ArrayValue + } + return nil +} + +func (m *Object) GetUint32ArrayValue() *Uint32S { + if x, ok := m.GetValue().(*Object_Uint32ArrayValue); ok { + return x.Uint32ArrayValue + } + return nil +} + +func (m *Object) GetUint64ArrayValue() *Uint64S { + if x, ok := m.GetValue().(*Object_Uint64ArrayValue); ok { + return x.Uint64ArrayValue + } + return nil +} + +func (m *Object) GetUintptrArrayValue() *Uintptrs { + if x, ok := m.GetValue().(*Object_UintptrArrayValue); ok { + return x.UintptrArrayValue + } + return nil +} + +func (m *Object) GetInt8ArrayValue() *Int8S { + if x, ok := m.GetValue().(*Object_Int8ArrayValue); ok { + return x.Int8ArrayValue + } + return nil +} + +func (m *Object) GetInt16ArrayValue() *Int16S { + if x, ok := m.GetValue().(*Object_Int16ArrayValue); ok { + return x.Int16ArrayValue + } + return nil +} + +func (m *Object) GetInt32ArrayValue() *Int32S { + if x, ok := m.GetValue().(*Object_Int32ArrayValue); ok { + return x.Int32ArrayValue + } + return nil +} + +func (m *Object) GetInt64ArrayValue() *Int64S { + if x, ok := m.GetValue().(*Object_Int64ArrayValue); ok { + return x.Int64ArrayValue + } + return nil +} + +func (m *Object) GetBoolArrayValue() *Bools { + if x, ok := m.GetValue().(*Object_BoolArrayValue); ok { + return x.BoolArrayValue + } + return nil +} + +func (m *Object) GetFloat64ArrayValue() *Float64S { + if x, ok := m.GetValue().(*Object_Float64ArrayValue); ok { + return x.Float64ArrayValue + } + return nil +} + +func (m *Object) GetFloat32ArrayValue() *Float32S { + if x, ok := m.GetValue().(*Object_Float32ArrayValue); ok { + return x.Float32ArrayValue + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*Object) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*Object_BoolValue)(nil), + (*Object_StringValue)(nil), + (*Object_Int64Value)(nil), + (*Object_Uint64Value)(nil), + (*Object_DoubleValue)(nil), + (*Object_RefValue)(nil), + (*Object_SliceValue)(nil), + (*Object_ArrayValue)(nil), + (*Object_InterfaceValue)(nil), + (*Object_StructValue)(nil), + (*Object_MapValue)(nil), + (*Object_ByteArrayValue)(nil), + (*Object_Uint16ArrayValue)(nil), + (*Object_Uint32ArrayValue)(nil), + (*Object_Uint64ArrayValue)(nil), + (*Object_UintptrArrayValue)(nil), + (*Object_Int8ArrayValue)(nil), + (*Object_Int16ArrayValue)(nil), + (*Object_Int32ArrayValue)(nil), + (*Object_Int64ArrayValue)(nil), + (*Object_BoolArrayValue)(nil), + (*Object_Float64ArrayValue)(nil), + (*Object_Float32ArrayValue)(nil), + } +} + +func init() { + proto.RegisterType((*Slice)(nil), "gvisor.state.statefile.Slice") + proto.RegisterType((*Array)(nil), "gvisor.state.statefile.Array") + proto.RegisterType((*Map)(nil), "gvisor.state.statefile.Map") + proto.RegisterType((*Interface)(nil), "gvisor.state.statefile.Interface") + proto.RegisterType((*Struct)(nil), "gvisor.state.statefile.Struct") + proto.RegisterType((*Field)(nil), "gvisor.state.statefile.Field") + proto.RegisterType((*Uint16S)(nil), "gvisor.state.statefile.Uint16s") + proto.RegisterType((*Uint32S)(nil), "gvisor.state.statefile.Uint32s") + proto.RegisterType((*Uint64S)(nil), "gvisor.state.statefile.Uint64s") + proto.RegisterType((*Uintptrs)(nil), "gvisor.state.statefile.Uintptrs") + proto.RegisterType((*Int8S)(nil), "gvisor.state.statefile.Int8s") + proto.RegisterType((*Int16S)(nil), "gvisor.state.statefile.Int16s") + proto.RegisterType((*Int32S)(nil), "gvisor.state.statefile.Int32s") + proto.RegisterType((*Int64S)(nil), "gvisor.state.statefile.Int64s") + proto.RegisterType((*Bools)(nil), "gvisor.state.statefile.Bools") + proto.RegisterType((*Float64S)(nil), "gvisor.state.statefile.Float64s") + proto.RegisterType((*Float32S)(nil), "gvisor.state.statefile.Float32s") + proto.RegisterType((*Object)(nil), "gvisor.state.statefile.Object") +} + +func init() { proto.RegisterFile("pkg/state/object.proto", fileDescriptor_3dee2c1912d4d62d) } + +var fileDescriptor_3dee2c1912d4d62d = []byte{ + // 781 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x96, 0x6f, 0x4f, 0xda, 0x5e, + 0x14, 0xc7, 0xa9, 0x40, 0x29, 0x07, 0x14, 0xb8, 0xfe, 0x7e, 0x8c, 0xcc, 0x38, 0xb1, 0x7b, 0x42, + 0xf6, 0x00, 0x33, 0x60, 0xc4, 0xf8, 0x64, 0x53, 0x13, 0x03, 0xc9, 0x8c, 0x59, 0x8d, 0xcb, 0x9e, + 0x99, 0x52, 0x2f, 0xac, 0xb3, 0xb6, 0x5d, 0x7b, 0x6b, 0xc2, 0xcb, 0xdc, 0x3b, 0x5a, 0xee, 0x1f, + 0xae, 0xfd, 0x03, 0xc5, 0xec, 0x89, 0xa1, 0xb7, 0xdf, 0xf3, 0xe1, 0xdc, 0xf3, 0x3d, 0xe7, 0x08, + 0xb4, 0xfd, 0xc7, 0xc5, 0x49, 0x48, 0x4c, 0x82, 0x4f, 0xbc, 0xd9, 0x2f, 0x6c, 0x91, 0xbe, 0x1f, + 0x78, 0xc4, 0x43, 0xed, 0xc5, 0xb3, 0x1d, 0x7a, 0x41, 0x9f, 0xbd, 0xe2, 0x7f, 0xe7, 0xb6, 0x83, + 0xf5, 0x1f, 0x50, 0xbe, 0x75, 0x6c, 0x0b, 0xa3, 0x36, 0xa8, 0x0e, 0x76, 0x17, 0xe4, 0x67, 0x47, + 0xe9, 0x2a, 0xbd, 0x5d, 0x43, 0x3c, 0xa1, 0xb7, 0xa0, 0x59, 0xa6, 0x6f, 0x5a, 0x36, 0x59, 0x76, + 0x76, 0xd8, 0x1b, 0xf9, 0x8c, 0x0e, 0xa0, 0x1a, 0xe0, 0xf9, 0xfd, 0xb3, 0xe9, 0x44, 0xb8, 0x53, + 0xec, 0x2a, 0xbd, 0x92, 0xa1, 0x05, 0x78, 0xfe, 0x9d, 0x3e, 0xeb, 0x97, 0x50, 0x3e, 0x0f, 0x02, + 0x73, 0x89, 0xce, 0x40, 0xb3, 0x3c, 0x97, 0x60, 0x97, 0x84, 0x1d, 0xa5, 0x5b, 0xec, 0xd5, 0x06, + 0xef, 0xfa, 0xeb, 0xb3, 0xe9, 0xdf, 0xb0, 0x94, 0x0d, 0xa9, 0xd7, 0x7f, 0x43, 0xf1, 0xda, 0xf4, + 0xd1, 0x00, 0x4a, 0x8f, 0x78, 0xf9, 0xda, 0x70, 0xa6, 0x45, 0x63, 0x50, 0x59, 0x62, 0x61, 0x67, + 0xe7, 0x55, 0x51, 0x42, 0xad, 0xdf, 0x41, 0x75, 0xea, 0x12, 0x1c, 0xcc, 0x4d, 0x0b, 0x23, 0x04, + 0x25, 0xb2, 0xf4, 0x31, 0xab, 0x49, 0xd5, 0x60, 0x9f, 0xd1, 0x08, 0xca, 0xfc, 0xc6, 0xb4, 0x1c, + 0xdb, 0xb9, 0x5c, 0xac, 0x7f, 0x06, 0xf5, 0x96, 0x04, 0x91, 0x45, 0xd0, 0x27, 0x50, 0xe7, 0x36, + 0x76, 0x1e, 0x56, 0xd7, 0x39, 0xdc, 0x04, 0xb8, 0xa2, 0x2a, 0x43, 0x88, 0xf5, 0x6f, 0x50, 0x66, + 0x07, 0x34, 0x27, 0xd7, 0x7c, 0x92, 0x39, 0xd1, 0xcf, 0xff, 0x98, 0xd3, 0x31, 0x54, 0xee, 0x6c, + 0x97, 0x7c, 0x1c, 0x87, 0xd4, 0x7e, 0x51, 0x2d, 0x9a, 0xd4, 0xae, 0xac, 0x86, 0x90, 0x0c, 0x07, + 0x69, 0x49, 0x25, 0x2d, 0x19, 0x8f, 0xd2, 0x12, 0x55, 0x4a, 0x74, 0xd0, 0xa8, 0xc4, 0x27, 0xc1, + 0x66, 0xcd, 0x11, 0x94, 0xa7, 0x2e, 0x39, 0x4d, 0x0a, 0x94, 0x5e, 0x5d, 0x0a, 0xba, 0xa0, 0x4e, + 0xd7, 0x25, 0x5b, 0x4e, 0x29, 0xb2, 0xb9, 0x36, 0x52, 0x8a, 0x6c, 0xaa, 0xcd, 0x78, 0x1a, 0x17, + 0x9e, 0xe7, 0xa4, 0x05, 0x5a, 0xfc, 0x2e, 0x57, 0x8e, 0x67, 0xae, 0x81, 0x28, 0x19, 0x4d, 0x36, + 0x95, 0x1d, 0xa9, 0xf9, 0x53, 0x03, 0x95, 0xdb, 0x81, 0x8e, 0x00, 0x66, 0x9e, 0xe7, 0x88, 0x41, + 0xa2, 0xb7, 0xd6, 0x26, 0x05, 0xa3, 0x4a, 0xcf, 0xd8, 0x2c, 0xa1, 0xf7, 0x50, 0x0f, 0x49, 0x60, + 0xbb, 0x8b, 0xfb, 0x17, 0x97, 0xeb, 0x93, 0x82, 0x51, 0xe3, 0xa7, 0x5c, 0x74, 0x0c, 0x35, 0x66, + 0x43, 0x6c, 0x1e, 0x8b, 0x93, 0x82, 0x01, 0xec, 0x50, 0x72, 0xa2, 0xb8, 0xa6, 0x44, 0x67, 0x96, + 0x72, 0xa2, 0xa4, 0xe8, 0xc1, 0x8b, 0x66, 0x0e, 0x16, 0xa2, 0x72, 0x57, 0xe9, 0x29, 0x54, 0xc4, + 0x4f, 0xb9, 0xe8, 0x30, 0x3e, 0xfa, 0xaa, 0xc0, 0xc8, 0xe1, 0x47, 0x5f, 0xa0, 0x16, 0xd2, 0xb5, + 0x22, 0x04, 0x15, 0xd6, 0x95, 0x1b, 0x1b, 0x9d, 0x6d, 0x20, 0x9a, 0x2a, 0x8b, 0x91, 0x04, 0x93, + 0xae, 0x0f, 0x41, 0xd0, 0xf2, 0x09, 0x6c, 0xd3, 0x50, 0x02, 0x8b, 0xe1, 0x84, 0xaf, 0xd0, 0xb0, + 0x57, 0x83, 0x2c, 0x28, 0x55, 0x46, 0x39, 0xde, 0x44, 0x91, 0x73, 0x3f, 0x29, 0x18, 0x7b, 0x32, + 0x96, 0xd3, 0x2e, 0x99, 0x05, 0x91, 0x45, 0x04, 0x0a, 0xf2, 0x07, 0x8d, 0xcf, 0xba, 0xb0, 0x28, + 0xb2, 0x08, 0x87, 0x9c, 0x41, 0xf5, 0xc9, 0xf4, 0x05, 0xa1, 0xc6, 0x08, 0x07, 0x9b, 0x08, 0xd7, + 0xa6, 0x4f, 0x4b, 0xfa, 0x64, 0xfa, 0x3c, 0xf6, 0x03, 0x34, 0x67, 0x4b, 0x82, 0xef, 0xe3, 0x55, + 0xa9, 0x8b, 0x3e, 0xd8, 0xa3, 0x6f, 0xce, 0x5f, 0xae, 0x7e, 0x03, 0x28, 0x62, 0x83, 0x9d, 0x50, + 0xef, 0xb2, 0x2f, 0x3c, 0xda, 0xf4, 0x85, 0x62, 0x15, 0x4c, 0x0a, 0x46, 0x93, 0x07, 0x67, 0x81, + 0xc3, 0x41, 0x02, 0xb8, 0xb7, 0x1d, 0x38, 0x1c, 0x48, 0xe0, 0x70, 0x90, 0x05, 0x8e, 0x47, 0x09, + 0x60, 0x63, 0x3b, 0x70, 0x3c, 0x92, 0xc0, 0xf1, 0x28, 0x06, 0x34, 0x60, 0x3f, 0xe2, 0x2b, 0x26, + 0x41, 0x6c, 0x32, 0x62, 0x37, 0x8f, 0x48, 0xb7, 0xd2, 0xa4, 0x60, 0xb4, 0x44, 0x78, 0x8c, 0x39, + 0x85, 0xa6, 0xed, 0x92, 0xd3, 0x04, 0xb0, 0x95, 0xdf, 0x88, 0x6c, 0x85, 0x89, 0xf6, 0x39, 0x3d, + 0x8f, 0x37, 0x63, 0x2b, 0x6b, 0x08, 0xca, 0xef, 0xa1, 0xe9, 0xca, 0x8f, 0x46, 0xda, 0x0e, 0x4e, + 0x4b, 0xb9, 0xb1, 0xbf, 0x95, 0xc6, 0xcd, 0x68, 0xa4, 0xbd, 0xe0, 0xb4, 0x94, 0x15, 0xff, 0x6d, + 0xa5, 0x71, 0x27, 0x1a, 0x69, 0x23, 0xa6, 0xd0, 0x64, 0xcb, 0x2c, 0x0e, 0xfb, 0x3f, 0xbf, 0x68, + 0x6c, 0xe1, 0xb2, 0x36, 0xf6, 0x3c, 0x27, 0xe9, 0xe9, 0x9c, 0xaf, 0xda, 0x04, 0xad, 0x9d, 0xef, + 0xe9, 0x6a, 0x3b, 0x53, 0x4f, 0x45, 0xf8, 0x1a, 0x66, 0xaa, 0x78, 0x6f, 0x5e, 0xc1, 0xe4, 0xe5, + 0x6b, 0x89, 0xf0, 0x17, 0xe6, 0x45, 0x45, 0xfc, 0xf7, 0x9d, 0xa9, 0xec, 0xc7, 0xd6, 0xf0, 0x6f, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x84, 0x69, 0xc9, 0x45, 0x86, 0x09, 0x00, 0x00, +} diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go deleted file mode 100644 index d7221e9e8..000000000 --- a/pkg/state/state_test.go +++ /dev/null @@ -1,721 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "bytes" - "context" - "io/ioutil" - "math" - "reflect" - "testing" -) - -// TestCase is used to define a single success/failure testcase of -// serialization of a set of objects. -type TestCase struct { - // Name is the name of the test case. - Name string - - // Objects is the list of values to serialize. - Objects []interface{} - - // Fail is whether the test case is supposed to fail or not. - Fail bool -} - -// runTest runs all testcases. -func runTest(t *testing.T, tests []TestCase) { - for _, test := range tests { - t.Logf("TEST %s:", test.Name) - for i, root := range test.Objects { - t.Logf(" case#%d: %#v", i, root) - - // Save the passed object. - saveBuffer := &bytes.Buffer{} - saveObjectPtr := reflect.New(reflect.TypeOf(root)) - saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Save failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Save failed as expected: %v", err) - continue - } - - // Load a new copy of the object. - loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Load failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Load failed as expected: %v", err) - continue - } - - // Compare the values. - loadedValue := loadObjectPtr.Elem().Interface() - if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail { - t.Errorf(" FAIL: Objects differs; got %#v", loadedValue) - continue - } else if !eq { - t.Logf(" PASS: Object different as expected.") - continue - } - - // Everything went okay. Is that good? - if test.Fail { - t.Errorf(" FAIL: Unexpected success.") - } else { - t.Logf(" PASS: Success.") - } - } - } -} - -// dumbStruct is a struct which does not implement the loader/saver interface. -// We expect that serialization of this struct will fail. -type dumbStruct struct { - A int - B int -} - -// smartStruct is a struct which does implement the loader/saver interface. -// We expect that serialization of this struct will succeed. -type smartStruct struct { - A int - B int -} - -func (s *smartStruct) save(m Map) { - m.Save("A", &s.A) - m.Save("B", &s.B) -} - -func (s *smartStruct) load(m Map) { - m.Load("A", &s.A) - m.Load("B", &s.B) -} - -// valueLoadStruct uses a value load. -type valueLoadStruct struct { - v int -} - -func (v *valueLoadStruct) save(m Map) { - m.SaveValue("v", v.v) -} - -func (v *valueLoadStruct) load(m Map) { - m.LoadValue("v", new(int), func(value interface{}) { - v.v = value.(int) - }) -} - -// afterLoadStruct has an AfterLoad function. -type afterLoadStruct struct { - v int -} - -func (a *afterLoadStruct) save(m Map) { -} - -func (a *afterLoadStruct) load(m Map) { - m.AfterLoad(func() { - a.v++ - }) -} - -// genericContainer is a generic dispatcher. -type genericContainer struct { - v interface{} -} - -func (g *genericContainer) save(m Map) { - m.Save("v", &g.v) -} - -func (g *genericContainer) load(m Map) { - m.Load("v", &g.v) -} - -// sliceContainer is a generic slice. -type sliceContainer struct { - v []interface{} -} - -func (s *sliceContainer) save(m Map) { - m.Save("v", &s.v) -} - -func (s *sliceContainer) load(m Map) { - m.Load("v", &s.v) -} - -// mapContainer is a generic map. -type mapContainer struct { - v map[int]interface{} -} - -func (mc *mapContainer) save(m Map) { - m.Save("v", &mc.v) -} - -func (mc *mapContainer) load(m Map) { - // Some of the test cases below assume legacy behavior wherein maps - // will automatically inherit dependencies. - m.LoadWait("v", &mc.v) -} - -// dumbMap is a map which does not implement the loader/saver interface. -// Serialization of this map will default to the standard encode/decode logic. -type dumbMap map[string]int - -// pointerStruct contains various pointers, shared and non-shared, and pointers -// to pointers. We expect that serialization will respect the structure. -type pointerStruct struct { - A *int - B *int - C *int - D *int - - AA **int - BB **int -} - -func (p *pointerStruct) save(m Map) { - m.Save("A", &p.A) - m.Save("B", &p.B) - m.Save("C", &p.C) - m.Save("D", &p.D) - m.Save("AA", &p.AA) - m.Save("BB", &p.BB) -} - -func (p *pointerStruct) load(m Map) { - m.Load("A", &p.A) - m.Load("B", &p.B) - m.Load("C", &p.C) - m.Load("D", &p.D) - m.Load("AA", &p.AA) - m.Load("BB", &p.BB) -} - -// testInterface is a trivial interface example. -type testInterface interface { - Foo() -} - -// testImpl is a trivial implementation of testInterface. -type testImpl struct { -} - -// Foo satisfies testInterface. -func (t *testImpl) Foo() { -} - -// testImpl is trivially serializable. -func (t *testImpl) save(m Map) { -} - -// testImpl is trivially serializable. -func (t *testImpl) load(m Map) { -} - -// testI demonstrates interface dispatching. -type testI struct { - I testInterface -} - -func (t *testI) save(m Map) { - m.Save("I", &t.I) -} - -func (t *testI) load(m Map) { - m.Load("I", &t.I) -} - -// cycleStruct is used to implement basic cycles. -type cycleStruct struct { - c *cycleStruct -} - -func (c *cycleStruct) save(m Map) { - m.Save("c", &c.c) -} - -func (c *cycleStruct) load(m Map) { - m.Load("c", &c.c) -} - -// badCycleStruct actually has deadlocking dependencies. -// -// This should pass if b.b = {nil|b} and fail otherwise. -type badCycleStruct struct { - b *badCycleStruct -} - -func (b *badCycleStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *badCycleStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(func() { - // This is not executable, since AfterLoad requires that the - // object and all dependencies are complete. This should cause - // a deadlock error during load. - }) -} - -// emptyStructPointer points to an empty struct. -type emptyStructPointer struct { - nothing *struct{} -} - -func (e *emptyStructPointer) save(m Map) { - m.Save("nothing", &e.nothing) -} - -func (e *emptyStructPointer) load(m Map) { - m.Load("nothing", &e.nothing) -} - -// truncateInteger truncates an integer. -type truncateInteger struct { - v int64 - v2 int32 -} - -func (t *truncateInteger) save(m Map) { - t.v2 = int32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = int64(t.v2) -} - -// truncateUnsignedInteger truncates an unsigned integer. -type truncateUnsignedInteger struct { - v uint64 - v2 uint32 -} - -func (t *truncateUnsignedInteger) save(m Map) { - t.v2 = uint32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateUnsignedInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = uint64(t.v2) -} - -// truncateFloat truncates a floating point number. -type truncateFloat struct { - v float64 - v2 float32 -} - -func (t *truncateFloat) save(m Map) { - t.v2 = float32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateFloat) load(m Map) { - m.Load("v", &t.v2) - t.v = float64(t.v2) -} - -func TestTypes(t *testing.T) { - // x and y are basic integers, while xp points to x. - x := 1 - y := 2 - xp := &x - - // cs is a single object cycle. - cs := cycleStruct{nil} - cs.c = &cs - - // cs1 and cs2 are in a two object cycle. - cs1 := cycleStruct{nil} - cs2 := cycleStruct{nil} - cs1.c = &cs2 - cs2.c = &cs1 - - // bs is a single object cycle. - bs := badCycleStruct{nil} - bs.b = &bs - - // bs2 and bs2 are in a deadlocking cycle. - bs1 := badCycleStruct{nil} - bs2 := badCycleStruct{nil} - bs1.b = &bs2 - bs2.b = &bs1 - - // regular nils. - var ( - nilmap dumbMap - nilslice []byte - ) - - // embed points to embedded fields. - embed1 := pointerStruct{} - embed1.AA = &embed1.A - embed2 := pointerStruct{} - embed2.BB = &embed2.B - - // es1 contains two structs pointing to the same empty struct. - es := emptyStructPointer{new(struct{})} - es1 := []emptyStructPointer{es, es} - - tests := []TestCase{ - { - Name: "bool", - Objects: []interface{}{ - true, - false, - }, - }, - { - Name: "integers", - Objects: []interface{}{ - int(0), - int(1), - int(-1), - int8(0), - int8(1), - int8(-1), - int16(0), - int16(1), - int16(-1), - int32(0), - int32(1), - int32(-1), - int64(0), - int64(1), - int64(-1), - }, - }, - { - Name: "unsigned integers", - Objects: []interface{}{ - uint(0), - uint(1), - uint8(0), - uint8(1), - uint16(0), - uint16(1), - uint32(1), - uint64(0), - uint64(1), - }, - }, - { - Name: "strings", - Objects: []interface{}{ - "", - "foo", - "bar", - "\xa0", - }, - }, - { - Name: "slices", - Objects: []interface{}{ - []int{-1, 0, 1}, - []*int{&x, &x, &x}, - []int{1, 2, 3}[0:1], - []int{1, 2, 3}[1:2], - make([]byte, 32), - make([]byte, 32)[:16], - make([]byte, 32)[:16:20], - nilslice, - }, - }, - { - Name: "arrays", - Objects: []interface{}{ - &[1048576]bool{false, true, false, true}, - &[1048576]uint8{0, 1, 2, 3}, - &[1048576]byte{0, 1, 2, 3}, - &[1048576]uint16{0, 1, 2, 3}, - &[1048576]uint{0, 1, 2, 3}, - &[1048576]uint32{0, 1, 2, 3}, - &[1048576]uint64{0, 1, 2, 3}, - &[1048576]uintptr{0, 1, 2, 3}, - &[1048576]int8{0, -1, -2, -3}, - &[1048576]int16{0, -1, -2, -3}, - &[1048576]int32{0, -1, -2, -3}, - &[1048576]int64{0, -1, -2, -3}, - &[1048576]float32{0, 1.1, 2.2, 3.3}, - &[1048576]float64{0, 1.1, 2.2, 3.3}, - }, - }, - { - Name: "pointers", - Objects: []interface{}{ - &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp}, - &pointerStruct{}, - }, - }, - { - Name: "empty struct", - Objects: []interface{}{ - struct{}{}, - }, - }, - { - Name: "unenlightened structs", - Objects: []interface{}{ - &dumbStruct{A: 1, B: 2}, - }, - Fail: true, - }, - { - Name: "enlightened structs", - Objects: []interface{}{ - &smartStruct{A: 1, B: 2}, - }, - }, - { - Name: "load-hooks", - Objects: []interface{}{ - &afterLoadStruct{v: 1}, - &valueLoadStruct{v: 1}, - &genericContainer{v: &afterLoadStruct{v: 1}}, - &genericContainer{v: &valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, - }, - }, - { - Name: "maps", - Objects: []interface{}{ - dumbMap{"a": -1, "b": 0, "c": 1}, - map[smartStruct]int{{}: 0, {A: 1}: 1}, - nilmap, - &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}}, - }, - }, - { - Name: "interfaces", - Objects: []interface{}{ - &testI{&testImpl{}}, - &testI{nil}, - &testI{(*testImpl)(nil)}, - }, - }, - { - Name: "unregistered-interfaces", - Objects: []interface{}{ - &genericContainer{v: afterLoadStruct{v: 1}}, - &genericContainer{v: valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}}, - }, - Fail: true, - }, - { - Name: "cycles", - Objects: []interface{}{ - &cs, - &cs1, - &cycleStruct{&cs1}, - &cycleStruct{&cs}, - &badCycleStruct{nil}, - &bs, - }, - }, - { - Name: "deadlock", - Objects: []interface{}{ - &bs1, - }, - Fail: true, - }, - { - Name: "embed", - Objects: []interface{}{ - &embed1, - &embed2, - }, - Fail: true, - }, - { - Name: "empty structs", - Objects: []interface{}{ - new(struct{}), - es, - es1, - }, - }, - { - Name: "truncated okay", - Objects: []interface{}{ - &truncateInteger{v: 1}, - &truncateUnsignedInteger{v: 1}, - &truncateFloat{v: 1.0}, - }, - }, - { - Name: "truncated bad", - Objects: []interface{}{ - &truncateInteger{v: math.MaxInt32 + 1}, - &truncateUnsignedInteger{v: math.MaxUint32 + 1}, - &truncateFloat{v: math.MaxFloat32 * 2}, - }, - Fail: true, - }, - } - - runTest(t, tests) -} - -// benchStruct is used for benchmarking. -type benchStruct struct { - b *benchStruct - - // Dummy data is included to ensure that these objects are large. - // This is to detect possible regression when registering objects. - _ [4096]byte -} - -func (b *benchStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *benchStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(b.afterLoad) -} - -func (b *benchStruct) afterLoad() { - // Do nothing, just force scheduling. -} - -// buildObject builds a benchmark object. -func buildObject(n int) (b *benchStruct) { - for i := 0; i < n; i++ { - b = &benchStruct{b: b} - } - return -} - -func BenchmarkEncoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var stats Stats - b.StartTimer() - if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { - b.Errorf("save failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func BenchmarkDecoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var newBS benchStruct - buf := &bytes.Buffer{} - if err := Save(context.Background(), buf, bs, nil); err != nil { - b.Errorf("save failed: %v", err) - } - var stats Stats - b.StartTimer() - if err := Load(context.Background(), buf, &newBS, &stats); err != nil { - b.Errorf("load failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func init() { - Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{ - Save: (*smartStruct).save, - Load: (*smartStruct).load, - }) - Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{ - Save: (*afterLoadStruct).save, - Load: (*afterLoadStruct).load, - }) - Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{ - Save: (*valueLoadStruct).save, - Load: (*valueLoadStruct).load, - }) - Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{ - Save: (*genericContainer).save, - Load: (*genericContainer).load, - }) - Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{ - Save: (*sliceContainer).save, - Load: (*sliceContainer).load, - }) - Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{ - Save: (*mapContainer).save, - Load: (*mapContainer).load, - }) - Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{ - Save: (*pointerStruct).save, - Load: (*pointerStruct).load, - }) - Register("stateTest.testImpl", (*testImpl)(nil), Fns{ - Save: (*testImpl).save, - Load: (*testImpl).load, - }) - Register("stateTest.testI", (*testI)(nil), Fns{ - Save: (*testI).save, - Load: (*testI).load, - }) - Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{ - Save: (*cycleStruct).save, - Load: (*cycleStruct).load, - }) - Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{ - Save: (*badCycleStruct).save, - Load: (*badCycleStruct).load, - }) - Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{ - Save: (*emptyStructPointer).save, - Load: (*emptyStructPointer).load, - }) - Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{ - Save: (*truncateInteger).save, - Load: (*truncateInteger).load, - }) - Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{ - Save: (*truncateUnsignedInteger).save, - Load: (*truncateUnsignedInteger).load, - }) - Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{ - Save: (*truncateFloat).save, - Load: (*truncateFloat).load, - }) - Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{ - Save: (*benchStruct).save, - Load: (*benchStruct).load, - }) -} diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD deleted file mode 100644 index e7581c09b..000000000 --- a/pkg/state/statefile/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "statefile", - srcs = ["statefile.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/binary", - "//pkg/compressio", - ], -) - -go_test( - name = "statefile_test", - size = "small", - srcs = ["statefile_test.go"], - library = ":statefile", - deps = ["//pkg/compressio"], -) diff --git a/pkg/state/statefile/statefile_state_autogen.go b/pkg/state/statefile/statefile_state_autogen.go new file mode 100755 index 000000000..a2cdaa3f1 --- /dev/null +++ b/pkg/state/statefile/statefile_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package statefile diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go deleted file mode 100644 index 0b470fdec..000000000 --- a/pkg/state/statefile/statefile_test.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package statefile - -import ( - "bytes" - crand "crypto/rand" - "encoding/base64" - "io" - "math/rand" - "runtime" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/compressio" -) - -func randomKey() ([]byte, error) { - r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize)) - if _, err := io.ReadFull(crand.Reader, r); err != nil { - return nil, err - } - key := make([]byte, keySize) - base64.RawStdEncoding.Encode(key, r) - return key, nil -} - -type testCase struct { - name string - data []byte - metadata map[string]string -} - -func TestStatefile(t *testing.T) { - rand.Seed(time.Now().Unix()) - - cases := []testCase{ - // Various data sizes. - {"nil", nil, nil}, - {"empty", []byte(""), nil}, - {"some", []byte("_"), nil}, - {"one", []byte("0"), nil}, - {"two", []byte("01"), nil}, - {"three", []byte("012"), nil}, - {"four", []byte("0123"), nil}, - {"five", []byte("01234"), nil}, - {"six", []byte("012356"), nil}, - {"seven", []byte("0123567"), nil}, - {"eight", []byte("01235678"), nil}, - - // Make sure we have one longer than the hash length. - {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil}, - - // Make sure we have one longer than the chunk size. - {"chunks", make([]byte, 3*compressionChunkSize), nil}, - {"large", make([]byte, 30*compressionChunkSize), nil}, - - // Different metadata. - {"one metadata", []byte("data"), map[string]string{"foo": "bar"}}, - {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}}, - } - - for _, c := range cases { - // Generate a key. - integrityKey, err := randomKey() - if err != nil { - t.Errorf("can't generate key: got %v, excepted nil", err) - continue - } - - t.Run(c.name, func(t *testing.T) { - for _, key := range [][]byte{nil, integrityKey} { - t.Run("key="+string(key), func(t *testing.T) { - // Encoding happens via a buffer. - var bufEncoded bytes.Buffer - var bufDecoded bytes.Buffer - - // Do all the writing. - w, err := NewWriter(&bufEncoded, key, c.metadata) - if err != nil { - t.Fatalf("error creating writer: got %v, expected nil", err) - } - if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil { - t.Fatalf("error during write: got %v, expected nil", err) - } - - // Finish the sum. - if err := w.Close(); err != nil { - t.Fatalf("error during close: got %v, expected nil", err) - } - - t.Logf("original data: %d bytes, encoded: %d bytes.", - len(c.data), len(bufEncoded.Bytes())) - - // Do all the reading. - r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key) - if err != nil { - t.Fatalf("error creating reader: got %v, expected nil", err) - } - if _, err := io.Copy(&bufDecoded, r); err != nil { - t.Fatalf("error during read: got %v, expected nil", err) - } - - // Check that the data matches. - if !bytes.Equal(c.data, bufDecoded.Bytes()) { - t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data)) - } - - // Check that the metadata matches. - for k, v := range c.metadata { - nv, ok := metadata[k] - if !ok { - t.Fatalf("missing metadata: %s", k) - } - if v != nv { - t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v) - } - } - - // Change the data and verify that it fails. - if key != nil { - b := append([]byte(nil), bufEncoded.Bytes()...) - b[rand.Intn(len(b))]++ - bufDecoded.Reset() - r, _, err = NewReader(bytes.NewReader(b), key) - if err == nil { - _, err = io.Copy(&bufDecoded, r) - } - if err == nil { - t.Error("got no error: expected error on data corruption") - } - } - - // Change the key and verify that it fails. - newKey := integrityKey - if len(key) > 0 { - newKey = append([]byte{}, key...) - newKey[rand.Intn(len(newKey))]++ - } - bufDecoded.Reset() - r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey) - if err == nil { - _, err = io.Copy(&bufDecoded, r) - } - if err != compressio.ErrHashMismatch { - t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err) - } - }) - } - }) - } -} - -const benchmarkDataSize = 100 * 1024 * 1024 - -func benchmark(b *testing.B, size int, write bool, compressible bool) { - b.StopTimer() - b.SetBytes(benchmarkDataSize) - - // Generate source data. - var source []byte - if compressible { - // For compressible data, we use essentially all zeros. - source = make([]byte, benchmarkDataSize) - } else { - // For non-compressible data, we use random base64 data (to - // make it marginally compressible, a ratio of 75%). - var sourceBuf bytes.Buffer - bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf) - bufR := rand.New(rand.NewSource(0)) - if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil { - b.Fatalf("unable to seed random data: %v", err) - } - source = sourceBuf.Bytes() - } - - // Generate a random key for integrity check. - key, err := randomKey() - if err != nil { - b.Fatalf("error generating key: %v", err) - } - - // Define our benchmark functions. Prior to running the readState - // function here, you must execute the writeState function at least - // once (done below). - var stateBuf bytes.Buffer - writeState := func() { - stateBuf.Reset() - w, err := NewWriter(&stateBuf, key, nil) - if err != nil { - b.Fatalf("error creating writer: %v", err) - } - for done := 0; done < len(source); { - chunk := size // limit size. - if done+chunk > len(source) { - chunk = len(source) - done - } - n, err := w.Write(source[done : done+chunk]) - done += n - if n == 0 && err != nil { - b.Fatalf("error during write: %v", err) - } - } - if err := w.Close(); err != nil { - b.Fatalf("error closing writer: %v", err) - } - } - readState := func() { - tmpBuf := bytes.NewBuffer(stateBuf.Bytes()) - r, _, err := NewReader(tmpBuf, key) - if err != nil { - b.Fatalf("error creating reader: %v", err) - } - for done := 0; done < len(source); { - chunk := size // limit size. - if done+chunk > len(source) { - chunk = len(source) - done - } - n, err := r.Read(source[done : done+chunk]) - done += n - if n == 0 && err != nil { - b.Fatalf("error during read: %v", err) - } - } - } - // Generate the state once without timing to ensure that buffers have - // been appropriately allocated. - writeState() - if write { - b.StartTimer() - for i := 0; i < b.N; i++ { - writeState() - } - b.StopTimer() - } else { - b.StartTimer() - for i := 0; i < b.N; i++ { - readState() - } - b.StopTimer() - } -} - -func BenchmarkWrite4KCompressible(b *testing.B) { - benchmark(b, 4096, true, true) -} - -func BenchmarkWrite4KNoncompressible(b *testing.B) { - benchmark(b, 4096, true, false) -} - -func BenchmarkWrite1MCompressible(b *testing.B) { - benchmark(b, 1024*1024, true, true) -} - -func BenchmarkWrite1MNoncompressible(b *testing.B) { - benchmark(b, 1024*1024, true, false) -} - -func BenchmarkRead4KCompressible(b *testing.B) { - benchmark(b, 4096, false, true) -} - -func BenchmarkRead4KNoncompressible(b *testing.B) { - benchmark(b, 4096, false, false) -} - -func BenchmarkRead1MCompressible(b *testing.B) { - benchmark(b, 1024*1024, false, true) -} - -func BenchmarkRead1MNoncompressible(b *testing.B) { - benchmark(b, 1024*1024, false, false) -} - -func init() { - runtime.GOMAXPROCS(runtime.NumCPU()) -} diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD deleted file mode 100644 index 5340cf0d6..000000000 --- a/pkg/sync/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -exports_files(["LICENSE"]) - -go_template( - name = "generic_atomicptr", - srcs = ["atomicptr_unsafe.go"], - types = [ - "Value", - ], -) - -go_template( - name = "generic_seqatomic", - srcs = ["seqatomic_unsafe.go"], - types = [ - "Value", - ], - deps = [ - ":sync", - ], -) - -go_library( - name = "sync", - srcs = [ - "aliases.go", - "downgradable_rwmutex_unsafe.go", - "memmove_unsafe.go", - "norace_unsafe.go", - "race_unsafe.go", - "seqcount.go", - "syncutil.go", - "tmutex_unsafe.go", - ], -) - -go_test( - name = "sync_test", - size = "small", - srcs = [ - "downgradable_rwmutex_test.go", - "seqcount_test.go", - "tmutex_test.go", - ], - library = ":sync", -) diff --git a/pkg/sync/LICENSE b/pkg/sync/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sync/README.md b/pkg/sync/README.md deleted file mode 100644 index 2183c4e20..000000000 --- a/pkg/sync/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Syncutil - -This package provides additional synchronization primitives not provided by the -Go stdlib 'sync' package. It is partially derived from the upstream 'sync' -package from go1.10. diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go index d2d7132fa..d2d7132fa 100644..100755 --- a/pkg/sync/aliases.go +++ b/pkg/sync/aliases.go diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD deleted file mode 100644 index e97553254..000000000 --- a/pkg/sync/atomicptrtest/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_int", - out = "atomicptr_int_unsafe.go", - package = "atomicptr", - suffix = "Int", - template = "//pkg/sync:generic_atomicptr", - types = { - "Value": "int", - }, -) - -go_library( - name = "atomicptr", - srcs = ["atomicptr_int_unsafe.go"], -) - -go_test( - name = "atomicptr_test", - size = "small", - srcs = ["atomicptr_test.go"], - library = ":atomicptr", -) diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go deleted file mode 100644 index 8fdc5112e..000000000 --- a/pkg/sync/atomicptrtest/atomicptr_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package atomicptr - -import ( - "testing" -) - -func newInt(val int) *int { - return &val -} - -func TestAtomicPtr(t *testing.T) { - var p AtomicPtrInt - if got := p.Load(); got != nil { - t.Errorf("initial value is %p (%v), wanted nil", got, got) - } - want := newInt(42) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } - want = newInt(100) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } -} diff --git a/pkg/sync/downgradable_rwmutex_test.go b/pkg/sync/downgradable_rwmutex_test.go deleted file mode 100644 index ce667e825..000000000 --- a/pkg/sync/downgradable_rwmutex_test.go +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// GOMAXPROCS=10 go test - -// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the -// addition of downgradingWriter and the renaming of num_iterations to -// numIterations to shut up Golint. - -package sync - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" -) - -func parallelReader(m *RWMutex, clocked, cunlock, cdone chan bool) { - m.RLock() - clocked <- true - <-cunlock - m.RUnlock() - cdone <- true -} - -func doTestParallelReaders(numReaders, gomaxprocs int) { - runtime.GOMAXPROCS(gomaxprocs) - var m RWMutex - clocked := make(chan bool) - cunlock := make(chan bool) - cdone := make(chan bool) - for i := 0; i < numReaders; i++ { - go parallelReader(&m, clocked, cunlock, cdone) - } - // Wait for all parallel RLock()s to succeed. - for i := 0; i < numReaders; i++ { - <-clocked - } - for i := 0; i < numReaders; i++ { - cunlock <- true - } - // Wait for the goroutines to finish. - for i := 0; i < numReaders; i++ { - <-cdone - } -} - -func TestParallelReaders(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - doTestParallelReaders(1, 4) - doTestParallelReaders(3, 4) - doTestParallelReaders(4, 2) -} - -func reader(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.RLock() - n := atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func writer(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.Unlock() - } - cdone <- true -} - -func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.DowngradeLock() - n = atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - n = atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { - runtime.GOMAXPROCS(gomaxprocs) - // Number of active readers + 10000 * number of active writers. - var activity int32 - var rwm RWMutex - cdone := make(chan bool) - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - var i int - for i = 0; i < numReaders/2; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - for ; i < numReaders; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - // Wait for the 4 writers and all readers to finish. - for i := 0; i < 4+numReaders; i++ { - <-cdone - } -} - -func TestDowngradableRWMutex(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - n := 1000 - if testing.Short() { - n = 5 - } - HammerDowngradableRWMutex(1, 1, n) - HammerDowngradableRWMutex(1, 3, n) - HammerDowngradableRWMutex(1, 10, n) - HammerDowngradableRWMutex(4, 1, n) - HammerDowngradableRWMutex(4, 3, n) - HammerDowngradableRWMutex(4, 10, n) - HammerDowngradableRWMutex(10, 1, n) - HammerDowngradableRWMutex(10, 3, n) - HammerDowngradableRWMutex(10, 10, n) - HammerDowngradableRWMutex(10, 5, n) -} - -func TestRWDoubleTryLock(t *testing.T) { - var rwm RWMutex - if !rwm.TryLock() { - t.Fatal("failed to aquire lock") - } - if rwm.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestRWTryLockAfterLock(t *testing.T) { - var rwm RWMutex - rwm.Lock() - if rwm.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestRWTryLockUnlock(t *testing.T) { - var rwm RWMutex - if !rwm.TryLock() { - t.Fatal("failed to aquire lock") - } - rwm.Unlock() - if !rwm.TryLock() { - t.Fatal("failed to aquire lock after unlock") - } -} - -func TestTryRLockAfterLock(t *testing.T) { - var rwm RWMutex - rwm.Lock() - if rwm.TryRLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestTryLockAfterRLock(t *testing.T) { - var rwm RWMutex - rwm.RLock() - if rwm.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestDoubleTryRLock(t *testing.T) { - var rwm RWMutex - if !rwm.TryRLock() { - t.Fatal("failed to aquire lock") - } - if !rwm.TryRLock() { - t.Fatal("failed to read aquire read locked lock") - } -} diff --git a/pkg/sync/downgradable_rwmutex_unsafe.go b/pkg/sync/downgradable_rwmutex_unsafe.go index ea6cdc447..ea6cdc447 100644..100755 --- a/pkg/sync/downgradable_rwmutex_unsafe.go +++ b/pkg/sync/downgradable_rwmutex_unsafe.go diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go index ad4a3a37e..ad4a3a37e 100644..100755 --- a/pkg/sync/memmove_unsafe.go +++ b/pkg/sync/memmove_unsafe.go diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go index 006055dd6..006055dd6 100644..100755 --- a/pkg/sync/norace_unsafe.go +++ b/pkg/sync/norace_unsafe.go diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go index 31d8fa9a6..31d8fa9a6 100644..100755 --- a/pkg/sync/race_unsafe.go +++ b/pkg/sync/race_unsafe.go diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD deleted file mode 100644 index 5c38c783e..000000000 --- a/pkg/sync/seqatomictest/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "seqatomic_int", - out = "seqatomic_int_unsafe.go", - package = "seqatomic", - suffix = "Int", - template = "//pkg/sync:generic_seqatomic", - types = { - "Value": "int", - }, -) - -go_library( - name = "seqatomic", - srcs = ["seqatomic_int_unsafe.go"], - deps = [ - "//pkg/sync", - ], -) - -go_test( - name = "seqatomic_test", - size = "small", - srcs = ["seqatomic_test.go"], - library = ":seqatomic", - deps = ["//pkg/sync"], -) diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go deleted file mode 100644 index 2c4568b07..000000000 --- a/pkg/sync/seqatomictest/seqatomic_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqatomic - -import ( - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestSeqAtomicLoadUncontended(t *testing.T) { - var seq sync.SeqCount - const want = 1 - data := want - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadAfterWrite(t *testing.T) { - var seq sync.SeqCount - var data int - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadDuringWrite(t *testing.T) { - var seq sync.SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicTryLoadUncontended(t *testing.T) { - var seq sync.SeqCount - const want = 1 - data := want - epoch := seq.BeginRead() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } -} - -func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { - var seq sync.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } - seq.EndWrite() -} - -func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { - var seq sync.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } -} - -func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { - var seq sync.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := SeqAtomicLoadInt(&seq, &data); got != want { - b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } - } - }) -} - -func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { - var seq sync.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - epoch := seq.BeginRead() - for pb.Next() { - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } - } - }) -} - -// For comparison: -func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { - var a atomic.Value - const want = 42 - a.Store(int(want)) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := a.Load().(int); got != want { - b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) - } - } - }) -} diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go index a1e895352..a1e895352 100644..100755 --- a/pkg/sync/seqcount.go +++ b/pkg/sync/seqcount.go diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go deleted file mode 100644 index 6eb7b4b59..000000000 --- a/pkg/sync/seqcount_test.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package sync - -import ( - "reflect" - "testing" - "time" -) - -func TestSeqCountWriteUncontended(t *testing.T) { - var seq SeqCount - seq.BeginWrite() - seq.EndWrite() -} - -func TestSeqCountReadUncontended(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadAfterWrite(t *testing.T) { - var seq SeqCount - var data int32 - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadDuringWrite(t *testing.T) { - var seq SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountReadOkAfterWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } -} - -func TestSeqCountReadOkDuringWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } - seq.EndWrite() -} - -func BenchmarkSeqCountWriteUncontended(b *testing.B) { - var seq SeqCount - for i := 0; i < b.N; i++ { - seq.BeginWrite() - seq.EndWrite() - } -} - -func BenchmarkSeqCountReadUncontended(b *testing.B) { - var seq SeqCount - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - b.Fatalf("ReadOk: got false, wanted true") - } - } - }) -} - -func TestPointersInType(t *testing.T) { - for _, test := range []struct { - name string // used for both test and value name - val interface{} - ptrs []string - }{ - { - name: "EmptyStruct", - val: struct{}{}, - }, - { - name: "Int", - val: int(0), - }, - { - name: "MixedStruct", - val: struct { - b bool - I int - ExportedPtr *struct{} - unexportedPtr *struct{} - arr [2]int - ptrArr [2]*int - nestedStruct struct { - nestedNonptr int - nestedPtr *int - } - structArr [1]struct { - nonptr int - ptr *int - } - }{}, - ptrs: []string{ - "MixedStruct.ExportedPtr", - "MixedStruct.unexportedPtr", - "MixedStruct.ptrArr[]", - "MixedStruct.nestedStruct.nestedPtr", - "MixedStruct.structArr[].ptr", - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - typ := reflect.TypeOf(test.val) - ptrs := PointersInType(typ, test.name) - t.Logf("Found pointers: %v", ptrs) - if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { - t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) - } - }) - } -} diff --git a/pkg/sync/sync_state_autogen.go b/pkg/sync/sync_state_autogen.go new file mode 100755 index 000000000..7ce796ad8 --- /dev/null +++ b/pkg/sync/sync_state_autogen.go @@ -0,0 +1,12 @@ +// automatically generated by stateify. + +// +build go1.13 +// +build !go1.15 +// +build go1.12 +// +build !go1.15 +// +build !race +// +build race +// +build go1.13 +// +build !go1.15 + +package sync diff --git a/pkg/sync/syncutil.go b/pkg/sync/syncutil.go index b16cf5333..b16cf5333 100644..100755 --- a/pkg/sync/syncutil.go +++ b/pkg/sync/syncutil.go diff --git a/pkg/sync/tmutex_test.go b/pkg/sync/tmutex_test.go deleted file mode 100644 index 0838248b4..000000000 --- a/pkg/sync/tmutex_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package sync - -import ( - "sync" - "testing" - "unsafe" -) - -// TestStructSize verifies that syncMutex's size hasn't drifted from the -// standard library's version. -// -// The correctness of this package relies on these remaining in sync. -func TestStructSize(t *testing.T) { - const ( - got = unsafe.Sizeof(syncMutex{}) - want = unsafe.Sizeof(sync.Mutex{}) - ) - if got != want { - t.Errorf("got sizeof(syncMutex) = %d, want = sizeof(sync.Mutex) = %d", got, want) - } -} - -// TestFieldValues verifies that the semantics of syncMutex.state from the -// standard library's implementation. -// -// The correctness of this package relies on these remaining in sync. -func TestFieldValues(t *testing.T) { - var m Mutex - m.Lock() - if got := *m.state(); got != mutexLocked { - t.Errorf("got locked sync.Mutex.state = %d, want = %d", got, mutexLocked) - } - m.Unlock() - if got := *m.state(); got != mutexUnlocked { - t.Errorf("got unlocked sync.Mutex.state = %d, want = %d", got, mutexUnlocked) - } -} - -func TestDoubleTryLock(t *testing.T) { - var m Mutex - if !m.TryLock() { - t.Fatal("failed to aquire lock") - } - if m.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestTryLockAfterLock(t *testing.T) { - var m Mutex - m.Lock() - if m.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestTryLockUnlock(t *testing.T) { - var m Mutex - if !m.TryLock() { - t.Fatal("failed to aquire lock") - } - m.Unlock() - if !m.TryLock() { - t.Fatal("failed to aquire lock after unlock") - } -} diff --git a/pkg/sync/tmutex_unsafe.go b/pkg/sync/tmutex_unsafe.go index 3dd15578b..3dd15578b 100644..100755 --- a/pkg/sync/tmutex_unsafe.go +++ b/pkg/sync/tmutex_unsafe.go diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD deleted file mode 100644 index 0500a22cf..000000000 --- a/pkg/syncevent/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "syncevent", - srcs = [ - "broadcaster.go", - "receiver.go", - "source.go", - "syncevent.go", - "waiter_amd64.s", - "waiter_arm64.s", - "waiter_asm_unsafe.go", - "waiter_noasm_unsafe.go", - "waiter_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/atomicbitops", - "//pkg/sync", - ], -) - -go_test( - name = "syncevent_test", - size = "small", - srcs = [ - "broadcaster_test.go", - "syncevent_example_test.go", - "waiter_test.go", - ], - library = ":syncevent", - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/waiter", - ], -) diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go deleted file mode 100644 index 4bff59e7d..000000000 --- a/pkg/syncevent/broadcaster.go +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "gvisor.dev/gvisor/pkg/sync" -) - -// Broadcaster is an implementation of Source that supports any number of -// subscribed Receivers. -// -// The zero value of Broadcaster is valid and has no subscribed Receivers. -// Broadcaster is not copyable by value. -// -// All Broadcaster methods may be called concurrently from multiple goroutines. -type Broadcaster struct { - // Broadcaster is implemented as a hash table where keys are assigned by - // the Broadcaster and returned as SubscriptionIDs, making it safe to use - // the identity function for hashing. The hash table resolves collisions - // using linear probing and features Robin Hood insertion and backward - // shift deletion in order to support a relatively high load factor - // efficiently, which matters since the cost of Broadcast is linear in the - // size of the table. - - // mu protects the following fields. - mu sync.Mutex - - // Invariants: len(table) is 0 or a power of 2. - table []broadcasterSlot - - // load is the number of entries in table with receiver != nil. - load int - - lastID SubscriptionID -} - -type broadcasterSlot struct { - // Invariants: If receiver == nil, then filter == NoEvents and id == 0. - // Otherwise, id != 0. - receiver *Receiver - filter Set - id SubscriptionID -} - -const ( - broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1 - - broadcasterMaxLoadNum = 13 - broadcasterMaxLoadDen = 16 -) - -// SubscribeEvents implements Source.SubscribeEvents. -func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID { - b.mu.Lock() - - // Assign an ID for this subscription. - b.lastID++ - id := b.lastID - - // Expand the table if over the maximum load factor: - // - // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen - // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table) - b.load++ - if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) { - // Double the number of slots in the new table. - newlen := broadcasterMinNonZeroTableSize - if len(b.table) != 0 { - newlen = 2 * len(b.table) - } - if newlen <= cap(b.table) { - // Reuse excess capacity in the current table, moving entries not - // already in their first-probed positions to better ones. - newtable := b.table[:newlen] - newmask := uint64(newlen - 1) - for i := range b.table { - if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) { - entry := b.table[i] - b.table[i] = broadcasterSlot{} - broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter) - } - } - b.table = newtable - } else { - newtable := make([]broadcasterSlot, newlen) - // Copy existing entries to the new table. - for i := range b.table { - if b.table[i].receiver != nil { - broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter) - } - } - // Switch to the new table. - b.table = newtable - } - } - - broadcasterTableInsert(b.table, id, r, filter) - b.mu.Unlock() - return id -} - -// Preconditions: table must not be full. len(table) is a power of 2. -func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) { - entry := broadcasterSlot{ - receiver: r, - filter: filter, - id: id, - } - mask := uint64(len(table) - 1) - i := uint64(id) & mask - disp := uint64(0) - for { - if table[i].receiver == nil { - table[i] = entry - return - } - // If we've been displaced farther from our first-probed slot than the - // element stored in this one, swap elements and switch to inserting - // the replaced one. (This is Robin Hood insertion.) - slotDisp := (i - uint64(table[i].id)) & mask - if disp > slotDisp { - table[i], entry = entry, table[i] - disp = slotDisp - } - i = (i + 1) & mask - disp++ - } -} - -// UnsubscribeEvents implements Source.UnsubscribeEvents. -func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) { - b.mu.Lock() - - mask := uint64(len(b.table) - 1) - i := uint64(id) & mask - for { - if b.table[i].id == id { - // Found the element to remove. Move all subsequent elements - // backward until we either find an empty slot, or an element that - // is already in its first-probed slot. (This is backward shift - // deletion.) - for { - next := (i + 1) & mask - if b.table[next].receiver == nil { - break - } - if uint64(b.table[next].id)&mask == next { - break - } - b.table[i] = b.table[next] - i = next - } - b.table[i] = broadcasterSlot{} - break - } - i = (i + 1) & mask - } - - // If a table 1/4 of the current size would still be at or under the - // maximum load factor (i.e. the current table size is at least two - // expansions bigger than necessary), halve the size of the table to reduce - // the cost of Broadcast. Since we are concerned with iteration time and - // not memory usage, reuse the existing slice to reduce future allocations - // from table re-expansion. - b.load-- - if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) { - newlen := len(b.table) / 2 - newtable := b.table[:newlen] - for i := newlen; i < len(b.table); i++ { - if b.table[i].receiver != nil { - broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter) - b.table[i] = broadcasterSlot{} - } - } - b.table = newtable - } - - b.mu.Unlock() -} - -// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset -// of events to which they subscribed. The order in which Receivers are -// notified is unspecified. -func (b *Broadcaster) Broadcast(events Set) { - b.mu.Lock() - for i := range b.table { - if intersection := events & b.table[i].filter; intersection != 0 { - // We don't need to check if broadcasterSlot.receiver is nil, since - // if it is then broadcasterSlot.filter is 0. - b.table[i].receiver.Notify(intersection) - } - } - b.mu.Unlock() -} - -// FilteredEvents returns the set of events for which Broadcast will notify at -// least one Receiver, i.e. the union of filters for all subscribed Receivers. -func (b *Broadcaster) FilteredEvents() Set { - var es Set - b.mu.Lock() - for i := range b.table { - es |= b.table[i].filter - } - b.mu.Unlock() - return es -} diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go deleted file mode 100644 index e88779e23..000000000 --- a/pkg/syncevent/broadcaster_test.go +++ /dev/null @@ -1,376 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "fmt" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestBroadcasterFilter(t *testing.T) { - const numReceivers = 2 * MaxEvents - - var br Broadcaster - ws := make([]Waiter, numReceivers) - for i := range ws { - ws[i].Init() - br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents)) - } - for ev := 0; ev < MaxEvents; ev++ { - br.Broadcast(1 << ev) - for i := range ws { - want := NoEvents - if i%MaxEvents == ev { - want = 1 << ev - } - if got := ws[i].Receiver().PendingAndAckAll(); got != want { - t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want) - } - } - } -} - -// TestBroadcasterManySubscriptions tests that subscriptions are not lost by -// table expansion/compaction. -func TestBroadcasterManySubscriptions(t *testing.T) { - const numReceivers = 5000 // arbitrary - - var br Broadcaster - ws := make([]Waiter, numReceivers) - for i := range ws { - ws[i].Init() - } - - ids := make([]SubscriptionID, numReceivers) - for i := 0; i < numReceivers; i++ { - // Subscribe receiver i. - ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1) - // Check that receivers [0, i] are subscribed. - br.Broadcast(1) - for j := 0; j <= i; j++ { - if ws[j].Pending() != 1 { - t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i) - } - ws[j].Ack(1) - } - } - - // Generate a random order for unsubscriptions. - unsub := rand.Perm(numReceivers) - for i := 0; i < numReceivers; i++ { - // Unsubscribe receiver unsub[i]. - br.UnsubscribeEvents(ids[unsub[i]]) - // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that - // receivers (unsub[i], unsub[numReceivers]) are still subscribed. - br.Broadcast(1) - for j := 0; j <= i; j++ { - if ws[unsub[j]].Pending() != 0 { - t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i]) - } - } - for j := i + 1; j < numReceivers; j++ { - if ws[unsub[j]].Pending() != 1 { - t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i]) - } - ws[unsub[j]].Ack(1) - } - } -} - -var ( - receiverCountsNonZero = []int{1, 4, 16, 64} - receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...) -) - -// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage -// pattern X (described in terms of Broadcaster) with Broadcaster, a -// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively. - -// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe -// cycle. - -func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) { - var br Broadcaster - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - id := br.SubscribeEvents(w.Receiver(), 1) - br.UnsubscribeEvents(id) - } -} - -func BenchmarkMapSubscribeUnsubscribe(b *testing.B) { - var mu sync.Mutex - m := make(map[*Receiver]Set) - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mu.Lock() - m[w.Receiver()] = Set(1) - mu.Unlock() - mu.Lock() - delete(m, w.Receiver()) - mu.Unlock() - } -} - -func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) { - var q waiter.Queue - e, _ := waiter.NewChannelEntry(nil) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - q.EventRegister(&e, 1) - q.EventUnregister(&e) - } -} - -// BenchmarkXxxSubscribeUnsubscribeBatch is similar to -// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large -// number of Receivers at a time in order to measure the amortized overhead of -// table expansion/compaction. (Since waiter.Queue is implemented using a -// linked list, BenchmarkQueueSubscribeUnsubscribe and -// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same -// result.) - -const numBatchReceivers = 1000 - -func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) { - var br Broadcaster - ws := make([]Waiter, numBatchReceivers) - for i := range ws { - ws[i].Init() - } - ids := make([]SubscriptionID, numBatchReceivers) - - // Generate a random order for unsubscriptions. - unsub := rand.Perm(numBatchReceivers) - - b.ResetTimer() - for i := 0; i < b.N/numBatchReceivers; i++ { - for j := 0; j < numBatchReceivers; j++ { - ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1) - } - for j := 0; j < numBatchReceivers; j++ { - br.UnsubscribeEvents(ids[unsub[j]]) - } - } -} - -func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) { - var mu sync.Mutex - m := make(map[*Receiver]Set) - ws := make([]Waiter, numBatchReceivers) - for i := range ws { - ws[i].Init() - } - - // Generate a random order for unsubscriptions. - unsub := rand.Perm(numBatchReceivers) - - b.ResetTimer() - for i := 0; i < b.N/numBatchReceivers; i++ { - for j := 0; j < numBatchReceivers; j++ { - mu.Lock() - m[ws[j].Receiver()] = Set(1) - mu.Unlock() - } - for j := 0; j < numBatchReceivers; j++ { - mu.Lock() - delete(m, ws[unsub[j]].Receiver()) - mu.Unlock() - } - } -} - -func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) { - var q waiter.Queue - es := make([]waiter.Entry, numBatchReceivers) - for i := range es { - es[i], _ = waiter.NewChannelEntry(nil) - } - - // Generate a random order for unsubscriptions. - unsub := rand.Perm(numBatchReceivers) - - b.ResetTimer() - for i := 0; i < b.N/numBatchReceivers; i++ { - for j := 0; j < numBatchReceivers; j++ { - q.EventRegister(&es[j], 1) - } - for j := 0; j < numBatchReceivers; j++ { - q.EventUnregister(&es[unsub[j]]) - } - } -} - -// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast -// already-pending events to multiple Receivers. - -func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) { - for _, n := range receiverCountsIncludingZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var br Broadcaster - ws := make([]Waiter, n) - for i := range ws { - ws[i].Init() - br.SubscribeEvents(ws[i].Receiver(), 1) - } - br.Broadcast(1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - br.Broadcast(1) - } - }) - } -} - -func BenchmarkMapBroadcastRedundant(b *testing.B) { - for _, n := range receiverCountsIncludingZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var mu sync.Mutex - m := make(map[*Receiver]Set) - ws := make([]Waiter, n) - for i := range ws { - ws[i].Init() - m[ws[i].Receiver()] = Set(1) - } - mu.Lock() - for r := range m { - r.Notify(1) - } - mu.Unlock() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mu.Lock() - for r := range m { - r.Notify(1) - } - mu.Unlock() - } - }) - } -} - -func BenchmarkQueueBroadcastRedundant(b *testing.B) { - for _, n := range receiverCountsIncludingZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var q waiter.Queue - for i := 0; i < n; i++ { - e, _ := waiter.NewChannelEntry(nil) - q.EventRegister(&e, 1) - } - q.Notify(1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - q.Notify(1) - } - }) - } -} - -// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to -// multiple Receivers, check that all Receivers have received the event, and -// clear the event from all Receivers. - -func BenchmarkBroadcasterBroadcastAck(b *testing.B) { - for _, n := range receiverCountsNonZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var br Broadcaster - ws := make([]Waiter, n) - for i := range ws { - ws[i].Init() - br.SubscribeEvents(ws[i].Receiver(), 1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - br.Broadcast(1) - for j := range ws { - if got, want := ws[j].Pending(), Set(1); got != want { - b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want) - } - ws[j].Ack(1) - } - } - }) - } -} - -func BenchmarkMapBroadcastAck(b *testing.B) { - for _, n := range receiverCountsNonZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var mu sync.Mutex - m := make(map[*Receiver]Set) - ws := make([]Waiter, n) - for i := range ws { - ws[i].Init() - m[ws[i].Receiver()] = Set(1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mu.Lock() - for r := range m { - r.Notify(1) - } - mu.Unlock() - for j := range ws { - if got, want := ws[j].Pending(), Set(1); got != want { - b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want) - } - ws[j].Ack(1) - } - } - }) - } -} - -func BenchmarkQueueBroadcastAck(b *testing.B) { - for _, n := range receiverCountsNonZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var q waiter.Queue - chs := make([]chan struct{}, n) - for i := range chs { - e, ch := waiter.NewChannelEntry(nil) - q.EventRegister(&e, 1) - chs[i] = ch - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - q.Notify(1) - for _, ch := range chs { - select { - case <-ch: - default: - b.Fatalf("channel did not receive event") - } - } - } - }) - } -} diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go deleted file mode 100644 index 5c86e5400..000000000 --- a/pkg/syncevent/receiver.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "sync/atomic" - - "gvisor.dev/gvisor/pkg/atomicbitops" -) - -// Receiver is an event sink that holds pending events and invokes a callback -// whenever new events become pending. Receiver's methods may be called -// concurrently from multiple goroutines. -// -// Receiver.Init() must be called before first use. -type Receiver struct { - // pending is the set of pending events. pending is accessed using atomic - // memory operations. - pending uint64 - - // cb is notified when new events become pending. cb is immutable after - // Init(). - cb ReceiverCallback -} - -// ReceiverCallback receives callbacks from a Receiver. -type ReceiverCallback interface { - // NotifyPending is called when the corresponding Receiver has new pending - // events. - // - // NotifyPending is called synchronously from Receiver.Notify(), so - // implementations must not take locks that may be held by callers of - // Receiver.Notify(). NotifyPending may be called concurrently from - // multiple goroutines. - NotifyPending() -} - -// Init must be called before first use of r. -func (r *Receiver) Init(cb ReceiverCallback) { - r.cb = cb -} - -// Pending returns the set of pending events. -func (r *Receiver) Pending() Set { - return Set(atomic.LoadUint64(&r.pending)) -} - -// Notify sets the given events as pending. -func (r *Receiver) Notify(es Set) { - p := Set(atomic.LoadUint64(&r.pending)) - // Optimization: Skip the atomic CAS on r.pending if all events are - // already pending. - if p&es == es { - return - } - // When this is uncontended (the common case), CAS is faster than - // atomic-OR because the former is inlined and the latter (which we - // implement in assembly ourselves) is not. - if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) { - // If the CAS fails, fall back to atomic-OR. - atomicbitops.OrUint64(&r.pending, uint64(es)) - } - r.cb.NotifyPending() -} - -// Ack unsets the given events as pending. -func (r *Receiver) Ack(es Set) { - p := Set(atomic.LoadUint64(&r.pending)) - // Optimization: Skip the atomic CAS on r.pending if all events are - // already not pending. - if p&es == 0 { - return - } - // When this is uncontended (the common case), CAS is faster than - // atomic-AND because the former is inlined and the latter (which we - // implement in assembly ourselves) is not. - if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) { - // If the CAS fails, fall back to atomic-AND. - atomicbitops.AndUint64(&r.pending, ^uint64(es)) - } -} - -// PendingAndAckAll unsets all events as pending and returns the set of -// previously-pending events. -// -// PendingAndAckAll should only be used in preference to a call to Pending -// followed by a conditional call to Ack when the caller expects events to be -// pending (e.g. after a call to ReceiverCallback.NotifyPending()). -func (r *Receiver) PendingAndAckAll() Set { - return Set(atomic.SwapUint64(&r.pending, 0)) -} diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go deleted file mode 100644 index ddffb171a..000000000 --- a/pkg/syncevent/source.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -// Source represents an event source. -type Source interface { - // SubscribeEvents causes the Source to notify the given Receiver of the - // given subset of events. - // - // Preconditions: r != nil. The ReceiverCallback for r must not take locks - // that are ordered prior to the Source; for example, it cannot call any - // Source methods. - SubscribeEvents(r *Receiver, filter Set) SubscriptionID - - // UnsubscribeEvents causes the Source to stop notifying the Receiver - // subscribed by a previous call to SubscribeEvents that returned the given - // SubscriptionID. - // - // Preconditions: UnsubscribeEvents may be called at most once for any - // given SubscriptionID. - UnsubscribeEvents(id SubscriptionID) -} - -// SubscriptionID identifies a call to Source.SubscribeEvents. -type SubscriptionID uint64 - -// UnsubscribeAndAck is a convenience function that unsubscribes r from the -// given events from src and also clears them from r. -func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) { - src.UnsubscribeEvents(id) - r.Ack(filter) -} - -// NoopSource implements Source by never sending events to subscribed -// Receivers. -type NoopSource struct{} - -// SubscribeEvents implements Source.SubscribeEvents. -func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID { - return 0 -} - -// UnsubscribeEvents implements Source.UnsubscribeEvents. -func (NoopSource) UnsubscribeEvents(SubscriptionID) { -} - -// See Broadcaster for a non-noop implementations of Source. diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go deleted file mode 100644 index 9fb6a06de..000000000 --- a/pkg/syncevent/syncevent.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package syncevent provides efficient primitives for goroutine -// synchronization based on event bitmasks. -package syncevent - -// Set is a bitmask where each bit represents a distinct user-defined event. -// The event package does not treat any bits in Set specially. -type Set uint64 - -const ( - // NoEvents is a Set containing no events. - NoEvents = Set(0) - - // AllEvents is a Set containing all possible events. - AllEvents = ^Set(0) - - // MaxEvents is the number of distinct events that can be represented by a Set. - MaxEvents = 64 -) diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go deleted file mode 100644 index bfb18e2ea..000000000 --- a/pkg/syncevent/syncevent_example_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "fmt" - "sync/atomic" - "time" -) - -func Example_ioReadinessInterrputible() { - const ( - evReady = Set(1 << iota) - evInterrupt - ) - errNotReady := fmt.Errorf("not ready for I/O") - - // State of some I/O object. - var ( - br Broadcaster - ready uint32 - ) - doIO := func() error { - if atomic.LoadUint32(&ready) == 0 { - return errNotReady - } - return nil - } - go func() { - // The I/O object eventually becomes ready for I/O. - time.Sleep(100 * time.Millisecond) - // When it does, it first ensures that future calls to isReady() return - // true, then broadcasts the readiness event to Receivers. - atomic.StoreUint32(&ready, 1) - br.Broadcast(evReady) - }() - - // Each user of the I/O object owns a Waiter. - var w Waiter - w.Init() - // The Waiter may be asynchronously interruptible, e.g. for signal - // handling in the sentry. - go func() { - time.Sleep(200 * time.Millisecond) - w.Receiver().Notify(evInterrupt) - }() - - // To use the I/O object: - // - // Optionally, if the I/O object is likely to be ready, attempt I/O first. - err := doIO() - if err == nil { - // Success, we're done. - return /* nil */ - } - if err != errNotReady { - // Failure, I/O failed for some reason other than readiness. - return /* err */ - } - // Subscribe for readiness events from the I/O object. - id := br.SubscribeEvents(w.Receiver(), evReady) - // When we are finished blocking, unsubscribe from readiness events and - // remove readiness events from the pending event set. - defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id) - for { - // Attempt I/O again. This must be done after the call to SubscribeEvents, - // since the I/O object might have become ready between the previous call - // to doIO and the call to SubscribeEvents. - err = doIO() - if err == nil { - return /* nil */ - } - if err != errNotReady { - return /* err */ - } - // Block until either the I/O object indicates it is ready, or we are - // interrupted. - events := w.Wait() - if events&evInterrupt != 0 { - // In the specific case of sentry signal handling, signal delivery - // is handled by another system, so we aren't responsible for - // acknowledging evInterrupt. - return /* errInterrupted */ - } - // Note that, in a concurrent context, the I/O object might become - // ready and then not ready again. To handle this: - // - // - evReady must be acknowledged before calling doIO() again (rather - // than after), so that if the I/O object becomes ready *again* after - // the call to doIO(), the readiness event is not lost. - // - // - We must loop instead of just calling doIO() once after receiving - // evReady. - w.Ack(evReady) - } -} diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s deleted file mode 100644 index 985b56ae5..000000000 --- a/pkg/syncevent/waiter_amd64.s +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -// See waiter_noasm_unsafe.go for a description of waiterUnlock. -// -// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool -TEXT ·waiterUnlock(SB),NOSPLIT,$0-24 - MOVQ g+0(FP), DI - MOVQ wg+8(FP), SI - - MOVQ $·preparingG(SB), AX - LOCK - CMPXCHGQ DI, 0(SI) - - SETEQ AX - MOVB AX, ret+16(FP) - - RET - diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s deleted file mode 100644 index 20d7ac23b..000000000 --- a/pkg/syncevent/waiter_arm64.s +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -// See waiter_noasm_unsafe.go for a description of waiterUnlock. -// -// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool -TEXT ·waiterUnlock(SB),NOSPLIT,$0-24 - MOVD wg+8(FP), R0 - MOVD $·preparingG(SB), R1 - MOVD g+0(FP), R2 -again: - LDAXR (R0), R3 - CMP R1, R3 - BNE ok - STLXR R2, (R0), R3 - CBNZ R3, again -ok: - CSET EQ, R0 - MOVB R0, ret+16(FP) - RET - diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go deleted file mode 100644 index 0995e9053..000000000 --- a/pkg/syncevent/waiter_asm_unsafe.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 arm64 - -package syncevent - -import ( - "unsafe" -) - -// See waiter_noasm_unsafe.go for a description of waiterUnlock. -func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go deleted file mode 100644 index 1c4b0e39a..000000000 --- a/pkg/syncevent/waiter_noasm_unsafe.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// waiterUnlock is called from g0, so when the race detector is enabled, -// waiterUnlock must be implemented in assembly since no race context is -// available. -// -// +build !race -// +build !amd64,!arm64 - -package syncevent - -import ( - "sync/atomic" - "unsafe" -) - -// waiterUnlock is the "unlock function" passed to runtime.gopark by -// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g. -// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping -// should be aborted. -// -//go:nosplit -func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool { - // The only way this CAS can fail is if a call to Waiter.NotifyPending() - // has replaced *wg with nil, in which case we should not sleep. - return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), g) -} diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go deleted file mode 100644 index 3c8cbcdd8..000000000 --- a/pkg/syncevent/waiter_test.go +++ /dev/null @@ -1,414 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sleep" - "gvisor.dev/gvisor/pkg/sync" -) - -func TestWaiterAlreadyPending(t *testing.T) { - var w Waiter - w.Init() - want := Set(1) - w.Notify(want) - if got := w.Wait(); got != want { - t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want) - } -} - -func TestWaiterAsyncNotify(t *testing.T) { - var w Waiter - w.Init() - want := Set(1) - go func() { - time.Sleep(100 * time.Millisecond) - w.Notify(want) - }() - if got := w.Wait(); got != want { - t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want) - } -} - -func TestWaiterWaitFor(t *testing.T) { - var w Waiter - w.Init() - evWaited := Set(1) - evOther := Set(2) - w.Notify(evOther) - notifiedEvent := uint32(0) - go func() { - time.Sleep(100 * time.Millisecond) - atomic.StoreUint32(¬ifiedEvent, 1) - w.Notify(evWaited) - }() - if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want { - t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want) - } - if atomic.LoadUint32(¬ifiedEvent) == 0 { - t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event") - } -} - -func TestWaiterWaitAndAckAll(t *testing.T) { - var w Waiter - w.Init() - w.Notify(AllEvents) - if got := w.WaitAndAckAll(); got != AllEvents { - t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents) - } - if got := w.Pending(); got != NoEvents { - t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got) - } -} - -// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage -// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and -// buffered chan struct{} respectively. When the maximum number of event -// sources is relevant, we use 3 event sources because this is representative -// of the kernel.Task.block() use case: an interrupt source, a timeout source, -// and the actual event source being waited on. - -// Event set used by most benchmarks. -const evBench Set = 1 - -// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of -// an event that is already pending. - -func BenchmarkWaiterNotifyRedundant(b *testing.B) { - var w Waiter - w.Init() - w.Notify(evBench) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Notify(evBench) - } -} - -func BenchmarkSleeperNotifyRedundant(b *testing.B) { - var s sleep.Sleeper - var w sleep.Waker - s.AddWaker(&w, 0) - w.Assert() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Assert() - } -} - -func BenchmarkChannelNotifyRedundant(b *testing.B) { - ch := make(chan struct{}, 1) - ch <- struct{}{} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - select { - case ch <- struct{}{}: - default: - } - } -} - -// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an -// event, return that event using a blocking check, and then unset the event as -// pending. - -func BenchmarkWaiterNotifyWaitAck(b *testing.B) { - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Notify(evBench) - w.Wait() - w.Ack(evBench) - } -} - -func BenchmarkSleeperNotifyWaitAck(b *testing.B) { - var s sleep.Sleeper - var w sleep.Waker - s.AddWaker(&w, 0) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Assert() - s.Fetch(true) - } -} - -func BenchmarkChannelNotifyWaitAck(b *testing.B) { - ch := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - // notify - select { - case ch <- struct{}{}: - default: - } - - // wait + ack - <-ch - } -} - -// BenchmarkSleeperMultiNotifyWaitAck is equivalent to -// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a -// temporary sleep.Waker. This is necessary when multiple goroutines may wait -// for the same event, since each sleep.Waker can wake only a single -// sleep.Sleeper. -// -// The syncevent package does not require a distinct object for each -// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and -// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state -// for channels, runtime.sudog, is inescapably runtime-allocated, so -// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would -// also be identical. - -func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) { - var s sleep.Sleeper - // The sleep package doesn't provide sync.Pool allocation of Wakers; - // we do for a fairer comparison. - wakerPool := sync.Pool{ - New: func() interface{} { - return &sleep.Waker{} - }, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w := wakerPool.Get().(*sleep.Waker) - s.AddWaker(w, 0) - w.Assert() - s.Fetch(true) - s.Done() - wakerPool.Put(w) - } -} - -// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also -// includes allocation of a temporary Waiter. This models the case where a -// goroutine not already associated with a Waiter needs one in order to block. -// -// The analogous state for channels is built into runtime.g, so -// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be -// identical. - -func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - w := GetWaiter() - w.Notify(evBench) - w.Wait() - w.Ack(evBench) - PutWaiter(w) - } -} - -func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) { - // The sleep package doesn't provide sync.Pool allocation of Sleepers; - // we do for a fairer comparison. - sleeperPool := sync.Pool{ - New: func() interface{} { - return &sleep.Sleeper{} - }, - } - var w sleep.Waker - - b.ResetTimer() - for i := 0; i < b.N; i++ { - s := sleeperPool.Get().(*sleep.Sleeper) - s.AddWaker(&w, 0) - w.Assert() - s.Fetch(true) - s.Done() - sleeperPool.Put(s) - } -} - -// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows -// for multiple event sources. - -func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) { - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Notify(evBench) - if e := w.Wait(); e != evBench { - b.Fatalf("Wait: got %#x, wanted %#x", e, evBench) - } - w.Ack(evBench) - } -} - -func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) { - var s sleep.Sleeper - var ws [3]sleep.Waker - for i := range ws { - s.AddWaker(&ws[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ws[0].Assert() - if id, _ := s.Fetch(true); id != 0 { - b.Fatalf("Fetch: got %d, wanted 0", id) - } - } -} - -func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) { - ch0 := make(chan struct{}, 1) - ch1 := make(chan struct{}, 1) - ch2 := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - // notify - select { - case ch0 <- struct{}{}: - default: - } - - // wait + clear - select { - case <-ch0: - // ok - case <-ch1: - b.Fatalf("received from ch1") - case <-ch2: - b.Fatalf("received from ch2") - } - } -} - -// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an -// event while another goroutine signals the event. This assumes that a new -// goroutine doesn't run immediately (i.e. the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). - -func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) { - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w.Notify(1) - }() - w.Wait() - w.Ack(evBench) - } -} - -func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) { - var s sleep.Sleeper - var w sleep.Waker - s.AddWaker(&w, 0) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w.Assert() - }() - s.Fetch(true) - } -} - -func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) { - ch := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - select { - case ch <- struct{}{}: - default: - } - }() - <-ch - } -} - -// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but -// allows for multiple event sources. - -func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) { - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w.Notify(evBench) - }() - if e := w.Wait(); e != evBench { - b.Fatalf("Wait: got %#x, wanted %#x", e, evBench) - } - w.Ack(evBench) - } -} - -func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) { - var s sleep.Sleeper - var ws [3]sleep.Waker - for i := range ws { - s.AddWaker(&ws[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - ws[0].Assert() - }() - if id, _ := s.Fetch(true); id != 0 { - b.Fatalf("Fetch: got %d, expected 0", id) - } - } -} - -func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) { - ch0 := make(chan struct{}, 1) - ch1 := make(chan struct{}, 1) - ch2 := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - select { - case ch0 <- struct{}{}: - default: - } - }() - - select { - case <-ch0: - // ok - case <-ch1: - b.Fatalf("received from ch1") - case <-ch2: - b.Fatalf("received from ch2") - } - } -} diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go deleted file mode 100644 index 112e0e604..000000000 --- a/pkg/syncevent/waiter_unsafe.go +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.11 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -package syncevent - -import ( - "sync/atomic" - "unsafe" - - "gvisor.dev/gvisor/pkg/sync" -) - -//go:linkname gopark runtime.gopark -func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int) - -//go:linkname goready runtime.goready -func goready(g unsafe.Pointer, traceskip int) - -const ( - waitReasonSelect = 9 // Go: src/runtime/runtime2.go - traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go -) - -// Waiter allows a goroutine to block on pending events received by a Receiver. -// -// Waiter.Init() must be called before first use. -type Waiter struct { - r Receiver - - // g is one of: - // - // - nil: No goroutine is blocking in Wait. - // - // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet - // completed waiterUnlock(). Thus the wait can only be interrupted by - // replacing the value of g with nil (the G may not be in state Gwaiting - // yet, so we can't call goready.) - // - // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the - // goroutine blocked in Wait, which can only be woken by calling goready. - g unsafe.Pointer `state:"zerovalue"` -} - -// Sentinel object for Waiter.g. -var preparingG struct{} - -// Init must be called before first use of w. -func (w *Waiter) Init() { - w.r.Init(w) -} - -// Receiver returns the Receiver that receives events that unblock calls to -// w.Wait(). -func (w *Waiter) Receiver() *Receiver { - return &w.r -} - -// Pending returns the set of pending events. -func (w *Waiter) Pending() Set { - return w.r.Pending() -} - -// Wait blocks until at least one event is pending, then returns the set of -// pending events. It does not affect the set of pending events; callers must -// call w.Ack() to do so, or use w.WaitAndAck() instead. -// -// Precondition: Only one goroutine may call any Wait* method at a time. -func (w *Waiter) Wait() Set { - return w.WaitFor(AllEvents) -} - -// WaitFor blocks until at least one event in es is pending, then returns the -// set of pending events (including those not in es). It does not affect the -// set of pending events; callers must call w.Ack() to do so. -// -// Precondition: Only one goroutine may call any Wait* method at a time. -func (w *Waiter) WaitFor(es Set) Set { - for { - // Optimization: Skip the atomic store to w.g if an event is already - // pending. - if p := w.r.Pending(); p&es != NoEvents { - return p - } - - // Indicate that we're preparing to go to sleep. - atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG)) - - // If an event is pending, abort the sleep. - if p := w.r.Pending(); p&es != NoEvents { - atomic.StorePointer(&w.g, nil) - return p - } - - // If w.g is still preparingG (i.e. w.NotifyPending() has not been - // called or has not reached atomic.SwapPointer()), go to sleep until - // w.NotifyPending() => goready(). - gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0) - } -} - -// Ack marks the given events as not pending. -func (w *Waiter) Ack(es Set) { - w.r.Ack(es) -} - -// WaitAndAckAll blocks until at least one event is pending, then marks all -// events as not pending and returns the set of previously-pending events. -// -// Precondition: Only one goroutine may call any Wait* method at a time. -func (w *Waiter) WaitAndAckAll() Set { - // Optimization: Skip the atomic store to w.g if an event is already - // pending. Call Pending() first since, in the common case that events are - // not yet pending, this skips an atomic swap on w.r.pending. - if w.r.Pending() != NoEvents { - if p := w.r.PendingAndAckAll(); p != NoEvents { - return p - } - } - - for { - // Indicate that we're preparing to go to sleep. - atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG)) - - // If an event is pending, abort the sleep. - if w.r.Pending() != NoEvents { - if p := w.r.PendingAndAckAll(); p != NoEvents { - atomic.StorePointer(&w.g, nil) - return p - } - } - - // If w.g is still preparingG (i.e. w.NotifyPending() has not been - // called or has not reached atomic.SwapPointer()), go to sleep until - // w.NotifyPending() => goready(). - gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0) - - // Check for pending events. We call PendingAndAckAll() directly now since - // we only expect to be woken after events become pending. - if p := w.r.PendingAndAckAll(); p != NoEvents { - return p - } - } -} - -// Notify marks the given events as pending, possibly unblocking concurrent -// calls to w.Wait() or w.WaitFor(). -func (w *Waiter) Notify(es Set) { - w.r.Notify(es) -} - -// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter -// should not call NotifyPending. -func (w *Waiter) NotifyPending() { - // Optimization: Skip the atomic swap on w.g if there is no sleeping - // goroutine. NotifyPending is called after w.r.Pending() is updated, so - // concurrent and future calls to w.Wait() will observe pending events and - // abort sleeping. - if atomic.LoadPointer(&w.g) == nil { - return - } - // Wake a sleeping G, or prevent a G that is preparing to sleep from doing - // so. Swap is needed here to ensure that only one call to NotifyPending - // calls goready. - if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) { - goready(g, 0) - } -} - -var waiterPool = sync.Pool{ - New: func() interface{} { - w := &Waiter{} - w.Init() - return w - }, -} - -// GetWaiter returns an unused Waiter. PutWaiter should be called to release -// the Waiter once it is no longer needed. -// -// Where possible, users should prefer to associate each goroutine that calls -// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of -// Waiters in hot paths. -func GetWaiter() *Waiter { - return waiterPool.Get().(*Waiter) -} - -// PutWaiter releases an unused Waiter previously returned by GetWaiter. -func PutWaiter(w *Waiter) { - waiterPool.Put(w) -} diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD deleted file mode 100644 index 7d760344a..000000000 --- a/pkg/syserr/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "syserr", - srcs = [ - "host_linux.go", - "netstack.go", - "syserr.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/syserror", - "//pkg/tcpip", - ], -) diff --git a/pkg/syserr/syserr_linux_state_autogen.go b/pkg/syserr/syserr_linux_state_autogen.go new file mode 100755 index 000000000..7fd5a68b8 --- /dev/null +++ b/pkg/syserr/syserr_linux_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build linux + +package syserr diff --git a/pkg/syserr/syserr_state_autogen.go b/pkg/syserr/syserr_state_autogen.go new file mode 100755 index 000000000..712631a64 --- /dev/null +++ b/pkg/syserr/syserr_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package syserr diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD deleted file mode 100644 index b13c15d9b..000000000 --- a/pkg/syserror/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "syserror", - srcs = ["syserror.go"], - visibility = ["//visibility:public"], -) - -go_test( - name = "syserror_test", - srcs = ["syserror_test.go"], - deps = [ - ":syserror", - ], -) diff --git a/pkg/syserror/syserror_state_autogen.go b/pkg/syserror/syserror_state_autogen.go new file mode 100755 index 000000000..456dcf093 --- /dev/null +++ b/pkg/syserror/syserror_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package syserror diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go deleted file mode 100644 index 29719752e..000000000 --- a/pkg/syserror/syserror_test.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syserror_test - -import ( - "errors" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/syserror" -) - -var globalError error - -func returnErrnoAsError() error { - return syscall.EINVAL -} - -func returnError() error { - return syserror.EINVAL -} - -func BenchmarkReturnErrnoAsError(b *testing.B) { - for i := b.N; i > 0; i-- { - returnErrnoAsError() - } -} - -func BenchmarkReturnError(b *testing.B) { - for i := b.N; i > 0; i-- { - returnError() - } -} - -func BenchmarkCompareErrno(b *testing.B) { - j := 0 - for i := b.N; i > 0; i-- { - if globalError == syscall.EINVAL { - j++ - } - } -} - -func BenchmarkCompareError(b *testing.B) { - j := 0 - for i := b.N; i > 0; i-- { - if globalError == syserror.EINVAL { - j++ - } - } -} - -func BenchmarkSwitchErrno(b *testing.B) { - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case syscall.EINVAL: - j += 1 - case syscall.EINTR: - j += 2 - case syscall.EAGAIN: - j += 3 - } - } -} - -func BenchmarkSwitchError(b *testing.B) { - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case syserror.EINVAL: - j += 1 - case syserror.EINTR: - j += 2 - case syserror.EAGAIN: - j += 3 - } - } -} - -type translationTestTable struct { - fn string - errIn error - syscallErrorIn syscall.Errno - expectedBool bool - expectedTranslation syscall.Errno -} - -func TestErrorTranslation(t *testing.T) { - myError := errors.New("My test error") - myError2 := errors.New("Another test error") - testTable := []translationTestTable{ - {"TranslateError", myError, 0, false, 0}, - {"TranslateError", myError2, 0, false, 0}, - {"AddErrorTranslation", myError, syscall.EAGAIN, true, 0}, - {"AddErrorTranslation", myError, syscall.EAGAIN, false, 0}, - {"AddErrorTranslation", myError, syscall.EPERM, false, 0}, - {"TranslateError", myError, 0, true, syscall.EAGAIN}, - {"TranslateError", myError2, 0, false, 0}, - {"AddErrorTranslation", myError2, syscall.EPERM, true, 0}, - {"AddErrorTranslation", myError2, syscall.EPERM, false, 0}, - {"AddErrorTranslation", myError2, syscall.EAGAIN, false, 0}, - {"TranslateError", myError, 0, true, syscall.EAGAIN}, - {"TranslateError", myError2, 0, true, syscall.EPERM}, - } - for _, tt := range testTable { - switch tt.fn { - case "TranslateError": - err, ok := syserror.TranslateError(tt.errIn) - if ok != tt.expectedBool { - t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) - } else if err != tt.expectedTranslation { - t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation) - } - case "AddErrorTranslation": - ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn) - if ok != tt.expectedBool { - t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) - } - default: - t.Fatalf("Unknown function %v", tt.fn) - } - } -} diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD deleted file mode 100644 index 26f7ba86b..000000000 --- a/pkg/tcpip/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tcpip", - srcs = [ - "packet_buffer.go", - "packet_buffer_state.go", - "tcpip.go", - "time_unsafe.go", - "timer.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip/buffer", - "//pkg/waiter", - ], -) - -go_test( - name = "tcpip_test", - size = "small", - srcs = ["tcpip_test.go"], - library = ":tcpip", -) - -go_test( - name = "tcpip_x_test", - size = "small", - srcs = ["timer_test.go"], - deps = [":tcpip"], -) diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD deleted file mode 100644 index e57d45f2a..000000000 --- a/pkg/tcpip/adapters/gonet/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gonet", - srcs = ["gonet.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) - -go_test( - name = "gonet_test", - size = "small", - srcs = ["gonet_test.go"], - library = ":gonet", - tags = ["flaky"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@org_golang_x_net//nettest:go_default_library", - ], -) diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 6e0db2741..6e0db2741 100644..100755 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go diff --git a/pkg/tcpip/adapters/gonet/gonet_state_autogen.go b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go new file mode 100755 index 000000000..7a5c5419e --- /dev/null +++ b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package gonet diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go deleted file mode 100644 index 3c552988a..000000000 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ /dev/null @@ -1,716 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gonet - -import ( - "context" - "fmt" - "io" - "net" - "reflect" - "strings" - "testing" - "time" - - "golang.org/x/net/nettest" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - NICID = 1 -) - -func TestTimeouts(t *testing.T) { - nc := NewTCPConn(nil, nil) - dlfs := []struct { - name string - f func(time.Time) error - }{ - {"SetDeadline", nc.SetDeadline}, - {"SetReadDeadline", nc.SetReadDeadline}, - {"SetWriteDeadline", nc.SetWriteDeadline}, - } - - for _, dlf := range dlfs { - if err := dlf.f(time.Time{}); err != nil { - t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil) - } - } -} - -func newLoopbackStack() (*stack.Stack, *tcpip.Error) { - // Create the stack and add a NIC. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, - }) - - if err := s.CreateNIC(NICID, loopback.New()); err != nil { - return nil, err - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - // IPv4 - { - Destination: header.IPv4EmptySubnet, - NIC: NICID, - }, - - // IPv6 - { - Destination: header.IPv6EmptySubnet, - NIC: NICID, - }, - }) - - return s, nil -} - -type testConnection struct { - wq *waiter.Queue - e *waiter.Entry - ch chan struct{} - ep tcpip.Endpoint -} - -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { - wq := &waiter.Queue{} - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - - entry, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&entry, waiter.EventOut) - - err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { - <-ch - err = ep.GetSockOpt(tcpip.ErrorOption{}) - } - if err != nil { - return nil, err - } - - wq.EventUnregister(&entry) - wq.EventRegister(&entry, waiter.EventIn) - - return &testConnection{wq, &entry, ch, ep}, nil -} - -func (c *testConnection) close() { - c.wq.EventUnregister(c.e) - c.ep.Close() -} - -// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock. -func TestCloseReader(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Fatalf("l.Accept() = %v", err) - } - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - if n != 0 || err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when -// using tcp.Forwarder. -func TestCloseReaderWithForwarder(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - done := make(chan struct{}) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) - } - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestCloseRead(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - _, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - // Endpoint will be closed in deferred s.Close (above). - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewTCPConn(tc.wq, tc.ep) - - if err := c.CloseRead(); err != nil { - t.Errorf("c.CloseRead() = %v", err) - } - - buf := make([]byte, 256) - if n, err := c.Read(buf); err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err) - } - - if n, err := c.Write([]byte("abc123")); n != 6 || err != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err) - } -} - -func TestCloseWrite(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - n, e := c.Read(make([]byte, 256)) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e) - } - - if n, e = c.Write([]byte("abc123")); n != 6 || e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewTCPConn(tc.wq, tc.ep) - - if err := c.CloseWrite(); err != nil { - t.Errorf("c.CloseWrite() = %v", err) - } - - buf := make([]byte, 256) - n, err := c.Read(buf) - if err != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err) - } - - n, err = c.Write([]byte("abc123")) - got, ok := err.(*net.OpError) - want := "endpoint is closed for send" - if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) { - t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want) - } -} - -func TestUDPForwarder(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - done := make(chan struct{}) - fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil { - t.Errorf("c.Read() = %v", e) - } - - if _, e := c.Write(buf[:n]); e != nil { - t.Errorf("c.Write() = %v", e) - } - }) - s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - - c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - sent := "abc123" - sendAddr := fullToUDPAddr(addr1) - if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - - buf := make([]byte, 256) - n, recvAddr, err := c2.ReadFrom(buf) - if err != nil || recvAddr.String() != sendAddr.String() { - t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil) - } -} - -// TestDeadlineChange tests that changing the deadline affects currently blocked reads. -func TestDeadlineChange(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Fatalf("l.Accept() = %v", err) - } - - c.SetDeadline(time.Now().Add(time.Minute)) - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.SetDeadline(time.Now().Add(time.Millisecond * 10)) - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := "i/o timeout" - if n != 0 || !ok || got.Err == nil || got.Err.Error() != want { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(time.Millisecond * 500): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 4):", err) - } - c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - sendAddr := fullToUDPAddr(addr2) - if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, recvAddr, err := c2.ReadFrom(recv) - if err != nil || n != len(recv) { - t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) { - t.Errorf("got recvAddr = %v, want = %v", recvAddr, want) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func TestConnectedPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 4):", err) - } - c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, err := c1.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func makePipe() (c1, c2 net.Conn, stop func(), err error) { - s, e := newLoopbackStack() - if e != nil { - return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e) - } - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) - if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) - } - - c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) - if err != nil { - l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) - } - - c2, err = l.Accept() - if err != nil { - l.Close() - c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) - } - - stop = func() { - c1.Close() - c2.Close() - s.Close() - s.Wait() - } - - if err := l.Close(); err != nil { - stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) - } - - return c1, c2, stop, nil -} - -func TestTCPConnTransfer(t *testing.T) { - c1, c2, _, err := makePipe() - if err != nil { - t.Fatal(err) - } - defer func() { - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } - }() - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - const sent = "abc123" - - tests := []struct { - name string - c1 net.Conn - c2 net.Conn - }{ - {"connected to accepted", c1, c2}, - {"accepted to connected", c2, c1}, - } - - for _, test := range tests { - if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil) - continue - } - - recv := make([]byte, len(sent)) - n, err := test.c2.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil) - continue - } - - if recv := string(recv); recv != sent { - t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent) - } - } -} - -func TestTCPDialError(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - - _, err := DialTCP(s, addr, ipv4.ProtocolNumber) - got, ok := err.(*net.OpError) - want := tcpip.ErrNoRoute - if !ok || got.Err.Error() != want.String() { - t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute) - } -} - -func TestDialContextTCPCanceled(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled) - } -} - -func TestDialContextTCPTimeout(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - time.Sleep(time.Second) - r.Complete(true) - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - ctx := context.Background() - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) - defer cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded) - } -} - -func TestNetTest(t *testing.T) { - nettest.TestConn(t, makePipe) -} diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD deleted file mode 100644 index 563bc78ea..000000000 --- a/pkg/tcpip/buffer/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "buffer", - srcs = [ - "prependable.go", - "view.go", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "buffer_test", - size = "small", - srcs = ["view_test.go"], - library = ":buffer", -) diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go new file mode 100755 index 000000000..954487771 --- /dev/null +++ b/pkg/tcpip/buffer/buffer_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package buffer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *VectorisedView) beforeSave() {} +func (x *VectorisedView) save(m state.Map) { + x.beforeSave() + m.Save("views", &x.views) + m.Save("size", &x.size) +} + +func (x *VectorisedView) afterLoad() {} +func (x *VectorisedView) load(m state.Map) { + m.Load("views", &x.views) + m.Load("size", &x.size) +} + +func init() { + state.Register("pkg/tcpip/buffer.VectorisedView", (*VectorisedView)(nil), state.Fns{Save: (*VectorisedView).save, Load: (*VectorisedView).load}) +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go deleted file mode 100644 index ebc3a17b7..000000000 --- a/pkg/tcpip/buffer/view_test.go +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package buffer_test contains tests for the VectorisedView type. -package buffer - -import ( - "reflect" - "testing" -) - -// copy returns a deep-copy of the vectorised view. -func (vv VectorisedView) copy() VectorisedView { - uu := VectorisedView{ - views: make([]View, 0, len(vv.views)), - size: vv.size, - } - for _, v := range vv.views { - uu.views = append(uu.views, append(View(nil), v...)) - } - return uu -} - -// vv is an helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) VectorisedView { - views := make([]View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return NewVectorisedView(size, views) -} - -var capLengthTestCases = []struct { - comment string - in VectorisedView - length int - want VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Case spanning across two Views", - in: vv(4, "123", "4"), - length: 2, - want: vv(2, "12"), - }, - { - comment: "Corner case with negative length", - in: vv(1, "1"), - length: -1, - want: vv(0), - }, - { - comment: "Corner case with length = 0", - in: vv(3, "12", "3"), - length: 0, - want: vv(0), - }, - { - comment: "Corner case with length = size", - in: vv(1, "1"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Corner case with length > size", - in: vv(1, "1"), - length: 2, - want: vv(1, "1"), - }, -} - -func TestCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - orig := c.in.copy() - c.in.CapLength(c.length) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v", - c.comment, c.length, orig, c.in, c.want) - } - } -} - -var trimFrontTestCases = []struct { - comment string - in VectorisedView - count int - want VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case where we trim an entire View", - in: vv(2, "1", "2"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: vv(1, "3"), - }, - { - comment: "Corner case with negative count", - in: vv(1, "1"), - count: -1, - want: vv(1, "1"), - }, - { - comment: " Corner case with count = 0", - in: vv(1, "1"), - count: 0, - want: vv(1, "1"), - }, - { - comment: "Corner case with count = size", - in: vv(1, "1"), - count: 1, - want: vv(0), - }, - { - comment: "Corner case with count > size", - in: vv(1, "1"), - count: 2, - want: vv(0), - }, -} - -func TestTrimFront(t *testing.T) { - for _, c := range trimFrontTestCases { - orig := c.in.copy() - c.in.TrimFront(c.count) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v", - c.comment, c.count, orig, c.in, c.want) - } - } -} - -var toViewCases = []struct { - comment string - in VectorisedView - want View -}{ - { - comment: "Simple case", - in: vv(2, "12"), - want: []byte("12"), - }, - { - comment: "Case with multiple views", - in: vv(2, "1", "2"), - want: []byte("12"), - }, - { - comment: "Empty case", - in: vv(0), - want: []byte(""), - }, -} - -func TestToView(t *testing.T) { - for _, c := range toViewCases { - got := c.in.ToView() - if !reflect.DeepEqual(got, c.want) { - t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v", - c.comment, c.in, got, c.want) - } - } -} - -var toCloneCases = []struct { - comment string - inView VectorisedView - inBuffer []View -}{ - { - comment: "Simple case", - inView: vv(1, "1"), - inBuffer: make([]View, 1), - }, - { - comment: "Case with multiple views", - inView: vv(2, "1", "2"), - inBuffer: make([]View, 2), - }, - { - comment: "Case with buffer too small", - inView: vv(2, "1", "2"), - inBuffer: make([]View, 1), - }, - { - comment: "Case with buffer larger than needed", - inView: vv(1, "1"), - inBuffer: make([]View, 2), - }, - { - comment: "Case with nil buffer", - inView: vv(1, "1"), - inBuffer: nil, - }, -} - -func TestToClone(t *testing.T) { - for _, c := range toCloneCases { - t.Run(c.comment, func(t *testing.T) { - got := c.inView.Clone(c.inBuffer) - if !reflect.DeepEqual(got, c.inView) { - t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v", - c.inView, c.inBuffer, got, c.inView) - } - }) - } -} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD deleted file mode 100644 index ed434807f..000000000 --- a/pkg/tcpip/checker/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "checker", - testonly = 1, - srcs = ["checker.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go deleted file mode 100644 index 8dc0f7c0e..000000000 --- a/pkg/tcpip/checker/checker.go +++ /dev/null @@ -1,872 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package checker provides helper functions to check networking packets for -// validity. -package checker - -import ( - "encoding/binary" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -// NetworkChecker is a function to check a property of a network packet. -type NetworkChecker func(*testing.T, []header.Network) - -// TransportChecker is a function to check a property of a transport packet. -type TransportChecker func(*testing.T, header.Transport) - -// ControlMessagesChecker is a function to check a property of ancillary data. -type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) - -// IPv4 checks the validity and properties of the given IPv4 packet. It is -// expected to be used in conjunction with other network checkers for specific -// properties. For example, to check the source and destination address, one -// would call: -// -// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) -func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv4 := header.IPv4(b) - - if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") - } - - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) - } - - for _, f := range checkers { - f(t, []header.Network{ipv4}) - } - if t.Failed() { - t.FailNow() - } -} - -// IPv6 checks the validity and properties of the given IPv6 packet. The usage -// is similar to IPv4. -func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") - } - - for _, f := range checkers { - f(t, []header.Network{ipv6}) - } - if t.Failed() { - t.FailNow() - } -} - -// SrcAddr creates a checker that checks the source address. -func SrcAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].SourceAddress(); a != addr { - t.Errorf("Bad source address, got %v, want %v", a, addr) - } - } -} - -// DstAddr creates a checker that checks the destination address. -func DstAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].DestinationAddress(); a != addr { - t.Errorf("Bad destination address, got %v, want %v", a, addr) - } - } -} - -// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). -func TTL(ttl uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - var v uint8 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TTL() - case header.IPv6: - v = ip.HopLimit() - } - if v != ttl { - t.Fatalf("Bad TTL, got %v, want %v", v, ttl) - } - } -} - -// PayloadLen creates a checker that checks the payload length. -func PayloadLen(plen int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if l := len(h[0].Payload()); l != plen { - t.Errorf("Bad payload length, got %v, want %v", l, plen) - } - } -} - -// FragmentOffset creates a checker that checks the FragmentOffset field. -func FragmentOffset(offset uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this of IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got %v, want %v", v, offset) - } - } - } -} - -// FragmentFlags creates a checker that checks the fragment flags field. -func FragmentFlags(flags uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this of IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got %v, want %v", v, flags) - } - } - } -} - -// ReceiveTClass creates a checker that checks the TCLASS field in -// ControlMessages. -func ReceiveTClass(want uint32) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTClass { - t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) - } - if got := cm.TClass; got != want { - t.Fatalf("got cm.TClass = %d, want %d", got, want) - } - } -} - -// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. -func ReceiveTOS(want uint8) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTOS { - t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) - } - if got := cm.TOS; got != want { - t.Fatalf("got cm.TOS = %d, want %d", got, want) - } - } -} - -// TOS creates a checker that checks the TOS field. -func TOS(tos uint8, label uint32) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) - } - } -} - -// Raw creates a checker that checks the bytes of payload. -// The checker always checks the payload of the last network header. -// For instance, in case of IPv6 fragments, the payload that will be checked -// is the one containing the actual data that the packet is carrying, without -// the bytes added by the IPv6 fragmentation. -func Raw(want []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// IPv6Fragment creates a checker that validates an IPv6 fragment. -func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) - } - - ipv6Frag := header.IPv6Fragment(h[0].Payload()) - if !ipv6Frag.IsValid() { - t.Error("Not a valid IPv6 fragment") - } - - for _, f := range checkers { - f(t, []header.Network{h[0], ipv6Frag}) - } - if t.Failed() { - t.FailNow() - } - } -} - -// TCP creates a checker that checks that the transport protocol is TCP and -// potentially additional transport header fields. -func TCP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - first := h[0] - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) - } - - // Verify the checksum. - tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) - } - - // Run the transport checkers. - for _, f := range checkers { - f(t, tcp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// UDP creates a checker that checks that the transport protocol is UDP and -// potentially additional transport header fields. -func UDP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) - } - - udp := header.UDP(last.Payload()) - for _, f := range checkers { - f(t, udp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// SrcPort creates a checker that checks the source port. -func SrcPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got %v, want %v", p, port) - } - } -} - -// DstPort creates a checker that checks the destination port. -func DstPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got %v, want %v", p, port) - } - } -} - -// SeqNum creates a checker that checks the sequence number. -func SeqNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got %v, want %v", s, seq) - } - } -} - -// AckNum creates a checker that checks the ack number. -func AckNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got %v, want %v", s, seq) - } - } -} - -// Window creates a checker that checks the tcp window. -func Window(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) - } - } -} - -// TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if f := tcp.Flags(); f != flags { - t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) - } - } -} - -// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the -// given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) - } - } -} - -// TCPSynOptions creates a checker that checks the presence of TCP options in -// SYN segments. -// -// If wndscale is negative, the window scale option must not be present. -func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := tcp.Options() - limit := len(opts) - foundMSS := false - foundWS := false - foundTS := false - foundSACKPermitted := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionMSS: - v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) - if wantOpts.MSS != v { - t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) - } - foundMSS = true - i += 4 - case header.TCPOptionWS: - if wantOpts.WS < 0 { - t.Error("WS present when it shouldn't be") - } - v := int(opts[i+2]) - if v != wantOpts.WS { - t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) - } - foundWS = true - i += 3 - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = uint32(0) - if tcp.Flags()&header.TCPFlagAck != 0 { - // If the syn is an SYN-ACK then read - // the tsEcr value as well. - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - } - foundTS = true - i += 10 - case header.TCPOptionSACKPermitted: - if i+1 >= limit { - t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) - } - if opts[i+1] != 2 { - t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) - } - foundSACKPermitted = true - i += 2 - - default: - i += int(opts[i+1]) - } - } - - if !foundMSS { - t.Errorf("MSS option not found. Options: %x", opts) - } - - if !foundWS && wantOpts.WS >= 0 { - t.Errorf("WS option not found. Options: %x", opts) - } - if wantOpts.TS && !foundTS { - t.Errorf("TS option not found. Options: %x", opts) - } - if foundTS && tsVal == 0 { - t.Error("TS option specified but the timestamp value is zero") - } - if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) - } - if wantOpts.SACKPermitted && !foundSACKPermitted { - t.Errorf("SACKPermitted option not found. Options: %x", opts) - } - } -} - -// TCPTimestampChecker creates a checker that validates that a TCP segment has a -// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and -// wantTSEcr values with those in the TCP segment (if present). -// -// If wantTSVal or wantTSEcr is zero then the corresponding comparison is -// skipped. -func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := []byte(tcp.Options()) - limit := len(opts) - foundTS := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - foundTS = true - i += 10 - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - return - } - l := int(opts[i+1]) - if i < 2 || i+l > limit { - return - } - i += l - } - } - - if wantTS != foundTS { - t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) - } - if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) - } - if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) - } - } -} - -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not -// contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - -// TCPSACKBlockChecker creates a checker that verifies that the segment does -// contain the specified SACK blocks in the TCP options. -func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - var gotSACKBlocks []header.SACKBlock - - opts := []byte(tcp.Options()) - limit := len(opts) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionSACK: - if i+2 > limit { - // Malformed SACK block. - t.Errorf("malformed SACK option in options: %v", opts) - } - sackOptionLen := int(opts[i+1]) - if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { - // Malformed SACK block. - t.Errorf("malformed SACK option length in options: %v", opts) - } - numBlocks := sackOptionLen / 8 - for j := 0; j < numBlocks; j++ { - start := binary.BigEndian.Uint32(opts[i+2+j*8:]) - end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) - gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ - Start: seqnum.Value(start), - End: seqnum.Value(end), - }) - } - i += sackOptionLen - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - break - } - l := int(opts[i+1]) - if l < 2 || i+l > limit { - break - } - i += l - } - } - - if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) - } - } -} - -// Payload creates a checker that checks the payload. -func Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - if got := h.Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and -// potentially additional ICMPv4 header fields. -func ICMPv4(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) - } - - icmp := header.ICMPv4(last.Payload()) - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv4Type creates a checker that checks the ICMPv4 Type field. -func ICMPv4Type(want header.ICMPv4Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) - } - if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) - } - } -} - -// ICMPv4Code creates a checker that checks the ICMPv4 Code field. -func ICMPv4Code(want byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) - } - if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) - } - } -} - -// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and -// potentially additional ICMPv6 header fields. -// -// ICMPv6 will validate the checksum field before calling checkers. -func ICMPv6(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) - } - - icmp := header.ICMPv6(last.Payload()) - if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { - t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv6Type creates a checker that checks the ICMPv6 Type field. -func ICMPv6Type(want header.ICMPv6Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) - } - if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) - } - } -} - -// ICMPv6Code creates a checker that checks the ICMPv6 Code field. -func ICMPv6Code(want byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) - } - if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) - } - } -} - -// NDP creates a checker that checks that the packet contains a valid NDP -// message for type of ty, with potentially additional checks specified by -// checkers. -// -// checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDP message as far as the size of the message (minSize) is concerned. The -// values within the message are up to checkers to validate. -func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // Check normal ICMPv6 first. - ICMPv6( - ICMPv6Type(msgType), - ICMPv6Code(0))(t, h) - - last := h[len(h)-1] - - icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.NDPPayload()); got < minSize { - t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// NDPNS creates a checker that checks that the packet contains a valid NDP -// Neighbor Solicitation message (as per the raw wire format), with potentially -// additional checks specified by checkers. -// -// checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNS message as far as the size of the messages concerned. The values within -// the message are up to checkers to validate. -func NDPNS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) -} - -// NDPNSTargetAddress creates a checker that checks the Target Address field of -// a header.NDPNeighborSolicit. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSTargetAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) - - if got := ns.TargetAddress(); got != want { - t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want) - } - } -} - -// ndpOptions checks that optsBuf only contains opts. -func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { - t.Helper() - - it, err := optsBuf.Iter(true) - if err != nil { - t.Errorf("optsBuf.Iter(true): %s", err) - return - } - - i := 0 - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - t.Fatalf("unexpected error when iterating over NDP options: %s", err) - } - if done { - break - } - - if i >= len(opts) { - t.Errorf("got unexpected option: %s", opt) - continue - } - - switch wantOpt := opts[i].(type) { - case header.NDPSourceLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - default: - t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) - } - - i++ - } - - if missing := opts[i:]; len(missing) > 0 { - t.Errorf("missing options: %s", missing) - } -} - -// NDPNSOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Neighbor Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) - ndpOptions(t, ns.Options(), opts) - } -} - -// NDPRS creates a checker that checks that the packet contains a valid NDP -// Router Solicitation message (as per the raw wire format). -// -// checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPRS as far as the size of the message is concerned. The values within the -// message are up to checkers to validate. -func NDPRS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) -} - -// NDPRSOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Router Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPRS message as far as the size is concerned. -func NDPRSOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - rs := header.NDPRouterSolicit(icmp.NDPPayload()) - ndpOptions(t, rs.Options(), opts) - } -} diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD deleted file mode 100644 index ff2719291..000000000 --- a/pkg/tcpip/hash/jenkins/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "jenkins", - srcs = ["jenkins.go"], - visibility = ["//visibility:public"], -) - -go_test( - name = "jenkins_test", - size = "small", - srcs = [ - "jenkins_test.go", - ], - library = ":jenkins", -) diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go new file mode 100755 index 000000000..216cc5a2e --- /dev/null +++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package jenkins diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go deleted file mode 100644 index 4c78b5808..000000000 --- a/pkg/tcpip/hash/jenkins/jenkins_test.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package jenkins - -import ( - "bytes" - "encoding/binary" - "hash" - "hash/fnv" - "math" - "testing" -) - -func TestGolden32(t *testing.T) { - var golden32 = []struct { - out []byte - in string - }{ - {[]byte{0x00, 0x00, 0x00, 0x00}, ""}, - {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"}, - {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"}, - {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"}, - } - - hash := New32() - - for _, g := range golden32 { - hash.Reset() - done, error := hash.Write([]byte(g.in)) - if error != nil { - t.Fatalf("write error: %s", error) - } - if done != len(g.in) { - t.Fatalf("wrote only %d out of %d bytes", done, len(g.in)) - } - if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) { - t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out) - } - } -} - -func TestIntegrity32(t *testing.T) { - data := []byte{'1', '2', 3, 4, 5} - - h := New32() - h.Write(data) - sum := h.Sum(nil) - - if size := h.Size(); size != len(sum) { - t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum)) - } - - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data[:2]) - h.Write(data[2:]) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a) - } - - sum32 := h.(hash.Hash32).Sum32() - if sum32 != binary.BigEndian.Uint32(sum) { - t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32) - } -} - -func BenchmarkJenkins32KB(b *testing.B) { - h := New32() - - b.SetBytes(1024) - data := make([]byte, 1024) - for i := range data { - data[i] = byte(i) - } - in := make([]byte, 0, h.Size()) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - h.Reset() - h.Write(data) - h.Sum(in) - } -} - -func BenchmarkFnv32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - - h := fnv.New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - c := 0 - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - if c == 0 { - b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N) - } - c++ - } - } - if c > 0 { - b.Logf("Unbalanced buckets: %d", c) - } - } -} - -func BenchmarkSum32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := Sum32(0) - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} - -func BenchmarkNew32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD deleted file mode 100644 index 9da0d71f8..000000000 --- a/pkg/tcpip/header/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "header", - srcs = [ - "arp.go", - "checksum.go", - "eth.go", - "gue.go", - "icmpv4.go", - "icmpv6.go", - "interfaces.go", - "ipv4.go", - "ipv6.go", - "ipv6_fragment.go", - "ndp_neighbor_advert.go", - "ndp_neighbor_solicit.go", - "ndp_options.go", - "ndp_router_advert.go", - "ndp_router_solicit.go", - "tcp.go", - "udp.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/seqnum", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "header_x_test", - size = "small", - srcs = [ - "checksum_test.go", - "ipv6_test.go", - "ipversion_test.go", - "tcp_test.go", - ], - deps = [ - ":header", - "//pkg/rand", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) - -go_test( - name = "header_test", - size = "small", - srcs = [ - "eth_test.go", - "ndp_test.go", - ], - library = ":header", - deps = [ - "//pkg/tcpip", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go deleted file mode 100644 index 309403482..000000000 --- a/pkg/tcpip/header/checksum_test.go +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package header provides the implementation of the encoding and decoding of -// network protocol headers. -package header_test - -import ( - "fmt" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestChecksumVVWithOffset(t *testing.T) { - testCases := []struct { - name string - vv buffer.VectorisedView - off, size int - initial uint16 - want uint16 - }{ - { - name: "empty", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 0, - want: 0, - }, - { - name: "OneView", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 5, - want: 1294, - }, - { - name: "TwoViews", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 0, - size: 11, - want: 33819, - }, - { - name: "TwoViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 1, - size: 11, - want: 33819, - }, - { - name: "ThreeViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 7, - size: 11, - want: 33819, - }, - { - name: "ThreeViewsWithInitial", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), - }), - initial: 77, - off: 7, - size: 11, - want: 33896, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { - t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) - } - v := tc.vv.ToView() - v.TrimFront(tc.off) - v.CapLength(tc.size) - if got, want := header.Checksum(v, tc.initial), tc.want; got != want { - t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) - } - }) - } -} - -func TestChecksum(t *testing.T) { - var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} - type testCase struct { - buf []byte - initial uint16 - csumOrig uint16 - csumNew uint16 - } - testCases := make([]testCase, 100000) - // Ensure same buffer generation for test consistency. - rnd := rand.New(rand.NewSource(42)) - for i := range testCases { - testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) - testCases[i].initial = uint16(rnd.Intn(65536)) - rnd.Read(testCases[i].buf) - } - - for i := range testCases { - testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) - testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) - if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { - t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) - } - } -} - -func BenchmarkChecksum(b *testing.B) { - var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} - - checkSumImpls := []struct { - fn func([]byte, uint16) uint16 - name string - }{ - {header.ChecksumOld, fmt.Sprintf("checksum_old")}, - {header.Checksum, fmt.Sprintf("checksum")}, - } - - for _, csumImpl := range checkSumImpls { - // Ensure same buffer generation for test consistency. - rnd := rand.New(rand.NewSource(42)) - for _, bufSz := range bufSizes { - b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { - tc := struct { - buf []byte - initial uint16 - csum uint16 - }{ - buf: make([]byte, bufSz), - initial: uint16(rnd.Intn(65536)), - } - rnd.Read(tc.buf) - b.ResetTimer() - for i := 0; i < b.N; i++ { - tc.csum = csumImpl.fn(tc.buf, tc.initial) - } - }) - } - } -} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go deleted file mode 100644 index 7a0014ad9..000000000 --- a/pkg/tcpip/header/eth_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -func TestIsValidUnicastEthernetAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.LinkAddress - expected bool - }{ - { - "Nil", - tcpip.LinkAddress([]byte(nil)), - false, - }, - { - "Empty", - tcpip.LinkAddress(""), - false, - }, - { - "InvalidLength", - tcpip.LinkAddress("\x01\x02\x03"), - false, - }, - { - "Unspecified", - unspecifiedEthernetAddress, - false, - }, - { - "Multicast", - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - false, - }, - { - "Valid", - tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), - true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected { - t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected) - } - }) - } -} - -func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4 Multicast without 24th bit set", - addr: "\xe0\x7e\xdc\xba", - expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", - }, - { - name: "IPv4 Multicast with 24th bit set", - addr: "\xe0\xfe\xdc\xba", - expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr { - t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr) - } - }) - } -} - -func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) { - addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a") - if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want { - t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want) - } -} diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go new file mode 100755 index 000000000..015d7e12a --- /dev/null +++ b/pkg/tcpip/header/header_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +package header + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SACKBlock) beforeSave() {} +func (x *SACKBlock) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *SACKBlock) afterLoad() {} +func (x *SACKBlock) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *TCPOptions) beforeSave() {} +func (x *TCPOptions) save(m state.Map) { + x.beforeSave() + m.Save("TS", &x.TS) + m.Save("TSVal", &x.TSVal) + m.Save("TSEcr", &x.TSEcr) + m.Save("SACKBlocks", &x.SACKBlocks) +} + +func (x *TCPOptions) afterLoad() {} +func (x *TCPOptions) load(m state.Map) { + m.Load("TS", &x.TS) + m.Load("TSVal", &x.TSVal) + m.Load("TSEcr", &x.TSEcr) + m.Load("SACKBlocks", &x.SACKBlocks) +} + +func init() { + state.Register("pkg/tcpip/header.SACKBlock", (*SACKBlock)(nil), state.Fns{Save: (*SACKBlock).save, Load: (*SACKBlock).load}) + state.Register("pkg/tcpip/header.TCPOptions", (*TCPOptions)(nil), state.Fns{Save: (*TCPOptions).save, Load: (*TCPOptions).load}) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go deleted file mode 100644 index 426a873b1..000000000 --- a/pkg/tcpip/header/ipv6_test.go +++ /dev/null @@ -1,417 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "bytes" - "crypto/sha256" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") -) - -func TestEthernetAdddressToModifiedEUI64(t *testing.T) { - expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} - - if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { - t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) - } - - var buf [header.IIDSize]byte - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) - if diff := cmp.Diff(expectedIID, buf); diff != "" { - t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) - } -} - -func TestLinkLocalAddr(t *testing.T) { - if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { - t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) - } -} - -func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte - if n, err := rand.Read(secretKeyBuf[:]); err != nil { - t.Fatalf("rand.Read(_): %s", err) - } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) - } - - tests := []struct { - name string - prefix tcpip.Subnet - nicName string - dadCounter uint8 - secretKey []byte - }{ - { - name: "SecretKey of minimum size", - prefix: header.IPv6LinkLocalPrefix.Subnet(), - nicName: "eth0", - dadCounter: 0, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], - }, - { - name: "SecretKey of less than minimum size", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "eth10", - dadCounter: 1, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], - }, - { - name: "SecretKey of more than minimum size", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "eth11", - dadCounter: 2, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], - }, - { - name: "Nil SecretKey and empty nicName", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "", - dadCounter: 3, - secretKey: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - h := sha256.New() - h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address])) - h.Write([]byte(test.nicName)) - h.Write([]byte{test.dadCounter}) - if k := test.secretKey; k != nil { - h.Write(k) - } - var hashSum [sha256.Size]byte - h.Sum(hashSum[:0]) - want := hashSum[:header.IIDSize] - - // Passing a nil buffer should result in a new buffer returned with the - // IID. - if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { - t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) - } - - // Passing a buffer with sufficient capacity for the IID should populate - // the buffer provided. - var iidBuf [header.IIDSize]byte - if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { - t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) - } - if got := iidBuf[:]; !bytes.Equal(got, want) { - t.Errorf("got iidBuf = %x, want = %x", got, want) - } - }) - } -} - -func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte - if n, err := rand.Read(secretKeyBuf[:]); err != nil { - t.Fatalf("rand.Read(_): %s", err) - } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) - } - - prefix := header.IPv6LinkLocalPrefix.Subnet() - - tests := []struct { - name string - prefix tcpip.Subnet - nicName string - dadCounter uint8 - secretKey []byte - }{ - { - name: "SecretKey of minimum size", - nicName: "eth0", - dadCounter: 0, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], - }, - { - name: "SecretKey of less than minimum size", - nicName: "eth10", - dadCounter: 1, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], - }, - { - name: "SecretKey of more than minimum size", - nicName: "eth11", - dadCounter: 2, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], - }, - { - name: "Nil SecretKey and empty nicName", - nicName: "", - dadCounter: 3, - secretKey: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - addrBytes := [header.IPv6AddressSize]byte{ - 0: 0xFE, - 1: 0x80, - } - - want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier( - addrBytes[:header.IIDOffsetInIPv6Address], - prefix, - test.nicName, - test.dadCounter, - test.secretKey, - )) - - if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want { - t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want) - } - }) - } -} - -func TestIsV6UniqueLocalAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Unique 1", - addr: uniqueLocalAddr1, - expected: true, - }, - { - name: "Valid Unique 2", - addr: uniqueLocalAddr1, - expected: true, - }, - { - name: "Link Local", - addr: linkLocalAddr, - expected: false, - }, - { - name: "Global", - addr: globalAddr, - expected: false, - }, - { - name: "IPv4", - addr: "\x01\x02\x03\x04", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - -func TestIsV6LinkLocalMulticastAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Link Local Multicast", - addr: linkLocalMulticastAddr, - expected: true, - }, - { - name: "Valid Link Local Multicast with flags", - addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - expected: true, - }, - { - name: "Link Local Unicast", - addr: linkLocalAddr, - expected: false, - }, - { - name: "IPv4 Multicast", - addr: "\xe0\x00\x00\x01", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - -func TestIsV6LinkLocalAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Link Local Unicast", - addr: linkLocalAddr, - expected: true, - }, - { - name: "Link Local Multicast", - addr: linkLocalMulticastAddr, - expected: false, - }, - { - name: "Unique Local", - addr: uniqueLocalAddr1, - expected: false, - }, - { - name: "Global", - addr: globalAddr, - expected: false, - }, - { - name: "IPv4 Link Local", - addr: "\xa9\xfe\x00\x01", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - -func TestScopeForIPv6Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - scope header.IPv6AddressScope - err *tcpip.Error - }{ - { - name: "Unique Local", - addr: uniqueLocalAddr1, - scope: header.UniqueLocalScope, - err: nil, - }, - { - name: "Link Local Unicast", - addr: linkLocalAddr, - scope: header.LinkLocalScope, - err: nil, - }, - { - name: "Link Local Multicast", - addr: linkLocalMulticastAddr, - scope: header.LinkLocalScope, - err: nil, - }, - { - name: "Global", - addr: globalAddr, - scope: header.GlobalScope, - err: nil, - }, - { - name: "IPv4", - addr: "\x01\x02\x03\x04", - scope: header.GlobalScope, - err: tcpip.ErrBadAddress, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := header.ScopeForIPv6Address(test.addr) - if err != test.err { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err) - } - if got != test.scope { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) - } - }) - } -} - -func TestSolicitedNodeAddr(t *testing.T) { - tests := []struct { - addr tcpip.Address - want tcpip.Address - }{ - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", - }, - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", - }, - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03", - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { - if got := header.SolicitedNodeAddr(test.addr); got != test.want { - t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want) - } - }) - } -} diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go deleted file mode 100644 index b5540bf66..000000000 --- a/pkg/tcpip/header/ipversion_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestIPv4(t *testing.T) { - b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) - - const want = header.IPv4Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestIPv6(t *testing.T) { - b := header.IPv6(make([]byte, header.IPv6MinimumSize)) - b.Encode(&header.IPv6Fields{}) - - const want = header.IPv6Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestOtherVersion(t *testing.T) { - const want = header.IPv4Version + header.IPv6Version - b := make([]byte, 1) - b[0] = want << 4 - - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestTooShort(t *testing.T) { - b := make([]byte, 1) - b[0] = (header.IPv4Version + header.IPv6Version) << 4 - - // Get the version of a zero-length slice. - const want = -1 - if v := header.IPVersion(b[:0]); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } - - // Get the version of a nil slice. - if v := header.IPVersion(nil); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go index 505c92668..505c92668 100644..100755 --- a/pkg/tcpip/header/ndp_neighbor_advert.go +++ b/pkg/tcpip/header/ndp_neighbor_advert.go diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go index 3a1b8e139..3a1b8e139 100644..100755 --- a/pkg/tcpip/header/ndp_neighbor_solicit.go +++ b/pkg/tcpip/header/ndp_neighbor_solicit.go diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index e6a6ad39b..e6a6ad39b 100644..100755 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go index bf7610863..bf7610863 100644..100755 --- a/pkg/tcpip/header/ndp_router_advert.go +++ b/pkg/tcpip/header/ndp_router_advert.go diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go index 9e67ba95d..9e67ba95d 100644..100755 --- a/pkg/tcpip/header/ndp_router_solicit.go +++ b/pkg/tcpip/header/ndp_router_solicit.go diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go deleted file mode 100644 index 1cb9f5dc8..000000000 --- a/pkg/tcpip/header/ndp_test.go +++ /dev/null @@ -1,937 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "bytes" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" -) - -// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. -func TestNDPNeighborSolicit(t *testing.T) { - b := []byte{ - 0, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - } - - // Test getting the Target Address. - ns := NDPNeighborSolicit(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") - if got := ns.TargetAddress(); got != addr { - t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) - } - - // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") - ns.SetTargetAddress(addr2) - if got := ns.TargetAddress(); got != addr2 { - t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) - } - // Make sure the address got updated in the backing buffer. - if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { - t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) - } -} - -// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert. -func TestNDPNeighborAdvert(t *testing.T) { - b := []byte{ - 160, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - } - - // Test getting the Target Address. - na := NDPNeighborAdvert(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") - if got := na.TargetAddress(); got != addr { - t.Errorf("got TargetAddress = %s, want %s", got, addr) - } - - // Test getting the Router Flag. - if got := na.RouterFlag(); !got { - t.Errorf("got RouterFlag = false, want = true") - } - - // Test getting the Solicited Flag. - if got := na.SolicitedFlag(); got { - t.Errorf("got SolicitedFlag = true, want = false") - } - - // Test getting the Override Flag. - if got := na.OverrideFlag(); !got { - t.Errorf("got OverrideFlag = false, want = true") - } - - // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") - na.SetTargetAddress(addr2) - if got := na.TargetAddress(); got != addr2 { - t.Errorf("got TargetAddress = %s, want %s", got, addr2) - } - // Make sure the address got updated in the backing buffer. - if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { - t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) - } - - // Test updating the Router Flag. - na.SetRouterFlag(false) - if got := na.RouterFlag(); got { - t.Errorf("got RouterFlag = true, want = false") - } - - // Test updating the Solicited Flag. - na.SetSolicitedFlag(true) - if got := na.SolicitedFlag(); !got { - t.Errorf("got SolicitedFlag = false, want = true") - } - - // Test updating the Override Flag. - na.SetOverrideFlag(false) - if got := na.OverrideFlag(); got { - t.Errorf("got OverrideFlag = true, want = false") - } - - // Make sure flags got updated in the backing buffer. - if got := b[ndpNAFlagsOffset]; got != 64 { - t.Errorf("got flags byte = %d, want = 64") - } -} - -func TestNDPRouterAdvert(t *testing.T) { - b := []byte{ - 64, 128, 1, 2, - 3, 4, 5, 6, - 7, 8, 9, 10, - } - - ra := NDPRouterAdvert(b) - - if got := ra.CurrHopLimit(); got != 64 { - t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) - } - - if got := ra.ManagedAddrConfFlag(); !got { - t.Errorf("got ManagedAddrConfFlag = false, want = true") - } - - if got := ra.OtherConfFlag(); got { - t.Errorf("got OtherConfFlag = true, want = false") - } - - if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) - } - - if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) - } - - if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) - } -} - -// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the -// Ethernet address from an NDPSourceLinkLayerAddressOption. -func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) { - tests := []struct { - name string - buf []byte - expected tcpip.LinkAddress - }{ - { - "ValidMAC", - []byte{1, 2, 3, 4, 5, 6}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - { - "SLLBodyTooShort", - []byte{1, 2, 3, 4, 5}, - tcpip.LinkAddress([]byte(nil)), - }, - { - "SLLBodyLargerThanNeeded", - []byte{1, 2, 3, 4, 5, 6, 7, 8}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - sll := NDPSourceLinkLayerAddressOption(test.buf) - if got := sll.EthernetAddress(); got != test.expected { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected) - } - }) - } -} - -// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a -// NDPSourceLinkLayerAddressOption. -func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedBuf []byte - addr tcpip.LinkAddress - }{ - { - "Ethernet", - make([]byte, 8), - []byte{1, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", - }, - { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", - }, - { - "Empty", - nil, - nil, - "", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - NDPSourceLinkLayerAddressOption(test.addr), - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) - } - sll := next.(NDPSourceLinkLayerAddressOption) - if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) - } - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the -// Ethernet address from an NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { - tests := []struct { - name string - buf []byte - expected tcpip.LinkAddress - }{ - { - "ValidMAC", - []byte{1, 2, 3, 4, 5, 6}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - { - "TLLBodyTooShort", - []byte{1, 2, 3, 4, 5}, - tcpip.LinkAddress([]byte(nil)), - }, - { - "TLLBodyLargerThanNeeded", - []byte{1, 2, 3, 4, 5, 6, 7, 8}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tll := NDPTargetLinkLayerAddressOption(test.buf) - if got := tll.EthernetAddress(); got != test.expected { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) - } - }) - } -} - -// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a -// NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedBuf []byte - addr tcpip.LinkAddress - }{ - { - "Ethernet", - make([]byte, 8), - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", - }, - { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", - }, - { - "Empty", - nil, - nil, - "", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - NDPTargetLinkLayerAddressOption(test.addr), - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - tll := next.(NDPTargetLinkLayerAddressOption) - if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) - } - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPPrefixInformationOption tests the field getters and serialization of a -// NDPPrefixInformation. -func TestNDPPrefixInformationOption(t *testing.T) { - b := []byte{ - 43, 127, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 5, 5, 5, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPPrefixInformation(b), - } - opts.Serialize(serializer) - expectedBuf := []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - if !bytes.Equal(targetBuf, expectedBuf) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - pi := next.(NDPPrefixInformation) - - if got := pi.Type(); got != 3 { - t.Errorf("got Type = %d, want = 3", got) - } - - if got := pi.Length(); got != 30 { - t.Errorf("got Length = %d, want = 30", got) - } - - if got := pi.PrefixLength(); got != 43 { - t.Errorf("got PrefixLength = %d, want = 43", got) - } - - if pi.OnLinkFlag() { - t.Error("got OnLinkFlag = true, want = false") - } - - if !pi.AutonomousAddressConfigurationFlag() { - t.Error("got AutonomousAddressConfigurationFlag = false, want = true") - } - - if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { - t.Errorf("got ValidLifetime = %d, want = %d", got, want) - } - - if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { - t.Errorf("got PreferredLifetime = %d, want = %d", got, want) - } - - if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { - t.Errorf("got Prefix = %s, want = %s", got, want) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 25, 3, 0, 0, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPRecursiveDNSServer(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Type(); got != 25 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16909320*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - want := []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - } - if got := opt.Addresses(); !cmp.Equal(got, want) { - t.Errorf("got Addresses = %v, want = %v", got, want) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOption(t *testing.T) { - tests := []struct { - name string - buf []byte - lifetime time.Duration - addrs []tcpip.Address - }{ - { - "Valid1Addr", - []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - }, - }, - { - "Valid2Addr", - []byte{ - 25, 5, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", - }, - }, - { - "Valid3Addr", - []byte{ - 25, 7, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - // Iterator should get our option. - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Lifetime(); got != test.lifetime { - t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) - } - if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { - t.Errorf("got Addresses = %v, want = %v", got, test.addrs) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions -// the iterator was returned for is malformed. -func TestNDPOptionsIterCheck(t *testing.T) { - tests := []struct { - name string - buf []byte - expected error - }{ - { - "ZeroLengthField", - []byte{0, 0, 0, 0, 0, 0, 0, 0}, - ErrNDPOptZeroLength, - }, - { - "ValidSourceLinkLayerAddressOption", - []byte{1, 1, 1, 2, 3, 4, 5, 6}, - nil, - }, - { - "TooSmallSourceLinkLayerAddressOption", - []byte{1, 1, 1, 2, 3, 4, 5}, - ErrNDPOptBufExhausted, - }, - { - "ValidTargetLinkLayerAddressOption", - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - nil, - }, - { - "TooSmallTargetLinkLayerAddressOption", - []byte{2, 1, 1, 2, 3, 4, 5}, - ErrNDPOptBufExhausted, - }, - { - "ValidPrefixInformation", - []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - nil, - }, - { - "TooSmallPrefixInformation", - []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, - }, - ErrNDPOptBufExhausted, - }, - { - "InvalidPrefixInformationLength", - []byte{ - 3, 3, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - }, - ErrNDPOptMalformedBody, - }, - { - "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", - []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - nil, - }, - { - "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", - []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // 255 is an unrecognized type. If 255 ends up - // being the type for some recognized type, - // update 255 to some other unrecognized value. - 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - nil, - }, - { - "InvalidRecursiveDNSServerCutsOffAddress", - []byte{ - 25, 4, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 1, 2, 3, 4, 5, 6, 7, - }, - ErrNDPOptMalformedBody, - }, - { - "InvalidRecursiveDNSServerInvalidLengthField", - []byte{ - 25, 2, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, - }, - ErrNDPInvalidLength, - }, - { - "RecursiveDNSServerTooSmall", - []byte{ - 25, 1, 0, 0, - 0, 0, 0, - }, - ErrNDPOptBufExhausted, - }, - { - "RecursiveDNSServerMulticast", - []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }, - ErrNDPOptMalformedBody, - }, - { - "RecursiveDNSServerUnspecified", - []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, - ErrNDPOptMalformedBody, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - - if _, err := opts.Iter(true); err != test.expected { - t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) - } - - // test.buf may be malformed but we chose not to check - // the iterator so it must return true. - if _, err := opts.Iter(false); err != nil { - t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) - } - }) - } -} - -// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, -// this test does not actually check any of the option's getters, it simply -// checks the option Type and Body. We have other tests that tests the option -// field gettings given an option body and don't need to duplicate those tests -// here. -func TestNDPOptionsIter(t *testing.T) { - buf := []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // 255 is an unrecognized type. If 255 ends up being the type - // for some recognized type, update 255 to some other - // unrecognized value. Note, this option should be skipped when - // iterating. - 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - opts := NDPOptions(buf) - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - // Test the first (Source Link-Layer) option. - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) - } - - // Test the next (Target Link-Layer) option. - next, done, err = it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - - // Test the next (Prefix Information) option. - // Note, the unrecognized option should be skipped. - next, done, err = it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go deleted file mode 100644 index 72563837b..000000000 --- a/pkg/tcpip/header/tcp_test.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestEncodeSACKBlocks(t *testing.T) { - testCases := []struct { - sackBlocks []header.SACKBlock - want []header.SACKBlock - bufSize int - }{ - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 40, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, - 30, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}}, - 20, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}}, - 10, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - nil, - 8, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 60, - }, - } - for _, tc := range testCases { - b := make([]byte, tc.bufSize) - t.Logf("testing: %v", tc) - header.EncodeSACKBlocks(tc.sackBlocks, b) - opts := header.ParseTCPOptions(b) - if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want) - } - } -} - -func TestTCPParseOptions(t *testing.T) { - type tsOption struct { - tsVal uint32 - tsEcr uint32 - } - - generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte { - l := 0 - if tsOpt != nil { - l += 10 - } - if len(sackBlocks) != 0 { - l += len(sackBlocks)*8 + 2 - } - b := make([]byte, l) - offset := 0 - if tsOpt != nil { - offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b) - } - header.EncodeSACKBlocks(sackBlocks, b[offset:]) - return b - } - - testCases := []struct { - b []byte - want header.TCPOptions - }{ - // Trivial cases. - {nil, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test timestamp parsing. - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - - // Test malformed timestamp option. - {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test SACKBlock parsing. - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}}, - {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}}, - - // Test malformed SACK option. - {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}}, - - // Test Timestamp + SACK block parsing. - {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}}, - - // Test valid timestamp + malformed SACK block parsing. - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - } - for _, tc := range testCases { - if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want) - } - } -} diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD deleted file mode 100644 index d1b73cfdf..000000000 --- a/pkg/tcpip/iptables/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "iptables", - srcs = [ - "iptables.go", - "targets.go", - "types.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/iptables/iptables_state_autogen.go b/pkg/tcpip/iptables/iptables_state_autogen.go new file mode 100755 index 000000000..e75169fa7 --- /dev/null +++ b/pkg/tcpip/iptables/iptables_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package iptables diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD deleted file mode 100644 index b8b93e78e..000000000 --- a/pkg/tcpip/link/channel/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "channel", - srcs = ["channel.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 5944ba190..5944ba190 100644..100755 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go diff --git a/pkg/tcpip/link/channel/channel_state_autogen.go b/pkg/tcpip/link/channel/channel_state_autogen.go new file mode 100755 index 000000000..ce52482a2 --- /dev/null +++ b/pkg/tcpip/link/channel/channel_state_autogen.go @@ -0,0 +1,22 @@ +// automatically generated by stateify. + +package channel + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *NotificationHandle) beforeSave() {} +func (x *NotificationHandle) save(m state.Map) { + x.beforeSave() + m.Save("n", &x.n) +} + +func (x *NotificationHandle) afterLoad() {} +func (x *NotificationHandle) load(m state.Map) { + m.Load("n", &x.n) +} + +func init() { + state.Register("pkg/tcpip/link/channel.NotificationHandle", (*NotificationHandle)(nil), state.Fns{Save: (*NotificationHandle).save, Load: (*NotificationHandle).load}) +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD deleted file mode 100644 index abe725548..000000000 --- a/pkg/tcpip/link/fdbased/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fdbased", - srcs = [ - "endpoint.go", - "endpoint_unsafe.go", - "mmap.go", - "mmap_stub.go", - "mmap_unsafe.go", - "packet_dispatchers.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/stack", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fdbased_test", - size = "small", - srcs = ["endpoint_test.go"], - library = ":fdbased", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go deleted file mode 100644 index 2066987eb..000000000 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ /dev/null @@ -1,468 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package fdbased - -import ( - "bytes" - "fmt" - "math/rand" - "reflect" - "syscall" - "testing" - "time" - "unsafe" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - mtu = 1500 - laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") - raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") - proto = 10 - csumOffset = 48 - gsoMSS = 500 -) - -type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents tcpip.PacketBuffer -} - -type context struct { - t *testing.T - fds [2]int - ep stack.LinkEndpoint - ch chan packetInfo - done chan struct{} -} - -func newContext(t *testing.T, opt *Options) *context { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) - if err != nil { - t.Fatalf("Socketpair failed: %v", err) - } - - done := make(chan struct{}, 1) - opt.ClosedFunc = func(*tcpip.Error) { - done <- struct{}{} - } - - opt.FDs = []int{fds[1]} - ep, err := New(opt) - if err != nil { - t.Fatalf("Failed to create FD endpoint: %v", err) - } - - c := &context{ - t: t, - fds: fds, - ep: ep, - ch: make(chan packetInfo, 100), - done: done, - } - - ep.Attach(c) - - return c -} - -func (c *context) cleanup() { - syscall.Close(c.fds[0]) - <-c.done - syscall.Close(c.fds[1]) -} - -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - c.ch <- packetInfo{remote, protocol, pkt} -} - -func TestNoEthernetProperties(t *testing.T) { - c := newContext(t, &Options{MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestEthernetProperties(t *testing.T) { - c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestAddress(t *testing.T) { - addrs := []tcpip.LinkAddress{"", "abc", "def"} - for _, a := range addrs { - t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { - c := newContext(t, &Options{Address: a, MTU: mtu}) - defer c.cleanup() - - if want, v := a, c.ep.LinkAddress(); want != v { - t.Fatalf("LinkAddress() = %v, want %v", v, want) - } - }) - } -} - -func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) - defer c.cleanup() - - r := &stack.Route{ - RemoteLinkAddress: raddr, - } - - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) - } - - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) - } - want := append(hdr.View(), payload...) - var gso *stack.GSO - if gsoMaxSize != 0 { - gso = &stack.GSO{ - Type: stack.GSOTCPv6, - NeedsCsum: true, - CsumOffset: csumOffset, - MSS: gsoMSS, - MaxSize: gsoMaxSize, - L3HdrLen: header.IPv4MaximumHeaderSize, - } - } - if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from fd, then compare with what we wrote. - b = make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - if gsoMaxSize != 0 { - vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0])) - if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { - t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) - } - csumStart := header.EthernetMinimumSize + gso.L3HdrLen - if vnetHdr.csumStart != csumStart { - t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) - } - if vnetHdr.csumOffset != csumOffset { - t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) - } - gsoType := uint8(0) - if int(gso.MSS) < plen { - gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 - } - if vnetHdr.gsoType != gsoType { - t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType) - } - b = b[virtioNetHdrSize:] - } - if eth { - h := header.Ethernet(b) - b = b[header.EthernetMinimumSize:] - - if a := h.SourceAddress(); a != laddr { - t.Fatalf("SourceAddress() = %v, want %v", a, laddr) - } - - if a := h.DestinationAddress(); a != raddr { - t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) - } - - if et := h.Type(); et != proto { - t.Fatalf("Type() = %v, want %v", et, proto) - } - } - if len(b) != len(want) { - t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) - } - if !bytes.Equal(b, want) { - t.Fatalf("Read returned %x, want %x", b, want) - } -} - -func TestWritePacket(t *testing.T) { - lengths := []int{0, 100, 1000} - eths := []bool{true, false} - gsos := []uint32{0, 32768} - - for _, eth := range eths { - for _, plen := range lengths { - for _, gso := range gsos { - t.Run( - fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), - func(t *testing.T) { - testWritePacket(t, plen, eth, gso) - }, - ) - } - } - } -} - -func TestPreserveSrcAddress(t *testing.T) { - baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") - - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true}) - defer c.cleanup() - - // Set LocalLinkAddress in route to the value of the bridged address. - r := &stack.Route{ - RemoteLinkAddress: raddr, - LocalLinkAddress: baddr, - } - - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ - Header: hdr, - Data: buffer.VectorisedView{}, - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from the FD, then compare with what we wrote. - b := make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - h := header.Ethernet(b) - - if a := h.SourceAddress(); a != baddr { - t.Fatalf("SourceAddress() = %v, want %v", a, baddr) - } -} - -func TestDeliverPacket(t *testing.T) { - lengths := []int{100, 1000} - eths := []bool{true, false} - - for _, eth := range eths { - for _, plen := range lengths { - t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) - defer c.cleanup() - - // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) - } - - var hdr header.Ethernet - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr = make(header.Ethernet, header.EthernetMinimumSize) - hdr.Encode(&header.EthernetFields{ - SrcAddr: raddr, - DstAddr: laddr, - Type: proto, - }) - all = append(hdr, b...) - } - - // Write packet via the file descriptor. - if _, err := syscall.Write(c.fds[0], all); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Receive packet through the endpoint. - select { - case pi := <-c.ch: - want := packetInfo{ - raddr: raddr, - proto: proto, - contents: tcpip.PacketBuffer{ - Data: buffer.View(b).ToVectorisedView(), - LinkHeader: buffer.View(hdr), - }, - } - if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - // want.contents.Data will be a single - // view, so make pi do the same for the - // DeepEqual check. - pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) - } - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for packet") - } - }) - } - } -} - -func TestBufConfigMaxLength(t *testing.T) { - got := 0 - for _, i := range BufConfig { - got += i - } - want := header.MaxIPPacketSize // maximum TCP packet size - if got < want { - t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) - } -} - -func TestBufConfigFirst(t *testing.T) { - // The stack assumes that the TCP/IP header is enterily contained in the first view. - // Therefore, the first view needs to be large enough to contain the maximum TCP/IP - // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). - want := 120 - got := BufConfig[0] - if got < want { - t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) - } -} - -var capLengthTestCases = []struct { - comment string - config []int - n int - wantUsed int - wantLengths []int -}{ - { - comment: "Single slice", - config: []int{2}, - n: 1, - wantUsed: 1, - wantLengths: []int{1}, - }, - { - comment: "Multiple slices", - config: []int{1, 2}, - n: 2, - wantUsed: 2, - wantLengths: []int{1, 1}, - }, - { - comment: "Entire buffer", - config: []int{1, 2}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2}, - }, - { - comment: "Entire buffer but not on the last slice", - config: []int{1, 2, 3}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2, 3}, - }, -} - -func TestReadVDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - // fd does not matter for this test. - d := readVDispatcher{fd: -1, e: &endpoint{}} - d.views = make([]buffer.View, len(c.config)) - d.iovecs = make([]syscall.Iovec, len(c.config)) - d.allocateViews(c.config) - - used := d.capViews(c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views)) - for i, v := range d.views { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - } -} - -func TestRecvMMsgDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - d := recvMMsgDispatcher{ - fd: -1, // fd does not matter for this test. - e: &endpoint{}, - views: make([][]buffer.View, 1), - iovecs: make([][]syscall.Iovec, 1), - msgHdrs: make([]rawfile.MMsgHdr, 1), - } - - for i, _ := range d.views { - d.views[i] = make([]buffer.View, len(c.config)) - } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, len(c.config)) - } - for k, msgHdr := range d.msgHdrs { - msgHdr.Msg.Iov = &d.iovecs[k][0] - msgHdr.Msg.Iovlen = uint64(len(c.config)) - } - - d.allocateViews(c.config) - - used := d.capViews(0, c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views[0])) - for i, v := range d.views[0] { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - - } -} diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go new file mode 100755 index 000000000..97cb3958e --- /dev/null +++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go @@ -0,0 +1,10 @@ +// automatically generated by stateify. + +// +build linux +// +build linux +// +build linux,amd64 linux,arm64 +// +build !linux !amd64,!arm64 +// +build linux,amd64 linux,arm64 +// +build linux + +package fdbased diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD deleted file mode 100644 index 6bf3805b7..000000000 --- a/pkg/tcpip/link/loopback/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "loopback", - srcs = ["loopback.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go new file mode 100755 index 000000000..c00fd9f19 --- /dev/null +++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package loopback diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD deleted file mode 100644 index 82b441b79..000000000 --- a/pkg/tcpip/link/muxed/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "muxed", - srcs = ["injectable.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "muxed_test", - size = "small", - srcs = ["injectable_test.go"], - library = ":muxed", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 445b22c17..445b22c17 100644..100755 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go deleted file mode 100644 index 63b249837..000000000 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package muxed - -import ( - "bytes" - "net" - "os" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -func TestInjectableEndpointRawDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - endpoint.InjectOutbound(dstIP, []byte{0xFA}) - - buf := make([]byte, ipv4.MaxTotalSize) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) - - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA, 0xFB}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buffer.NewView(0).ToVectorisedView(), - }) - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tcpip.Address) { - dstIP := tcpip.Address(net.ParseIP("1.2.3.4").To4()) - pair, err := syscall.Socketpair(syscall.AF_UNIX, - syscall.SOCK_SEQPACKET|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0) - if err != nil { - t.Fatal("Failed to create socket pair:", err) - } - underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) - routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - endpoint := NewInjectableEndpoint(routes) - return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP -} diff --git a/pkg/tcpip/link/muxed/muxed_state_autogen.go b/pkg/tcpip/link/muxed/muxed_state_autogen.go new file mode 100755 index 000000000..56330e2a5 --- /dev/null +++ b/pkg/tcpip/link/muxed/muxed_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package muxed diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD deleted file mode 100644 index 14b527bc2..000000000 --- a/pkg/tcpip/link/rawfile/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "rawfile", - srcs = [ - "blockingpoll_amd64.s", - "blockingpoll_arm64.s", - "blockingpoll_noyield_unsafe.go", - "blockingpoll_yield_unsafe.go", - "errors.go", - "rawfile_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go new file mode 100755 index 000000000..6b6816bae --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go @@ -0,0 +1,10 @@ +// automatically generated by stateify. + +// +build linux,!amd64,!arm64 +// +build linux,amd64 linux,arm64 +// +build go1.12 +// +build !go1.15 +// +build linux +// +build linux + +package rawfile diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD deleted file mode 100644 index 13243ebbb..000000000 --- a/pkg/tcpip/link/sharedmem/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sharedmem", - srcs = [ - "rx.go", - "sharedmem.go", - "sharedmem_unsafe.go", - "tx.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "sharedmem_test", - srcs = [ - "sharedmem_test.go", - ], - library = ":sharedmem", - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/sharedmem/pipe", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD deleted file mode 100644 index 87020ec08..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = [ - "pipe.go", - "pipe_unsafe.go", - "rx.go", - "tx.go", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "pipe_test", - srcs = [ - "pipe_test.go", - ], - library = ":pipe", - deps = ["//pkg/sync"], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe.go b/pkg/tcpip/link/sharedmem/pipe/pipe.go index 74c9f0311..74c9f0311 100644..100755 --- a/pkg/tcpip/link/sharedmem/pipe/pipe.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe.go diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go new file mode 100755 index 000000000..d3b40feb4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pipe diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go deleted file mode 100644 index dc239a0d0..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ /dev/null @@ -1,518 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "math/rand" - "reflect" - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestSimpleReadWrite(t *testing.T) { - // Check that a simple write can be properly read from the rx side. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - wb := tx.Push(10) - if wb == nil { - t.Fatalf("Push failed on empty pipe") - } - for i := range wb { - wb[i] = byte(tr.Intn(256)) - } - tx.Flush() - - var rx Rx - rx.Init(b) - rb := rx.Pull() - if len(rb) != 10 { - t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10) - } - - for i := range rb { - if v := byte(rr.Intn(256)); v != rb[i] { - t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v) - } - } - rx.Flush() -} - -func TestEmptyRead(t *testing.T) { - // Check that pulling from an empty pipe fails. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestTooLargeWrite(t *testing.T) { - // Check that writes that are too large are properly rejected. - b := make([]byte, 96) - var tx Tx - tx.Init(b) - - if wb := tx.Push(96); wb != nil { - t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(88); wb != nil { - t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } -} - -func TestFullWrite(t *testing.T) { - // Check that writes fail when the pipe is full. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestFullAndFlushedWrite(t *testing.T) { - // Check that writes fail when the pipe is full and has already been - // flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - tx.Flush() - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestTxFlushTwice(t *testing.T) { - // Checks that a second consecutive tx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // Make copy of original tx queue, flush it, then check that it didn't - // change. - orig := tx - tx.Flush() - - if !reflect.DeepEqual(orig, tx) { - t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig) - } -} - -func TestRxFlushTwice(t *testing.T) { - // Checks that a second consecutive rx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // Make copy of original rx queue, flush it, then check that it didn't - // change. - orig := rx - rx.Flush() - - if !reflect.DeepEqual(orig, rx) { - t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig) - } -} - -func TestWrapInMiddleOfTransaction(t *testing.T) { - // Check that writes are not flushed when we need to wrap the buffer - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestWriteAbort(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it but - // has aborted the push. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Abort() - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestWrappedWriteAbort(t *testing.T) { - // Check that writes are properly aborted even if the writes wrap - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Abort() - - // The pushes were aborted, so no data should be readable. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - // Try the same transactions again, but flush this time. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestEmptyReadOnNonFlushedWrite(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it - // but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Flush() - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull on failed on non-empty pipe") - } -} - -func TestPullAfterPullingEntirePipe(t *testing.T) { - // Check that Pull fails when the pipe is full, but all of it has - // already been pulled but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3 - // buffers that will fill the pipe. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - if wb := tx.Push(24); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The three buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - // Fourth pull must fail. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestNoRoomToWrapOnPush(t *testing.T) { - // Check that Push fails when it tries to allocate room to add a wrap - // message. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20, - // which won't fit (64+20+8+padding = 96, which wouldn't leave room for - // the padding), so it wraps around. - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - tx.Flush() - - // Buffer offset is at 28. Try to write 70, which would require a wrap - // slot which cannot be created now. - if wb := tx.Push(70); wb != nil { - t.Fatalf("Push succeeded on pipe with no room for wrap message") - } -} - -func TestRxImplicitFlushOfWrapMessage(t *testing.T) { - // Check if the first read is that of a wrapping message, that it gets - // immediately flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // This will cause a wrapping message to written. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - var rx Rx - rx.Init(b) - - // Read the first message. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // This should fail because of the wrapping message is taking up space. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - // Try to read the next one. This should consume the wrapping message. - rx.Pull() - - // This must now succeed. - if wb := tx.Push(60); wb == nil { - t.Fatalf("Push failed on empty pipe") - } -} - -func TestConcurrentReaderWriter(t *testing.T) { - // Push a million buffers of random sizes and random contents. Check - // that buffers read match what was written. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - - const count = 1000000 - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + tr.Intn(80) - wb := tx.Push(uint64(n)) - for wb == nil { - wb = tx.Push(uint64(n)) - } - - for j := range wb { - wb[j] = byte(tr.Intn(256)) - } - - tx.Flush() - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } - - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } - - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } - } - - rx.Flush() - } - }() - - wg.Wait() -} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go index 62d17029e..62d17029e 100644..100755 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go diff --git a/pkg/tcpip/link/sharedmem/pipe/rx.go b/pkg/tcpip/link/sharedmem/pipe/rx.go index f22e533ac..f22e533ac 100644..100755 --- a/pkg/tcpip/link/sharedmem/pipe/rx.go +++ b/pkg/tcpip/link/sharedmem/pipe/rx.go diff --git a/pkg/tcpip/link/sharedmem/pipe/tx.go b/pkg/tcpip/link/sharedmem/pipe/tx.go index 9841eb231..9841eb231 100644..100755 --- a/pkg/tcpip/link/sharedmem/pipe/tx.go +++ b/pkg/tcpip/link/sharedmem/pipe/tx.go diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD deleted file mode 100644 index 3ba06af73..000000000 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "queue", - srcs = [ - "rx.go", - "tx.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip/link/sharedmem/pipe", - ], -) - -go_test( - name = "queue_test", - srcs = [ - "queue_test.go", - ], - library = ":queue", - deps = [ - "//pkg/tcpip/link/sharedmem/pipe", - ], -) diff --git a/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go new file mode 100755 index 000000000..563d4fbb4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package queue diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go deleted file mode 100644 index 9a0aad5d7..000000000 --- a/pkg/tcpip/link/sharedmem/queue/queue_test.go +++ /dev/null @@ -1,517 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package queue - -import ( - "encoding/binary" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -func TestBasicTxQueue(t *testing.T) { - // Tests that a basic transmit on a queue works, and that completion - // gets properly reported as well. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Enqueue two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on empty queue") - } - - // Check the contents of the pipe. - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - - want := []byte{ - 234, 3, 0, 0, 0, 0, 0, 0, // id - 100, 0, 0, 0, // total size - 0, 0, 0, 0, // reserved - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - } - - if !reflect.DeepEqual(want, d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want) - } - - rxp.Flush() - - // Check that there are no completions yet. - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Packet reported as completed too soon") - } - - // Post a completion. - d = txp.Push(8) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - - // Check that completion is properly reported. - id, ok := q.CompletedPacket() - if !ok { - t.Fatalf("Completion not reported") - } - - if id != usedID { - t.Fatalf("Bad completion id: got %v, want %v", id, usedID) - } -} - -func TestBasicRxQueue(t *testing.T) { - // Tests that a basic receive on a queue works. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on empty queue") - } - - // Check the contents of the pipe. - want := [][]byte{ - { - 100, 0, 0, 0, 0, 0, 0, 0, // Offset1 - 60, 0, 0, 0, // Size1 - 0, 0, 0, 0, // Remaining in group 1 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - }, - { - 200, 0, 0, 0, 0, 0, 0, 0, // Offset2 - 40, 0, 0, 0, // Size2 - 0, 0, 0, 0, // Remaining in group 2 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }, - } - - for i := range b { - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - - if !reflect.DeepEqual(want[i], d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want[i]) - } - - rxp.Flush() - } - - // Check that there are no completions. - if _, n := q.Dequeue(nil); n != 0 { - t.Fatalf("Packet reported as received too soon") - } - - // Post a completion. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - // Check that completion is properly reported. - bufs, n := q.Dequeue(nil) - if n != 100 { - t.Fatalf("Bad packet size: got %v, want %v", n, 100) - } - - if !reflect.DeepEqual(bufs, b) { - t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b) - } -} - -func TestBadTxCompletion(t *testing.T) { - // Check that tx completions with bad sizes are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion that is too long, and check that it is ignored. - if d := txp.Push(10); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } -} - -func TestBadRxCompletion(t *testing.T) { - // Check that bad rx completions are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes add up to less than the total - // size. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 10, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 10, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes will cause a 32-bit overflow, - // but adds up to the right number. - d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 255, 255, 255, 255, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 101, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } -} - -func TestFillTxPipe(t *testing.T) { - // Check that transmitting a new buffer when the buffer pipe is full - // fails gracefully. - pb1 := make([]byte, 104) - pb2 := make([]byte, 104) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Transmit twice, which should fill the tx pipe. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - for i := uint64(0); i < 2; i++ { - if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) { - t.Fatalf("Failed to transmit buffer") - } - } - - // Transmit another packet now that the tx pipe is full. - if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue succeeded when tx pipe is full") - } -} - -func TestFillRxPipe(t *testing.T) { - // Check that posting a new buffer when the buffer pipe is full fails - // gracefully. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a buffer twice, it should fill the tx pipe. - b := []RxBuffer{ - {100, 60, 1077, 0}, - } - - for i := 0; i < 2; i++ { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - } - - // Post another buffer now that the tx pipe is full. - if q.PostBuffers(b) { - t.Fatalf("PostBuffers succeeded on full queue") - } -} - -func TestLotsOfTransmissions(t *testing.T) { - // Make sure pipes are being properly flushed when transmitting packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Prepare packet with two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - - // Post 100000 packets and completions. - for i := 100000; i > 0; i-- { - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - rxp.Flush() - - d := txp.Push(8) - if d == nil { - t.Fatalf("Unable to write to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - if _, ok := q.CompletedPacket(); !ok { - t.Fatalf("Completion not returned") - } - } -} - -func TestLotsOfReceptions(t *testing.T) { - // Make sure pipes are being properly flushed when receiving packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Prepare for posting two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - // Post 100000 buffers and completions. - for i := 100000; i > 0; i-- { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - if _, n := q.Dequeue(nil); n == 0 { - t.Fatalf("Dequeue failed when there is a completion") - } - } -} - -func TestRxEnableNotification(t *testing.T) { - // Check that enabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.EnableNotification() - if state != eventFDEnabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled) - } -} - -func TestRxDisableNotification(t *testing.T) { - // Check that disabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.DisableNotification() - if state != eventFDDisabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled) - } -} diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go index 696e6c9e5..696e6c9e5 100644..100755 --- a/pkg/tcpip/link/sharedmem/queue/rx.go +++ b/pkg/tcpip/link/sharedmem/queue/rx.go diff --git a/pkg/tcpip/link/sharedmem/queue/tx.go b/pkg/tcpip/link/sharedmem/queue/tx.go index beffe807b..beffe807b 100644..100755 --- a/pkg/tcpip/link/sharedmem/queue/tx.go +++ b/pkg/tcpip/link/sharedmem/queue/tx.go diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go index eec11e4cb..eec11e4cb 100644..100755 --- a/pkg/tcpip/link/sharedmem/rx.go +++ b/pkg/tcpip/link/sharedmem/rx.go diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 655e537c4..655e537c4 100644..100755 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go diff --git a/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go new file mode 100755 index 000000000..bc12017b2 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build linux +// +build linux + +package sharedmem diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go deleted file mode 100644 index 5c729a439..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ /dev/null @@ -1,812 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package sharedmem - -import ( - "bytes" - "io/ioutil" - "math/rand" - "os" - "strings" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - localLinkAddr = "\xde\xad\xbe\xef\x56\x78" - remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" - - queueDataSize = 1024 * 1024 - queuePipeSize = 4096 -) - -type queueBuffers struct { - data []byte - rx pipe.Tx - tx pipe.Rx -} - -func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) { - // Prepare tx pipe. - b, err := getBuffer(c.TxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.tx.Init(b) - - // Prepare rx pipe. - b, err = getBuffer(c.RxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.rx.Init(b) - - // Get data slice. - q.data, err = getBuffer(c.DataFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } -} - -func (q *queueBuffers) cleanup() { - syscall.Munmap(q.tx.Bytes()) - syscall.Munmap(q.rx.Bytes()) - syscall.Munmap(q.data) -} - -type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView - linkHeader buffer.View -} - -type testContext struct { - t *testing.T - ep *endpoint - txCfg QueueConfig - rxCfg QueueConfig - txq queueBuffers - rxq queueBuffers - - packetCh chan struct{} - mu sync.Mutex - packets []packetInfo -} - -func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext { - var err error - c := &testContext{ - t: t, - packetCh: make(chan struct{}, 1000000), - } - c.txCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - c.rxCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - initQueue(t, &c.txq, &c.txCfg) - initQueue(t, &c.rxq, &c.rxCfg) - - ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - c.ep = ep.(*endpoint) - c.ep.Attach(c) - - return c -} - -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - c.mu.Lock() - c.packets = append(c.packets, packetInfo{ - addr: remoteLinkAddr, - proto: proto, - vv: pkt.Data.Clone(nil), - }) - c.mu.Unlock() - - c.packetCh <- struct{}{} -} - -func (c *testContext) cleanup() { - c.ep.Close() - closeFDs(&c.txCfg) - closeFDs(&c.rxCfg) - c.txq.cleanup() - c.rxq.cleanup() -} - -func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) { - for i := 0; i < n; i++ { - select { - case <-c.packetCh: - case <-to: - c.t.Fatalf(errorStr) - } - } -} - -func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) { - b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs))) - queue.EncodeRxCompletion(b, size, 0) - for i := range bs { - queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{ - Offset: bs[i].Offset, - Size: bs[i].Size, - ID: bs[i].ID, - }) - } -} - -func randomFill(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func shuffle(b []int) { - for i := len(b) - 1; i >= 0; i-- { - j := rand.Intn(i + 1) - b[i], b[j] = b[j], b[i] - } -} - -func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir := os.Getenv("TEST_TMPDIR") - if tmpDir == "" { - tmpDir = os.Getenv("TMPDIR") - } - f, err := ioutil.TempFile(tmpDir, "sharedmem_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - syscall.Unlink(f.Name()) - - if initQueue { - // Write the "slot-free" flag in the initial queue. - _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) - if err != nil { - t.Fatalf("WriteAt failed: %v", err) - } - } - - fd, err := syscall.Dup(int(f.Fd())) - if err != nil { - t.Fatalf("Dup failed: %v", err) - } - - if err := syscall.Ftruncate(fd, size); err != nil { - syscall.Close(fd) - t.Fatalf("Ftruncate failed: %v", err) - } - - return fd -} - -func closeFDs(c *QueueConfig) { - syscall.Close(c.DataFD) - syscall.Close(c.EventFD) - syscall.Close(c.TxPipeFD) - syscall.Close(c.RxPipeFD) - syscall.Close(c.SharedDataFD) -} - -type queueSizes struct { - dataSize int64 - txPipeSize int64 - rxPipeSize int64 - sharedDataSize int64 -} - -func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { - fd, _, err := syscall.RawSyscall(syscall.SYS_EVENTFD2, 0, 0, 0) - if err != 0 { - t.Fatalf("eventfd failed: %v", error(err)) - } - - return QueueConfig{ - EventFD: int(fd), - DataFD: createFile(t, s.dataSize, false), - TxPipeFD: createFile(t, s.txPipeSize, true), - RxPipeFD: createFile(t, s.rxPipeSize, true), - SharedDataFD: createFile(t, s.sharedDataSize, false), - } -} - -// TestSimpleSend sends 1000 packets with random header and payload sizes, -// then checks that the right payload is received on the shared memory queues. -func TestSimpleSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - for iters := 1000; iters > 0; iters-- { - func() { - // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) - randomFill(hdrBuf) - - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check the ethernet header. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: localLinkAddr, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } - - // Compare contents skipping the ethernet header added by the - // endpoint. - merged := append(hdrBuf, buf...) - if uint32(len(contents)) < pi.Size { - t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) - } - contents = contents[:pi.Size][header.EthernetMinimumSize:] - - if !bytes.Equal(contents, merged) { - t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged)) - } - }() - } -} - -// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress -// set in Route (using much of the same code as TestSimpleSend), then checks -// that the encoded ethernet header received includes the correct SrcAddr. -func TestPreserveSrcAddressInSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) - // Set both remote and local link address in route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - LocalLinkAddress: newLocalLinkAddress, - } - - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ - Header: hdr, - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check that the ethernet header contains the expected SrcAddr. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: newLocalLinkAddress, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } -} - -// TestFillTxQueue sends packets until the queue is full. -func TestFillTxQueue(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets -// until the queue is full. -func TestFillTxQueueAfterBadCompletion(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Send a bad completion. - queue.EncodeTxCompletion(c.txq.rx.Push(8), 1) - c.txq.rx.Flush() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Send two packets so that the id slice has at least two slots. - for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } - - // Complete the two writes twice. - for i := 2; i > 0; i-- { - pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull()) - - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - c.txq.rx.Flush() - } - c.txq.tx.Flush() - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxMemory sends packets until the we run out of shared memory. -func TestFillTxMemory(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible until - // we fill the memory. - ids := make(map[uint64]struct{}) - for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - c.txq.tx.Flush() - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }) - if want := tcpip.ErrWouldBlock; err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of -// shared memory for a 2-buffer packet, but still with room for a 1-buffer -// packet. -func TestFillTxMemoryWithMultiBuffer(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible - // until there is only one buffer left. - for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Pull the posted buffer. - c.txq.tx.Pull() - c.txq.tx.Flush() - } - - // Attempt to write a two-buffer packet. It must fail. - { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: uu, - }); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } - } - - // Attempt to write the one-buffer packet again. It must succeed. - { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } -} - -func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte { - t.Helper() - - for { - b := p.Pull() - if b != nil { - return b - } - - select { - case <-time.After(10 * time.Millisecond): - case <-to: - t.Fatal(errStr) - } - } -} - -// TestSimpleReceive completes 1000 different receives with random payload and -// random number of buffers. It checks that the contents match the expected -// values. -func TestSimpleReceive(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Check that buffers have been posted. - limit := c.ep.rx.q.PostedBuffersLimit() - for i := uint64(0); i < limit; i++ { - timeout := time.After(2 * time.Second) - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted")) - - if want := i * bufferSize; want != bi.Offset { - t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want) - } - - if want := i; want != bi.ID { - t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want) - } - - if bufferSize != bi.Size { - t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize) - } - } - c.rxq.tx.Flush() - - // Create a slice with the indices 0..limit-1. - idx := make([]int, limit) - for i := range idx { - idx[i] = i - } - - // Complete random packets 1000 times. - for iters := 1000; iters > 0; iters-- { - timeout := time.After(2 * time.Second) - // Prepare a random packet. - shuffle(idx) - n := 1 + rand.Intn(10) - bufs := make([]queue.RxBuffer, n) - contents := make([]byte, bufferSize*n-rand.Intn(500)) - randomFill(contents) - for i := range bufs { - j := idx[i] - bufs[i].Size = bufferSize - bufs[i].Offset = uint64(bufferSize * j) - bufs[i].ID = uint64(j) - - copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:]) - } - - // Push completion. - c.pushRxCompletion(uint32(len(contents)), bufs) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be received, then check it. - c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") - c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) - c.packets = c.packets[:0] - c.mu.Unlock() - - if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) { - t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents) - } - - // Check that buffers have been reposted. - for i := range bufs { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted")) - if bi != bufs[i] { - t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i]) - } - } - c.rxq.tx.Flush() - } -} - -// TestRxBuffersReposted tests that rx buffers get reposted after they have been -// completed. -func TestRxBuffersReposted(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Receive all posted buffers. - limit := c.ep.rx.q.PostedBuffersLimit() - buffers := make([]queue.RxBuffer, 0, limit) - for i := limit; i > 0; i-- { - timeout := time.After(2 * time.Second) - buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers"))) - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when individually completed. - for i := range buffers { - timeout := time.After(2 * time.Second) - // Complete the buffer. - c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for it to be reposted. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[i] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i]) - } - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when completed in pairs. - for i := 0; i < len(buffers)/2; i++ { - timeout := time.After(2 * time.Second) - // Complete with two buffers. - c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for them to be reposted. - for j := 0; j < 2; j++ { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[2*i+j] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j]) - } - } - } - c.rxq.tx.Flush() -} - -// TestReceivePostingIsFull checks that the endpoint will properly handle the -// case when a received buffer cannot be immediately reposted because it hasn't -// been pulled from the tx pipe yet. -func TestReceivePostingIsFull(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Complete first posted buffer before flushing it from the tx pipe. - first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) - c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that packet is received. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Complete another buffer. - second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) - c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that no packet is received yet, as the worker is blocked trying - // to repost. - select { - case <-time.After(500 * time.Millisecond): - case <-c.packetCh: - t.Fatalf("Unexpected packet received") - } - - // Flush tx queue, which will allow the first buffer to be reposted, - // and the second completion to be pulled. - c.rxq.tx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that second packet completes. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") -} - -// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to -// repost a buffer. Make sure it backs out. -func TestCloseWhileWaitingToPost(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - cleaned := false - defer func() { - if !cleaned { - c.cleanup() - } - }() - - // Complete first posted buffer before flushing it from the tx pipe. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) - c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be indicated. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Cleanup and wait for worker to complete. - c.cleanup() - cleaned = true - c.ep.Wait() -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go index f7e816a41..f7e816a41 100644..100755 --- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index 6b8d7859d..6b8d7859d 100644..100755 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD deleted file mode 100644 index 230a8d53a..000000000 --- a/pkg/tcpip/link/sniffer/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sniffer", - srcs = [ - "pcap.go", - "sniffer.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go new file mode 100755 index 000000000..8d79defea --- /dev/null +++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sniffer diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD deleted file mode 100644 index e0db6cf54..000000000 --- a/pkg/tcpip/link/tun/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "tun", - srcs = [ - "device.go", - "protocol.go", - "tun_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/refs", - "//pkg/sync", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 6ff47a742..6ff47a742 100644..100755 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go index 89d9d91a9..89d9d91a9 100644..100755 --- a/pkg/tcpip/link/tun/protocol.go +++ b/pkg/tcpip/link/tun/protocol.go diff --git a/pkg/tcpip/link/tun/tun_state_autogen.go b/pkg/tcpip/link/tun/tun_state_autogen.go new file mode 100755 index 000000000..8b56175e4 --- /dev/null +++ b/pkg/tcpip/link/tun/tun_state_autogen.go @@ -0,0 +1,29 @@ +// automatically generated by stateify. + +// +build linux + +package tun + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Device) save(m state.Map) { + x.beforeSave() + m.Save("Queue", &x.Queue) + m.Save("endpoint", &x.endpoint) + m.Save("notifyHandle", &x.notifyHandle) + m.Save("flags", &x.flags) +} + +func (x *Device) afterLoad() {} +func (x *Device) load(m state.Map) { + m.Load("Queue", &x.Queue) + m.Load("endpoint", &x.endpoint) + m.Load("notifyHandle", &x.notifyHandle) + m.Load("flags", &x.flags) +} + +func init() { + state.Register("pkg/tcpip/link/tun.Device", (*Device)(nil), state.Fns{Save: (*Device).save, Load: (*Device).load}) +} diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go index 09ca9b527..09ca9b527 100644..100755 --- a/pkg/tcpip/link/tun/tun_unsafe.go +++ b/pkg/tcpip/link/tun/tun_unsafe.go diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD deleted file mode 100644 index 0956d2c65..000000000 --- a/pkg/tcpip/link/waitable/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "waitable", - srcs = [ - "waitable.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/gate", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "waitable_test", - srcs = [ - "waitable_test.go", - ], - library = ":waitable", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index a8de38979..a8de38979 100644..100755 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go diff --git a/pkg/tcpip/link/waitable/waitable_state_autogen.go b/pkg/tcpip/link/waitable/waitable_state_autogen.go new file mode 100755 index 000000000..059424fa0 --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package waitable diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go deleted file mode 100644 index 31b11a27a..000000000 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waitable - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type countedEndpoint struct { - dispatchCount int - writeCount int - attachCount int - - mtu uint32 - capabilities stack.LinkEndpointCapabilities - hdrLen uint16 - linkAddr tcpip.LinkAddress - - dispatcher stack.NetworkDispatcher -} - -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - e.dispatchCount++ -} - -func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.attachCount++ - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *countedEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *countedEndpoint) MTU() uint32 { - return e.mtu -} - -func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.capabilities -} - -func (e *countedEndpoint) MaxHeaderLength() uint16 { - return e.hdrLen -} - -func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - e.writeCount++ - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - e.writeCount += len(pkts) - return len(pkts), nil -} - -func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - e.writeCount++ - return nil -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*countedEndpoint) Wait() {} - -func TestWaitWrite(t *testing.T) { - ep := &countedEndpoint{} - wep := New(ep) - - // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) - if want := 1; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on dispatches, then try to write. It must go through. - wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on writes, then try to write. It must not go through. - wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } -} - -func TestWaitDispatch(t *testing.T) { - ep := &countedEndpoint{} - wep := New(ep) - - // Check that attach happens. - wep.Attach(ep) - if want := 1; ep.attachCount != want { - t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want) - } - - // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) - if want := 1; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on writes, then try to dispatch. It must go through. - wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on dispatches, then try to dispatch. It must not go through. - wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } -} - -func TestOtherMethods(t *testing.T) { - const ( - mtu = 0xdead - capabilities = 0xbeef - hdrLen = 0x1234 - linkAddr = "test address" - ) - ep := &countedEndpoint{ - mtu: mtu, - capabilities: capabilities, - hdrLen: hdrLen, - linkAddr: linkAddr, - } - wep := New(ep) - - if v := wep.MTU(); v != mtu { - t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) - } - - if v := wep.Capabilities(); v != capabilities { - t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities) - } - - if v := wep.MaxHeaderLength(); v != hdrLen { - t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen) - } - - if v := wep.LinkAddress(); v != linkAddr { - t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr) - } -} diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD deleted file mode 100644 index 6a4839fb8..000000000 --- a/pkg/tcpip/network/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "ip_test", - size = "small", - srcs = [ - "ip_test.go", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - ], -) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD deleted file mode 100644 index eddf7b725..000000000 --- a/pkg/tcpip/network/arp/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "arp", - srcs = ["arp.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "arp_test", - size = "small", - srcs = ["arp_test.go"], - deps = [ - ":arp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - ], -) diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go new file mode 100755 index 000000000..5cd8535e3 --- /dev/null +++ b/pkg/tcpip/network/arp/arp_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package arp diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go deleted file mode 100644 index 03cf03b6d..000000000 --- a/pkg/tcpip/network/arp/arp_test.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp_test - -import ( - "context" - "strconv" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -const ( - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") -) - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack -} - -func newTestContext(t *testing.T) *testContext { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, - }) - - const defaultMTU = 65536 - ep := channel.New(256, defaultMTU, stackLinkAddr) - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("AddAddress for arp failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }}) - - return &testContext{ - t: t, - s: s, - linkEP: ep, - } -} - -func (c *testContext) cleanup() { - c.linkEP.Close() -} - -func TestDirectRequest(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - - inject := func(addr tcpip.Address) { - copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ - Data: v.ToVectorisedView(), - }) - } - - for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { - t.Run(strconv.Itoa(i), func(t *testing.T) { - inject(address) - pi, _ := c.linkEP.ReadContext(context.Background()) - if pi.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) - } - rep := header.ARP(pi.Pkt.Header.View()) - if !rep.IsValid() { - t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { - t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) - } - }) - } - - inject(stackAddrBad) - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) - if pkt, ok := c.linkEP.ReadContext(ctx); ok { - t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - } -} diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD deleted file mode 100644 index d1c728ccf..000000000 --- a/pkg/tcpip/network/fragmentation/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "reassembler_list", - out = "reassembler_list.go", - package = "fragmentation", - prefix = "reassembler", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*reassembler", - "Linker": "*reassembler", - }, -) - -go_library( - name = "fragmentation", - srcs = [ - "frag_heap.go", - "fragmentation.go", - "reassembler.go", - "reassembler_list.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - ], -) - -go_test( - name = "fragmentation_test", - size = "small", - srcs = [ - "frag_heap_test.go", - "fragmentation_test.go", - "reassembler_test.go", - ], - library = ":fragmentation", - deps = ["//pkg/tcpip/buffer"], -) diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go deleted file mode 100644 index 9ececcb9f..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -var reassambleTestCases = []struct { - comment string - in []fragment - want buffer.VectorisedView -}{ - { - comment: "Non-overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Non-overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Duplicated packets", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(1, "0"), - }, - { - comment: "Overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(2, "01")}, - {offset: 1, vv: vv(2, "12")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(2, "12")}, - {offset: 0, vv: vv(2, "01")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping subset in-order", - in: []fragment{ - {offset: 0, vv: vv(3, "012")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(3, "012"), - }, - { - comment: "Overlapping subset out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(3, "012")}, - }, - want: vv(3, "012"), - }, -} - -func TestReassamble(t *testing.T) { - for _, c := range reassambleTestCases { - t.Run(c.comment, func(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, err := h.reassemble() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, c.want) { - t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) - } - }) - } -} - -func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when the first packet had offset != 0") - } -} - -func TestReassambleFailsForHoles(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) - heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when there was a hole in the packet") - } -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go new file mode 100755 index 000000000..cbaecdaa7 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go @@ -0,0 +1,38 @@ +// automatically generated by stateify. + +package fragmentation + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *reassemblerList) beforeSave() {} +func (x *reassemblerList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *reassemblerList) afterLoad() {} +func (x *reassemblerList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *reassemblerEntry) beforeSave() {} +func (x *reassemblerEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *reassemblerEntry) afterLoad() {} +func (x *reassemblerEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/tcpip/network/fragmentation.reassemblerList", (*reassemblerList)(nil), state.Fns{Save: (*reassemblerList).save, Load: (*reassemblerList).load}) + state.Register("pkg/tcpip/network/fragmentation.reassemblerEntry", (*reassemblerEntry)(nil), state.Fns{Save: (*reassemblerEntry).save, Load: (*reassemblerEntry).load}) +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go deleted file mode 100644 index 72c0f53be..000000000 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -// vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -type processInput struct { - id uint32 - first uint16 - last uint16 - more bool - vv buffer.VectorisedView -} - -type processOutput struct { - vv buffer.VectorisedView - done bool -} - -var processTestCases = []struct { - comment string - in []processInput - out []processOutput -}{ - { - comment: "One ID", - in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, - { - comment: "Two IDs", - in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "ab", "cd"), done: true}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, -} - -func TestFragmentationProcess(t *testing.T) { - for _, c := range processTestCases { - t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) - for i, in := range c.in { - vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) - if err != nil { - t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) - } - if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) - } - if done != c.out[i].done { - t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) - } - if c.out[i].done { - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Process(%d) did not remove buffer from reassemblers", i) - } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Process(%d) did not remove buffer from rList", i) - } - } - } - } - }) - } -} - -func TestReassemblingTimeout(t *testing.T) { - timeout := time.Millisecond - f := NewFragmentation(1024, 512, timeout) - // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(0, 0, 0, true, vv(1, "0")) - // Sleep more than the timeout. - time.Sleep(2 * timeout) - // Send another fragment that completes a packet. - // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) - if err != nil { - t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) - } - if done { - t.Errorf("Fragmentation does not respect the reassembling timeout.") - } -} - -func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(3, 1, DefaultReassembleTimeout) - // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) - // Send first fragment with id = 1. - f.Process(1, 0, 0, true, vv(1, "1")) - // Send first fragment with id = 2. - f.Process(2, 0, 0, true, vv(1, "2")) - - // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be - // evicted. - f.Process(3, 0, 0, true, vv(1, "3")) - - if _, ok := f.reassemblers[0]; ok { - t.Errorf("Memory limits are not respected: id=0 has not been evicted.") - } - if _, ok := f.reassemblers[1]; ok { - t.Errorf("Memory limits are not respected: id=1 has not been evicted.") - } - if _, ok := f.reassemblers[3]; !ok { - t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") - } -} - -func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(1, 0, DefaultReassembleTimeout) - // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) - // Send the same packet again. - f.Process(0, 0, 0, true, vv(1, "0")) - - got := f.size - want := 1 - if got != want { - t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) - } -} diff --git a/pkg/tcpip/network/fragmentation/reassembler_list.go b/pkg/tcpip/network/fragmentation/reassembler_list.go new file mode 100755 index 000000000..a48422c97 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/reassembler_list.go @@ -0,0 +1,186 @@ +package fragmentation + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type reassemblerElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type reassemblerList struct { + head *reassembler + tail *reassembler +} + +// Reset resets list l to the empty state. +func (l *reassemblerList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *reassemblerList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *reassemblerList) Front() *reassembler { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *reassemblerList) Back() *reassembler { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *reassemblerList) PushFront(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *reassemblerList) PushBack(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *reassemblerList) PushBackList(m *reassemblerList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) + reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *reassemblerList) InsertAfter(b, e *reassembler) { + bLinker := reassemblerElementMapper{}.linkerFor(b) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *reassemblerList) InsertBefore(a, e *reassembler) { + aLinker := reassemblerElementMapper{}.linkerFor(a) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *reassemblerList) Remove(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + reassemblerElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type reassemblerEntry struct { + next *reassembler + prev *reassembler +} + +// Next returns the entry that follows e in the list. +func (e *reassemblerEntry) Next() *reassembler { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *reassemblerEntry) Prev() *reassembler { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *reassemblerEntry) SetNext(elem *reassembler) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *reassemblerEntry) SetPrev(elem *reassembler) { + e.prev = elem +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go deleted file mode 100644 index 7eee0710d..000000000 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "math" - "reflect" - "testing" -) - -type updateHolesInput struct { - first uint16 - last uint16 - more bool -} - -var holesTestCases = []struct { - comment string - in []updateHolesInput - want []hole -}{ - { - comment: "No fragments. Expected holes: {[0 -> inf]}.", - in: []updateHolesInput{}, - want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, - }, - { - comment: "One fragment at beginning. Expected holes: {[2, inf]}.", - in: []updateHolesInput{{first: 0, last: 1, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: false}, - }, - }, - { - comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", - in: []updateHolesInput{{first: 1, last: 2, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - {first: 3, last: math.MaxUint16, deleted: false}, - }, - }, - { - comment: "One fragment at the end. Expected holes: {[0, 0]}.", - in: []updateHolesInput{{first: 1, last: 2, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - }, - }, - { - comment: "One fragment completing a packet. Expected holes: {}.", - in: []updateHolesInput{{first: 0, last: 1, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - }, - }, - { - comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 1, more: true}, - {first: 2, last: 3, more: false}, - }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: true}, - }, - }, - { - comment: "Two overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 2, more: true}, - {first: 2, last: 3, more: false}, - }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 3, last: math.MaxUint16, deleted: true}, - }, - }, -} - -func TestUpdateHoles(t *testing.T) { - for _, c := range holesTestCases { - r := newReassembler(0) - for _, i := range c.in { - r.updateHoles(i.first, i.last, i.more) - } - if !reflect.DeepEqual(r.holes, c.want) { - t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) - } - } -} diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD deleted file mode 100644 index 872165866..000000000 --- a/pkg/tcpip/network/hash/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hash", - srcs = ["hash.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/rand", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go new file mode 100755 index 000000000..9467fe298 --- /dev/null +++ b/pkg/tcpip/network/hash/hash_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hash diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go deleted file mode 100644 index f4d78f8c6..000000000 --- a/pkg/tcpip/network/ip_test.go +++ /dev/null @@ -1,655 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - localIpv4Addr = "\x0a\x00\x00\x01" - localIpv4PrefixLen = 24 - remoteIpv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - localIpv6PrefixLen = 120 - remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" -) - -// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. -// The former is used to pretend that it's a link endpoint so that we can -// inspect packets written by the network endpoints. The latter is used to -// pretend that it's the network stack so that it can inspect incoming packets -// that have been handled by the network endpoints. -// -// Packets are checked by comparing their fields/values against the expected -// values stored in the test object itself. -type testObject struct { - t *testing.T - protocol tcpip.TransportProtocolNumber - contents []byte - srcAddr tcpip.Address - dstAddr tcpip.Address - v4 bool - typ stack.ControlType - extra uint32 - - dataCalls int - controlCalls int -} - -// checkValues verifies that the transport protocol, data contents, src & dst -// addresses of a packet match what's expected. If any field doesn't match, the -// test fails. -func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { - v := vv.ToView() - if protocol != t.protocol { - t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) - } - - if srcAddr != t.srcAddr { - t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) - } - - if dstAddr != t.dstAddr { - t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) - } - - if len(v) != len(t.contents) { - t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) - } - - for i := range t.contents { - if t.contents[i] != v[i] { - t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) - } - } -} - -// DeliverTransportPacket is called by network endpoints after parsing incoming -// packets. This is used by the test object to verify that the results of the -// parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { - t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) - t.dataCalls++ -} - -// DeliverTransportControlPacket is called by network endpoints after parsing -// incoming control (ICMP) packets. This is used by the test object to verify -// that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { - t.checkValues(trans, pkt.Data, remote, local) - if typ != t.typ { - t.t.Errorf("typ = %v, want %v", typ, t.typ) - } - if extra != t.extra { - t.t.Errorf("extra = %v, want %v", extra, t.extra) - } - t.controlCalls++ -} - -// Attach is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (*testObject) IsAttached() bool { - return true -} - -// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that -// matches the linux loopback MTU. -func (*testObject) MTU() uint32 { - return 65536 -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (*testObject) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (*testObject) LinkAddress() tcpip.LinkAddress { - return "" -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*testObject) Wait() {} - -// WritePacket is called by network endpoints after producing a packet and -// writing it to the link endpoint. This is used by the test object to verify -// that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - var prot tcpip.TransportProtocolNumber - var srcAddr tcpip.Address - var dstAddr tcpip.Address - - if t.v4 { - h := header.IPv4(pkt.Header.View()) - prot = tcpip.TransportProtocolNumber(h.Protocol()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - - } else { - h := header.IPv6(pkt.Header.View()) - prot = tcpip.TransportProtocolNumber(h.NextHeader()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - } - t.checkValues(prot, pkt.Data, srcAddr, dstAddr) - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - panic("not implemented") -} - -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - return tcpip.ErrNotSupported -} - -func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, - }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv4.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - Gateway: ipv4Gateway, - NIC: 1, - }}) - - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) -} - -func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, - }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv6.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: ipv6Gateway, - NIC: 1, - }}) - - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) -} - -func buildDummyStack() *stack.Stack { - return stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, - }) -} - -func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) - - // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv4MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: view.ToVectorisedView(), - }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv4ReceiveControl(t *testing.T) { - const mtu = 0xbeef - header.IPv4MinimumSize - cases := []struct { - name string - expectedCount int - fragmentOffset uint16 - code uint8 - expectedTyp stack.ControlType - expectedExtra uint32 - trunc int - }{ - {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, - {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, - {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, - } - r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") - if err != nil { - t.Fatal(err) - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize - view := buffer.NewView(dataOffset + 8) - - // Create the outer IPv4 header. - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(view) - c.trunc), - TTL: 20, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIpv4Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) - icmp.SetType(header.ICMPv4DstUnreachable) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: 100, - TTL: 20, - Protocol: 10, - FragmentOffset: c.fragmentOffset, - SrcAddr: localIpv4Addr, - DstAddr: remoteIpv4Addr, - }) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv4 endpoint, dispatcher will validate that - // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra - - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: vv, - }) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) - } - }) - } -} - -func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv4MinimumSize + 24 - - frag1 := buffer.NewView(totalLen) - ip1 := header.IPv4(frag1) - ip1.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 0, - Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag1[i] = uint8(i) - } - - frag2 := buffer.NewView(totalLen) - ip2 := header.IPv4(frag2) - ip2.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 24, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag2[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - // Send first segment. - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: frag1.ToVectorisedView(), - }) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) - } - - // Send second segment. - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: frag2.ToVectorisedView(), - }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv6Send(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) - - // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload - - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv6MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(totalLen - header.IPv6MinimumSize), - NextHeader: 10, - HopLimit: 20, - SrcAddr: remoteIpv6Addr, - DstAddr: localIpv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] - - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: view.ToVectorisedView(), - }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv6ReceiveControl(t *testing.T) { - newUint16 := func(v uint16) *uint16 { return &v } - - const mtu = 0xffff - const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" - cases := []struct { - name string - expectedCount int - fragmentOffset *uint16 - typ header.ICMPv6Type - code uint8 - expectedTyp stack.ControlType - expectedExtra uint32 - trunc int - }{ - {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, - {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, - {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, - {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, - } - r, err := buildIPv6Route( - localIpv6Addr, - "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - ) - if err != nil { - t.Fatal(err) - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - defer ep.Close() - - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize - if c.fragmentOffset != nil { - dataOffset += header.IPv6FragmentHeaderSize - } - view := buffer.NewView(dataOffset + 8) - - // Create the outer IPv6 header. - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIpv6Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) - icmp.SetType(c.typ) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - NextHeader: 10, - HopLimit: 20, - SrcAddr: localIpv6Addr, - DstAddr: remoteIpv6Addr, - }) - - // Build the fragmentation header if needed. - if c.fragmentOffset != nil { - ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) - frag.Encode(&header.IPv6FragmentFields{ - NextHeader: 10, - FragmentOffset: *c.fragmentOffset, - M: true, - Identification: 0x12345678, - }) - } - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv6 endpoint, dispatcher will validate that - // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra - - // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: view[:len(view)-c.trunc].ToVectorisedView(), - }) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD deleted file mode 100644 index 0fef2b1f1..000000000 --- a/pkg/tcpip/network/ipv4/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv4", - srcs = [ - "icmp.go", - "ipv4.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/network/fragmentation", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv4_test", - size = "small", - srcs = ["ipv4_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go new file mode 100755 index 000000000..250b2128e --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ipv4 diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go deleted file mode 100644 index e900f1b45..000000000 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ /dev/null @@ -1,475 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "bytes" - "encoding/hex" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestExcludeBroadcast(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - }) - - const defaultMTU = 65536 - ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) - if testing.Verbose() { - ep = sniffer.New(ep) - } - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }}) - - randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} - - var wq waiter.Queue - t.Run("WithoutPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Cannot connect using a broadcast address as the source. - if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) - } - - // However, we can bind to a broadcast address to listen. - if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { - t.Errorf("Bind failed: %v", err) - } - }) - - t.Run("WithPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := ep.Connect(randomAddr); err != nil { - t.Errorf("Connect failed: %v", err) - } - }) -} - -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much -// data should already be in the header before WritePacket. extraLength -// indicates how much extra space should be in the header. The payload is made -// from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - - var views []buffer.View - totalLength := 0 - for _, s := range viewSizes { - newView := buffer.NewView(s) - rand.Read(newView) - views = append(views, newView) - totalLength += s - } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload -} - -// comparePayloads compared the contents of all the packets against the contents -// of the source packet. -func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { - t.Helper() - // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Data.ToView()...) - - // Make a copy of the IP header, which will be modified in some fields to make - // an expected header. - sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) - sourceCopy.SetChecksum(0) - sourceCopy.SetFlagsFragmentOffset(0, 0) - sourceCopy.SetTotalLength(0) - var offset uint16 - // Build up an array of the bytes sent. - var reassembledPayload []byte - for i, packet := range packets { - // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Data) - ip := header.IPv4(allBytes.ToView()) - if !ip.IsValid(len(ip)) { - t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) - } - if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { - t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) - } - if got, want := len(ip), int(mtu); got > want { - t.Errorf("fragment is too large, got %d want %d", got, want) - } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) - } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { - t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) - } - if i < len(packets)-1 { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) - } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) - } - reassembledPayload = append(reassembledPayload, ip.Payload()...) - offset += ip.TotalLength() - uint16(ip.HeaderLength()) - // Clear out the checksum and length from the ip because we can't compare - // it. - sourceCopy.SetTotalLength(uint16(len(ip))) - sourceCopy.SetChecksum(0) - sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { - t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) - } - } - expected := source[source.HeaderLength():] - if !bytes.Equal(reassembledPayload, expected) { - t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) - } -} - -type errorChannel struct { - *channel.Endpoint - Ch chan tcpip.PacketBuffer - packetCollectorErrors []*tcpip.Error -} - -// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket -// will return successive errors from packetCollectorErrors until the list is -// empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { - return &errorChannel{ - Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan tcpip.PacketBuffer, size), - packetCollectorErrors: packetCollectorErrors, - } -} - -// Drain removes all outbound packets from the channel and counts them. -func (e *errorChannel) Drain() int { - c := 0 - for { - select { - case <-e.Ch: - c++ - default: - return c - } - } -} - -// WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - select { - case e.Ch <- pkt: - default: - } - - nextError := (*tcpip.Error)(nil) - if len(e.packetCollectorErrors) > 0 { - nextError = e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - } - return nextError -} - -type context struct { - stack.Route - linkEP *errorChannel -} - -func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { - // Make the packet and write it. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - s.CreateNIC(1, ep) - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - s.AddAddress(1, ipv4.ProtocolNumber, src) - { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - return context{ - Route: r, - linkEP: ep, - } -} - -func TestFragmentation(t *testing.T) { - var manyPayloadViewsSizes [1000]int - for i := range manyPayloadViewsSizes { - manyPayloadViewsSizes[i] = 7 - } - fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - hdrLength int - extraLength int - payloadViewsSizes []int - expectedFrags int - }{ - {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, - {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, - } - - for _, ft := range fragTests { - t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := tcpip.PacketBuffer{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), - } - c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: payload, - }) - if err != nil { - t.Errorf("err got %v, want %v", err, nil) - } - - var results []tcpip.PacketBuffer - L: - for { - select { - case pi := <-c.linkEP.Ch: - results = append(results, pi) - default: - break L - } - } - - if got, want := len(results), ft.expectedFrags; got != want { - t.Errorf("len(result) got %d, want %d", got, want) - } - if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(result) got %d, want %d", got, want) - } - compareFragments(t, results, source, ft.mtu) - }) - } -} - -// TestFragmentationErrors checks that errors are returned from write packet -// correctly. -func TestFragmentationErrors(t *testing.T) { - fragTests := []struct { - description string - mtu uint32 - hdrLength int - payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error - }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, - } - - for _, ft := range fragTests { - t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) - c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: payload, - }) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } - } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) - } - if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) - } - }) - } -} - -func TestInvalidFragments(t *testing.T) { - // These packets have both IHL and TotalLength set to 0. - testCases := []struct { - name string - packets [][]byte - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - }{ - { - "ihl_totallen_zero_valid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 0, - }, - { - "ihl_totallen_zero_invalid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 0, - }, - { - // Total Length of 37(20 bytes IP header + 17 bytes of - // payload) - // Frag Offset of 0x1ffe = 8190*8 = 65520 - // Leading to the fragment end to be past 65535. - "ihl_totallen_valid_invalid_frag_offset_1", - [][]byte{ - {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 1, - }, - // The following 3 tests were found by running a fuzzer and were - // triggering a panic in the IPv4 reassembler code. - { - "ihl_less_than_ipv4_minimum_size_1", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 2, - 0, - }, - { - "ihl_less_than_ipv4_minimum_size_2", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 2, - 0, - }, - { - "ihl_less_than_ipv4_minimum_size_3", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 2, - 0, - }, - { - "fragment_with_short_total_len_extra_payload", - [][]byte{ - {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 1, - }, - { - "multiple_fragments_with_more_fragments_set_to_false", - [][]byte{ - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - }, - 1, - 1, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - const nicID tcpip.NICID = 42 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - }, - }) - - var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) - var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) - ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicID, sniffer.New(ep)) - - for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), - }) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { - t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { - t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD deleted file mode 100644 index fb11874c6..000000000 --- a/pkg/tcpip/network/ipv6/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv6", - srcs = [ - "icmp.go", - "ipv6.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv6_test", - size = "small", - srcs = [ - "icmp_test.go", - "ipv6_test.go", - "ndp_test.go", - ], - library = ":ipv6", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go deleted file mode 100644 index 50c4b6474..000000000 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ /dev/null @@ -1,958 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "context" - "reflect" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") -) - -var ( - lladdr0 = header.LinkLocalAddr(linkAddr0) - lladdr1 = header.LinkLocalAddr(linkAddr1) -) - -type stubLinkEndpoint struct { - stack.LinkEndpoint -} - -func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -func (*stubLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { - return nil -} - -func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -type stubDispatcher struct { - stack.TransportDispatcher -} - -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { -} - -type stubLinkAddressCache struct { - stack.LinkAddressCache -} - -func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID { - return 0 -} - -func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { -} - -func TestICMPCounts(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, - }) - { - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } - - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - defer r.Release() - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - typ header.ICMPv6Type - size int - extraData []byte - }{ - { - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - }, - { - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - }, - { - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - }, - { - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, - { - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - }, - { - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - }, - } - - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - for _, typ := range types { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) - - handleIPv6Payload(hdr) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) - - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } -} - -func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { - t := v.Type() - for i := 0; i < v.NumField(); i++ { - v := v.Field(i) - if s, ok := v.Interface().(*tcpip.StatCounter); ok { - f(t.Field(i).Name, s) - } else { - visitStats(v, f) - } - } -} - -type testContext struct { - s0 *stack.Stack - s1 *stack.Stack - - linkEP0 *channel.Endpoint - linkEP1 *channel.Endpoint -} - -type endpointWithResolutionCapability struct { - stack.LinkEndpoint -} - -func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities { - return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired -} - -func newTestContext(t *testing.T) *testContext { - c := &testContext{ - s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, - }), - s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, - }), - } - - const defaultMTU = 65536 - c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) - - wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) - if testing.Verbose() { - wrappedEP0 = sniffer.New(wrappedEP0) - } - if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { - t.Fatalf("CreateNIC s0: %v", err) - } - if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) - } - - c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) - wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) - if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) - } - - subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - c.s0.SetRouteTable( - []tcpip.Route{{ - Destination: subnet0, - NIC: 1, - }}, - ) - subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) - if err != nil { - t.Fatal(err) - } - c.s1.SetRouteTable( - []tcpip.Route{{ - Destination: subnet1, - NIC: 1, - }}, - ) - - return c -} - -func (c *testContext) cleanup() { - c.linkEP0.Close() - c.linkEP1.Close() -} - -type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type - remoteLinkAddr tcpip.LinkAddress -} - -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { - t.Helper() - - pi, _ := args.src.ReadContext(context.Background()) - - { - views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} - size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ - Data: vv, - }) - } - - if pi.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pi.Proto) - return - } - - if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { - t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) - } - - ipv6 := header.IPv6(pi.Pkt.Header.View()) - transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) - if transProto != header.ICMPv6ProtocolNumber { - t.Errorf("unexpected transport protocol number %d", transProto) - return - } - icmpv6 := header.ICMPv6(ipv6.Payload()) - if got, want := icmpv6.Type(), args.typ; got != want { - t.Errorf("got ICMPv6 type = %d, want = %d", got, want) - return - } - if fn != nil { - fn(t, icmpv6) - } -} - -func TestLinkResolution(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - defer r.Release() - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - payload := tcpip.SlicePayload(hdr.View()) - - // We can't send our payload directly over the route because that - // doesn't provoke NDP discovery. - var wq waiter.Queue - ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } - - for { - _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}) - if resCh != nil { - if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) - } - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, - } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { - if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { - t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) - } - }) - } - <-resCh - continue - } - if err != nil { - t.Fatalf("ep.Write(_) = _, _, %s", err) - } - break - } - - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, - } { - routeICMPv6Packet(t, args, nil) - } -} - -func TestICMPChecksumValidationSimple(t *testing.T) { - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - }{ - { - name: "DstUnreachable", - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - }, - { - name: "PacketTooBig", - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - }, - { - name: "TimeExceeded", - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - }, - { - name: "ParamProblem", - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - }, - { - name: "EchoRequest", - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - }, - { - name: "EchoReply", - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - }, - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - handleIPv6Payload := func(checksum bool) { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(typ.size + extraDataLen), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestICMPChecksumValidationWithPayload(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - icmpSize := size + payloadSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) - - if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) - - payload := buffer.NewView(payloadSize) - payloadFn(payload) - - if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), - }) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go new file mode 100755 index 000000000..40c67d440 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ipv6 diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go deleted file mode 100644 index 1cbfa7278..000000000 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // The least significant 3 bytes are the same as addr2 so both addr2 and - // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" -) - -// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the -// expected Neighbor Advertisement received count after receiving the packet. -func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - // Receive ICMP packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - stats := s.Stats().ICMP.V6PacketsReceived - - if got := stats.NeighborAdvert.Value(); got != want { - t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) - } -} - -// testReceiveUDP tests receiving a UDP packet from src to dst. want is the -// expected UDP received count after receiving the packet. -func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - - // Receive UDP Packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - - // UDP pseudo-header checksum. - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) - - // UDP checksum - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - stat := s.Stats().UDP.PacketsReceived - - if got := stat.Value(); got != want { - t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) - } -} - -// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and -// UDP packets destined to the IPv6 link-local all-nodes multicast address. -func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocol - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, - }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should receive a packet destined to the all-nodes - // multicast address. - test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) - }) - } -} - -// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP -// packets destined to the IPv6 solicited-node address of an assigned IPv6 -// address. -func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocol - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, - } - - snmc := header.SolicitedNodeAddr(addr2) - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, - }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as we haven't added - // those addresses. - test.rxf(t, s, e, addr1, snmc, 0) - - if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) - } - - // Should receive a packet destined to the solicited - // node address of addr2/addr3 now that we have added - // added addr2. - test.rxf(t, s, e, addr1, snmc, 1) - - if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) - } - - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have added addr3. - test.rxf(t, s, e, addr1, snmc, 2) - - if err := s.RemoveAddress(1, addr2); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) - } - - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have removed addr2. - test.rxf(t, s, e, addr1, snmc, 3) - - if err := s.RemoveAddress(1, addr3); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) - } - - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as both of them got - // removed. - test.rxf(t, s, e, addr1, snmc, 3) - }) - } -} - -// TestAddIpv6Address tests adding IPv6 addresses. -func TestAddIpv6Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - }{ - // This test is in response to b/140943433. - { - "Nil", - tcpip.Address([]byte(nil)), - }, - { - "ValidUnicast", - addr1, - }, - { - "ValidLinkLocalUnicast", - lladdr0, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) - } - - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go deleted file mode 100644 index c9395de52..000000000 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ /dev/null @@ -1,613 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } - - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } - - return s, ep -} - -// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a -// valid NDP NS message with the Source Link Layer Address option results in a -// new entry in the link address cache for the sender of the message. -func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - e := channel.New(0, 1280, linkAddr0) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(lladdr0) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } - - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) - } - - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } - } - }) - } -} - -// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a -// valid NDP NA message with the Target Link Layer Address option results in a -// new entry in the link address cache for the target of the message. -func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - e := channel.New(0, 1280, linkAddr0) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) - ns.SetTargetAddress(lladdr1) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } - - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) - } - - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } - } - }) - } -} - -// TestHopLimitValidation is a test that makes sure that NDP packets are only -// received if their IP header's hop limit is set to 255. -func TestHopLimitValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r - } - - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - ep.HandlePacket(r, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - }{ - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - // Should not have received any ICMPv6 packets with - // type = typ.typ. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Receive the NDP packet with an invalid hop limit - // value. - handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) - - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - // Rx count of NDP packet of type typ.typ should not - // have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) - - // Rx count of NDP packet of type typ.typ should have - // increased. - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -// TestRouterAdvertValidation tests that when the NIC is configured to handle -// NDP Router Advertisement packets, it validates the Router Advertisement -// properly before handling them. -func TestRouterAdvertValidation(t *testing.T) { - tests := []struct { - name string - src tcpip.Address - hopLimit uint8 - code uint8 - ndpPayload []byte - expectedSuccess bool - }{ - { - "OK", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - true, - }, - { - "NonLinkLocalSourceAddr", - addr1, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "HopLimitNot255", - lladdr0, - 254, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NonZeroCode", - lladdr0, - 255, - 1, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NDPPayloadTooSmall", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, - }, - false, - }, - { - "OKWithOptions", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - 2, 1, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - true, - }, - { - "OptionWithZeroLength", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - // Invalid as it has 0 length. - 2, 0, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } - - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } - - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } - } - }) - } -} diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go index ab24372e7..ab24372e7 100644..100755 --- a/pkg/tcpip/packet_buffer.go +++ b/pkg/tcpip/packet_buffer.go diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go index ad3cc24fa..ad3cc24fa 100644..100755 --- a/pkg/tcpip/packet_buffer_state.go +++ b/pkg/tcpip/packet_buffer_state.go diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD deleted file mode 100644 index 2bad05a2e..000000000 --- a/pkg/tcpip/ports/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ports", - srcs = ["ports.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - ], -) - -go_test( - name = "ports_test", - srcs = ["ports_test.go"], - library = ":ports", - deps = [ - "//pkg/tcpip", - ], -) diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go new file mode 100755 index 000000000..f0ee1bb11 --- /dev/null +++ b/pkg/tcpip/ports/ports_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package ports + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Flags) beforeSave() {} +func (x *Flags) save(m state.Map) { + x.beforeSave() + m.Save("MostRecent", &x.MostRecent) + m.Save("LoadBalanced", &x.LoadBalanced) +} + +func (x *Flags) afterLoad() {} +func (x *Flags) load(m state.Map) { + m.Load("MostRecent", &x.MostRecent) + m.Load("LoadBalanced", &x.LoadBalanced) +} + +func init() { + state.Register("pkg/tcpip/ports.Flags", (*Flags)(nil), state.Fns{Save: (*Flags).save, Load: (*Flags).load}) +} diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go deleted file mode 100644 index d6969d050..000000000 --- a/pkg/tcpip/ports/ports_test.go +++ /dev/null @@ -1,402 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ports - -import ( - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 - - fakeIPAddress = tcpip.Address("\x08\x08\x08\x08") - fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") -) - -type portReserveTestAction struct { - port uint16 - ip tcpip.Address - want *tcpip.Error - flags Flags - release bool - device tcpip.NICID -} - -func TestPortReservation(t *testing.T) { - for _, test := range []struct { - tname string - actions []portReserveTestAction - }{ - { - tname: "bind to ip", - actions: []portReserveTestAction{ - {port: 80, ip: fakeIPAddress, want: nil}, - {port: 80, ip: fakeIPAddress1, want: nil}, - /* N.B. Order of tests matters! */ - {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, - }, - }, - { - tname: "bind to inaddr any", - actions: []portReserveTestAction{ - {port: 22, ip: anyIPAddress, want: nil}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - /* release fakeIPAddress, but anyIPAddress is still inuse */ - {port: 22, ip: fakeIPAddress, release: true}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, - /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ - {port: 22, ip: anyIPAddress, want: nil, release: true}, - {port: 22, ip: fakeIPAddress, want: nil}, - }, - }, { - tname: "bind to zero port", - actions: []portReserveTestAction{ - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind to ip with reuseport", - actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - - {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - - {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind to inaddr any with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, - }, - }, { - tname: "bind twice with device fails", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 3, want: nil}, - {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind to device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 1, want: nil}, - {port: 24, ip: fakeIPAddress, device: 2, want: nil}, - }, - }, { - tname: "bind to device and then without device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind without device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, want: nil}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "binding with reuseport and device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "mixing reuseport and not reuseport by binding to device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: nil}, - }, - }, { - tname: "can't bind to 0 after mixing reuseport and not reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind and release", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - - // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, - }, - }, { - tname: "bind twice with reuseport once", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "release an unreserved device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, - // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, - // Release all. - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, - }, - }, { - tname: "bind with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, - }, - }, { - tname: "bind twice with reuseaddr once", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr and reuseport, and then reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseaddr and reuseport, and then reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr and reuseport twice, and then reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr, and then reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseport, and then reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, - }, - }, - } { - t.Run(test.tname, func(t *testing.T) { - pm := NewPortManager() - net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} - - for _, test := range test.actions { - if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) - continue - } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) - if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) - } - if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) - } - } - }) - - } -} - -func TestPickEphemeralPort(t *testing.T) { - customErr := &tcpip.Error{} - for _, test := range []struct { - name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr - }, - wantErr: customErr, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port == FirstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: FirstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port < FirstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - } { - t.Run(test.name, func(t *testing.T) { - pm := NewPortManager() - if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) - } - }) - } -} - -func TestPickEphemeralPortStable(t *testing.T) { - customErr := &tcpip.Error{} - for _, test := range []struct { - name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr - }, - wantErr: customErr, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port == FirstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: FirstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port < FirstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - } { - t.Run(test.name, func(t *testing.T) { - pm := NewPortManager() - portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) - if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) - } - }) - } -} diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD deleted file mode 100644 index cf0a5fefe..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_connect", - srcs = ["main.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go deleted file mode 100644 index 0ab089208..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and connects to a peer. Similar to "nc <address> <port>". While the -// sample is running, attempts to connect to its IPv4 address will result in -// a RST segment. -// -// As an example of how to run it, a TUN device can be created and enabled on -// a linux host as follows (this only needs to be done once per boot): -// -// [sudo] ip tuntap add user <username> mode tun <device-name> -// [sudo] ip link set <device-name> up -// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name> -// -// A concrete example: -// -// $ sudo ip tuntap add user wedsonaf mode tun tun0 -// $ sudo ip link set tun0 up -// $ sudo ip addr add 192.168.1.1/24 dev tun0 -// -// Then one can run tun_tcp_connect as such: -// -// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234 -// -// This will attempt to connect to the linux host's stack. One can run nc in -// listen mode to accept a connect from tun_tcp_connect and exchange data. -package main - -import ( - "bufio" - "fmt" - "log" - "math/rand" - "net" - "os" - "strconv" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// writer reads from standard input and writes to the endpoint until standard -// input is closed. It signals that it's done by closing the provided channel. -func writer(ch chan struct{}, ep tcpip.Endpoint) { - defer func() { - ep.Shutdown(tcpip.ShutdownWrite) - close(ch) - }() - - r := bufio.NewReader(os.Stdin) - for { - v := buffer.NewView(1024) - n, err := r.Read(v) - if err != nil { - return - } - - v.CapLength(n) - for len(v) > 0 { - n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - if err != nil { - fmt.Println("Write failed:", err) - return - } - - v.TrimFront(int(n)) - } - } -} - -func main() { - if len(os.Args) != 6 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>") - } - - tunName := os.Args[1] - addrName := os.Args[2] - portName := os.Args[3] - remoteAddrName := os.Args[4] - remotePortName := os.Args[5] - - rand.Seed(time.Now().UnixNano()) - - addr := tcpip.Address(net.ParseIP(addrName).To4()) - remote := tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()), - } - - var localPort uint16 - if v, err := strconv.Atoi(portName); err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } else { - localPort = uint16(v) - } - - if v, err := strconv.Atoi(remotePortName); err != nil { - log.Fatalf("Unable to convert port %v: %v", remotePortName, err) - } else { - remote.Port = uint16(v) - } - - // Create the stack with ipv4 and tcp protocols, then add a tun-based - // NIC and ipv4 address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - fd, err := tun.Open(tunName) - if err != nil { - log.Fatal(err) - } - - linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - }) - - // Create TCP endpoint. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if e != nil { - log.Fatal(e) - } - - // Bind if a port is specified. - if localPort != 0 { - if err := ep.Bind(tcpip.FullAddress{0, "", localPort}); err != nil { - log.Fatal("Bind failed: ", err) - } - } - - // Issue connect request and wait for it to complete. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - terr := ep.Connect(remote) - if terr == tcpip.ErrConnectStarted { - fmt.Println("Connect is pending...") - <-notifyCh - terr = ep.GetSockOpt(tcpip.ErrorOption{}) - } - wq.EventUnregister(&waitEntry) - - if terr != nil { - log.Fatal("Unable to connect: ", terr) - } - - fmt.Println("Connected") - - // Start the writer in its own goroutine. - writerCompletedCh := make(chan struct{}) - go writer(writerCompletedCh, ep) // S/R-SAFE: sample code. - - // Read data and write to standard output until the peer closes the - // connection from its side. - wq.EventRegister(&waitEntry, waiter.EventIn) - for { - v, _, err := ep.Read(nil) - if err != nil { - if err == tcpip.ErrClosedForReceive { - break - } - - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - log.Fatal("Read() failed:", err) - } - - os.Stdout.Write(v) - } - wq.EventUnregister(&waitEntry) - - // The reader has completed. Now wait for the writer as well. - <-writerCompletedCh - - ep.Close() -} diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD deleted file mode 100644 index 43264b76d..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_echo", - srcs = ["main.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go deleted file mode 100644 index 9e37cab18..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and listens on a port. Data received by the server in the accepted -// connections is echoed back to the clients. -package main - -import ( - "flag" - "log" - "math/rand" - "net" - "os" - "strconv" - "strings" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var tap = flag.Bool("tap", false, "use tap istead of tun") -var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") - -func echo(wq *waiter.Queue, ep tcpip.Endpoint) { - defer ep.Close() - - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - - for { - v, _, err := ep.Read(nil) - if err != nil { - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - return - } - - ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - } -} - -func main() { - flag.Parse() - if len(flag.Args()) != 3 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>") - } - - tunName := flag.Arg(0) - addrName := flag.Arg(1) - portName := flag.Arg(2) - - rand.Seed(time.Now().UnixNano()) - - // Parse the mac address. - maddr, err := net.ParseMAC(*mac) - if err != nil { - log.Fatalf("Bad MAC address: %v", *mac) - } - - // Parse the IP address. Support both ipv4 and ipv6. - parsedAddr := net.ParseIP(addrName) - if parsedAddr == nil { - log.Fatalf("Bad IP address: %v", addrName) - } - - var addr tcpip.Address - var proto tcpip.NetworkProtocolNumber - if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) - proto = ipv4.ProtocolNumber - } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) - proto = ipv6.ProtocolNumber - } else { - log.Fatalf("Unknown IP type: %v", addrName) - } - - localPort, err := strconv.Atoi(portName) - if err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } - - // Create the stack with ip and tcp protocols, then add a tun-based - // NIC and address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - var fd int - if *tap { - fd, err = tun.OpenTAP(tunName) - } else { - fd, err = tun.Open(tunName) - } - if err != nil { - log.Fatal(err) - } - - linkEP, err := fdbased.New(&fdbased.Options{ - FDs: []int{fd}, - MTU: mtu, - EthernetHeader: *tap, - Address: tcpip.LinkAddress(maddr), - }) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, linkEP); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - log.Fatal(err) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) - if err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: subnet, - NIC: 1, - }, - }) - - // Create TCP endpoint, bind it, then start listening. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) - if e != nil { - log.Fatal(e) - } - - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}); err != nil { - log.Fatal("Bind failed: ", err) - } - - if err := ep.Listen(10); err != nil { - log.Fatal("Listen failed: ", err) - } - - // Wait for connections to appear. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - - for { - n, wq, err := ep.Accept() - if err != nil { - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - log.Fatal("Accept() failed:", err) - } - - go echo(wq, n) // S/R-SAFE: sample code. - } -} diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD deleted file mode 100644 index 45f503845..000000000 --- a/pkg/tcpip/seqnum/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "seqnum", - srcs = ["seqnum.go"], - visibility = ["//visibility:public"], -) diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go new file mode 100755 index 000000000..23e79811d --- /dev/null +++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package seqnum diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD deleted file mode 100644 index 6c029b2fb..000000000 --- a/pkg/tcpip/stack/BUILD +++ /dev/null @@ -1,95 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "linkaddrentry_list", - out = "linkaddrentry_list.go", - package = "stack", - prefix = "linkAddrEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*linkAddrEntry", - "Linker": "*linkAddrEntry", - }, -) - -go_library( - name = "stack", - srcs = [ - "dhcpv6configurationfromndpra_string.go", - "forwarder.go", - "icmp_rate_limit.go", - "linkaddrcache.go", - "linkaddrentry_list.go", - "ndp.go", - "nic.go", - "registration.go", - "route.go", - "stack.go", - "stack_global_state.go", - "transport_demuxer.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/ilist", - "//pkg/rand", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/waiter", - "@org_golang_x_time//rate:go_default_library", - ], -) - -go_test( - name = "stack_x_test", - size = "medium", - srcs = [ - "ndp_test.go", - "stack_test.go", - "transport_demuxer_test.go", - "transport_test.go", - ], - deps = [ - ":stack", - "//pkg/rand", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stack_test", - size = "small", - srcs = [ - "forwarder_test.go", - "linkaddrcache_test.go", - "nic_test.go", - ], - library = ":stack", - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - ], -) diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go index 8b4213eec..8b4213eec 100644..100755 --- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 631953935..631953935 100644..100755 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go deleted file mode 100644 index 321b7524d..000000000 --- a/pkg/tcpip/stack/forwarder_test.go +++ /dev/null @@ -1,635 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "encoding/binary" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -const ( - fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fwdTestNetHeaderLen = 12 - fwdTestNetDefaultPrefixLen = 8 - - // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match - // the MTU of loopback interfaces on linux systems. - fwdTestNetDefaultMTU = 65536 -) - -// fwdTestNetworkEndpoint is a network-layer protocol endpoint. -// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fwdTestNetworkEndpoint struct { - nicID tcpip.NICID - id NetworkEndpointID - prefixLen int - proto *fwdTestNetworkProtocol - dispatcher TransportDispatcher - ep LinkEndpoint -} - -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) -} - -func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID -} - -func (f *fwdTestNetworkEndpoint) PrefixLen() int { - return f.prefixLen -} - -func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { - return &f.id -} - -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt tcpip.PacketBuffer) { - // Consume the network header. - b := pkt.Data.First() - pkt.Data.TrimFront(fwdTestNetHeaderLen) - - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) -} - -func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen -} - -func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - -func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { - return f.ep.Capabilities() -} - -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { - // Add the protocol's header to the packet and send it to the link - // endpoint. - b := pkt.Header.Prepend(fwdTestNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) - - return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { - panic("not implemented") -} - -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported -} - -func (*fwdTestNetworkEndpoint) Close() {} - -// fwdTestNetworkProtocol is a network-layer protocol that implements Address -// resolution. -type fwdTestNetworkProtocol struct { - addrCache *linkAddrCache - addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) - onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) -} - -func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { - return fwdTestNetHeaderLen -} - -func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - -func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) -} - -func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { - return &fwdTestNetworkEndpoint{ - nicID: nicID, - id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - proto: f, - dispatcher: dispatcher, - ep: ep, - }, nil -} - -func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption -} - -func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption -} - -func (f *fwdTestNetworkProtocol) Close() {} - -func (f *fwdTestNetworkProtocol) Wait() {} - -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { - if f.addrCache != nil && f.onLinkAddressResolved != nil { - time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, addr) - }) - } - return nil -} - -func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if f.onResolveStaticAddress != nil { - return f.onResolveStaticAddress(addr) - } - return "", false -} - -func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -// fwdTestPacketInfo holds all the information about an outbound packet. -type fwdTestPacketInfo struct { - RemoteLinkAddress tcpip.LinkAddress - LocalLinkAddress tcpip.LinkAddress - Pkt tcpip.PacketBuffer -} - -type fwdTestLinkEndpoint struct { - dispatcher NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - - // C is where outbound packets are queued. - C chan fwdTestPacketInfo -} - -// InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - e.InjectLinkAddr(protocol, "", pkt) -} - -// InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) -} - -// Attach saves the stack network-layer dispatcher for use later when packets -// are injected. -func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *fwdTestLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *fwdTestLinkEndpoint) MTU() uint32 { - return e.mtu -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { - caps := LinkEndpointCapabilities(0) - return caps | CapabilityResolutionRequired -} - -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - -// MaxHeaderLength returns the maximum size of the link layer header. Given it -// doesn't have a header, it just returns 0. -func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, - LocalLinkAddress: r.LocalLinkAddress, - Pkt: pkt, - } - - select { - case e.C <- p: - default: - } - - return nil -} - -// WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - n := 0 - for _, pkt := range pkts { - e.WritePacket(r, gso, protocol, pkt) - n++ - } - - return n, nil -} - -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := fwdTestPacketInfo{ - Pkt: tcpip.PacketBuffer{Data: vv}, - } - - select { - case e.C <- p: - default: - } - - return nil -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*fwdTestLinkEndpoint) Wait() {} - -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { - // Create a stack with the network protocol and two NICs. - s := New(Options{ - NetworkProtocols: []NetworkProtocol{proto}, - }) - - proto.addrCache = s.linkAddrCache - - // Enable forwarding. - s.SetForwarding(true) - - // NIC 1 has the link address "a", and added the network address 1. - ep1 = &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "a", - } - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) - } - - // NIC 2 has the link address "b", and added the network address 2. - ep2 = &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "b", - } - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } - - // Route all packets to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) - } - - return ep1, ep2 -} - -func TestForwardingWithStaticResolver(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithFakeResolver(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithNoResolver(t *testing.T) { - // Create a network protocol without a resolver. - proto := &fwdTestNetworkProtocol{} - - ep1, ep2 := fwdTestNetFactory(t, proto) - - // inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - select { - case <-ep2.C: - t.Fatal("Packet should not be forwarded") - case <-time.After(time.Second): - } -} - -func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - } - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := p.Pkt.Header.View() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := p.Pkt.Header.View() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} - -func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[0] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := p.Pkt.Header.View() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) - } - // The first 5 packets should not be forwarded so the the - // sequemnce number should start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} - -func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - b := p.Pkt.Header.View() - if b[0] < 8 { - t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go deleted file mode 100644 index 1baa498d0..000000000 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sleep" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -type testaddr struct { - addr tcpip.FullAddress - linkAddr tcpip.LinkAddress -} - -var testAddrs = func() []testaddr { - var addrs []testaddr - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - addrs = append(addrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } - return addrs -}() - -type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration - onLinkAddressRequest func() -} - -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) - if f := r.onLinkAddressRequest; f != nil { - f() - } - return nil -} - -func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testAddrs { - if ta.addr.Addr == addr { - r.cache.add(ta.addr, ta.linkAddr) - break - } - } -} - -func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "broadcast" { - return "mac_broadcast", true - } - return "", false -} - -func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 1 -} - -func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - - for { - if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - return got, err - } - s.Fetch(true) - } -} - -func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testAddrs) - 1; i >= 0; i-- { - e := testAddrs[i] - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) - } - } - // Expect to find at least half of the most recent entries. - for i := 0; i < linkAddrCacheSize/2; i++ { - e := testAddrs[i] - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) - } - } - // The earliest entries should no longer be in the cache. - for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { - e := testAddrs[i] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) - } - } -} - -func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - - var wg sync.WaitGroup - for r := 0; r < 16; r++ { - wg.Add(1) - go func() { - for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) - c.get(e.addr, nil, "", nil, nil) // make work for gotsan - } - wg.Done() - }() - } - wg.Wait() - - // All goroutines add in the same order and add more values than - // can fit in the cache, so our eviction strategy requires that - // the last entry be present and the first be missing. - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - e = testAddrs[0] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testAddrs[0] - c.add(e.addr, e.linkAddr) - time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testAddrs[0] - l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - c.add(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != l2 { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) - } -} - -func TestCacheResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) - linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testAddrs { - got, err := getBlocking(c, ta.addr, linkRes) - if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) - } - if got != ta.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) - } - } - - // Check that after resolved, address stays in the cache and never returns WouldBlock. - for i := 0; i < 10; i++ { - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - } -} - -func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) - linkRes := &testLinkAddressResolver{cache: c} - - var requestCount uint32 - linkRes.onLinkAddressRequest = func() { - atomic.AddUint32(&requestCount, 1) - } - - // First, sanity check that resolution is working... - e := testAddrs[0] - got, err := getBlocking(c, e.addr, linkRes) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - before := atomic.LoadUint32(&requestCount) - - e.addr.Addr += "2" - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } - - if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { - t.Errorf("got link address request count = %d, want = %d", got, want) - } -} - -func TestCacheResolutionTimeout(t *testing.T) { - resolverDelay := 500 * time.Millisecond - expiration := resolverDelay / 10 - c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) - linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - - e := testAddrs[0] - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -// TestStaticResolution checks that static link addresses are resolved immediately and don't -// send resolution requests. -func TestStaticResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, time.Millisecond, 1) - linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute} - - addr := tcpip.Address("broadcast") - want := tcpip.LinkAddress("mac_broadcast") - got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err) - } - if got != want { - t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) - } -} diff --git a/pkg/tcpip/stack/linkaddrentry_list.go b/pkg/tcpip/stack/linkaddrentry_list.go new file mode 100755 index 000000000..6697281cd --- /dev/null +++ b/pkg/tcpip/stack/linkaddrentry_list.go @@ -0,0 +1,186 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type linkAddrEntryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (linkAddrEntryElementMapper) linkerFor(elem *linkAddrEntry) *linkAddrEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type linkAddrEntryList struct { + head *linkAddrEntry + tail *linkAddrEntry +} + +// Reset resets list l to the empty state. +func (l *linkAddrEntryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *linkAddrEntryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *linkAddrEntryList) Front() *linkAddrEntry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *linkAddrEntryList) Back() *linkAddrEntry { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *linkAddrEntryList) PushFront(e *linkAddrEntry) { + linker := linkAddrEntryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + linkAddrEntryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *linkAddrEntryList) PushBack(e *linkAddrEntry) { + linker := linkAddrEntryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + linkAddrEntryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *linkAddrEntryList) PushBackList(m *linkAddrEntryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + linkAddrEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + linkAddrEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *linkAddrEntryList) InsertAfter(b, e *linkAddrEntry) { + bLinker := linkAddrEntryElementMapper{}.linkerFor(b) + eLinker := linkAddrEntryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + linkAddrEntryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *linkAddrEntryList) InsertBefore(a, e *linkAddrEntry) { + aLinker := linkAddrEntryElementMapper{}.linkerFor(a) + eLinker := linkAddrEntryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + linkAddrEntryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *linkAddrEntryList) Remove(e *linkAddrEntry) { + linker := linkAddrEntryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + linkAddrEntryElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + linkAddrEntryElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type linkAddrEntryEntry struct { + next *linkAddrEntry + prev *linkAddrEntry +} + +// Next returns the entry that follows e in the list. +func (e *linkAddrEntryEntry) Next() *linkAddrEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *linkAddrEntryEntry) Prev() *linkAddrEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *linkAddrEntryEntry) SetNext(elem *linkAddrEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *linkAddrEntryEntry) SetPrev(elem *linkAddrEntry) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index d689a006d..d689a006d 100644..100755 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go deleted file mode 100644 index 4368c236c..000000000 --- a/pkg/tcpip/stack/ndp_test.go +++ /dev/null @@ -1,3759 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "context" - "encoding/binary" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") - defaultTimeout = 100 * time.Millisecond - defaultAsyncEventTimeout = time.Second -) - -var ( - llAddr1 = header.LinkLocalAddr(linkAddr1) - llAddr2 = header.LinkLocalAddr(linkAddr2) - llAddr3 = header.LinkLocalAddr(linkAddr3) - llAddr4 = header.LinkLocalAddr(linkAddr4) - dstAddr = tcpip.FullAddress{ - Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - Port: 25, - } -) - -func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { - if !header.IsValidUnicastEthernetAddress(linkAddr) { - return tcpip.AddressWithPrefix{} - } - - addrBytes := []byte(subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - return tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } -} - -// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent -// tcpip.Subnet, and an address where the lower half of the address is composed -// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. -func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { - prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixBytes), - PrefixLen: 64, - } - - subnet := prefix.Subnet() - - return prefix, subnet, addrForSubnet(subnet, linkAddr) -} - -// ndpDADEvent is a set of parameters that was passed to -// ndpDispatcher.OnDuplicateAddressDetectionStatus. -type ndpDADEvent struct { - nicID tcpip.NICID - addr tcpip.Address - resolved bool - err *tcpip.Error -} - -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool -} - -type ndpPrefixEvent struct { - nicID tcpip.NICID - prefix tcpip.Subnet - // true if prefix was discovered, false if invalidated. - discovered bool -} - -type ndpAutoGenAddrEventType int - -const ( - newAddr ndpAutoGenAddrEventType = iota - deprecatedAddr - invalidatedAddr -) - -type ndpAutoGenAddrEvent struct { - nicID tcpip.NICID - addr tcpip.AddressWithPrefix - eventType ndpAutoGenAddrEventType -} - -type ndpRDNSS struct { - addrs []tcpip.Address - lifetime time.Duration -} - -type ndpRDNSSEvent struct { - nicID tcpip.NICID - rdnss ndpRDNSS -} - -type ndpDHCPv6Event struct { - nicID tcpip.NICID - configuration stack.DHCPv6ConfigurationFromNDPRA -} - -var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) - -// ndpDispatcher implements NDPDispatcher so tests can know when various NDP -// related events happen for test purposes. -type ndpDispatcher struct { - dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool - prefixC chan ndpPrefixEvent - rememberPrefix bool - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent - dhcpv6ConfigurationC chan ndpDHCPv6Event -} - -// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { - if n.dadC != nil { - n.dadC <- ndpDADEvent{ - nicID, - addr, - resolved, - err, - } - } -} - -// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ - nicID, - addr, - true, - } - } - - return n.rememberRouter -} - -// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ - nicID, - addr, - false, - } - } -} - -// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { - if c := n.prefixC; c != nil { - c <- ndpPrefixEvent{ - nicID, - prefix, - true, - } - } - - return n.rememberPrefix -} - -// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. -func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { - if c := n.prefixC; c != nil { - c <- ndpPrefixEvent{ - nicID, - prefix, - false, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - newAddr, - } - } - return true -} - -func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - deprecatedAddr, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - invalidatedAddr, - } - } -} - -// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. -func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { - if c := n.rdnssC; c != nil { - c <- ndpRDNSSEvent{ - nicID, - ndpRDNSS{ - addrs, - lifetime, - }, - } - } -} - -// Implements stack.NDPDispatcher.OnDHCPv6Configuration. -func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { - if c := n.dhcpv6ConfigurationC; c != nil { - c <- ndpDHCPv6Event{ - nicID, - configuration, - } - } -} - -// channelLinkWithHeaderLength is a channel.Endpoint with a configurable -// header length. -type channelLinkWithHeaderLength struct { - *channel.Endpoint - headerLength uint16 -} - -func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { - return l.headerLength -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { - return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) -} - -// TestDADDisabled tests that an address successfully resolves immediately -// when DAD is not enabled (the default for an empty stack.Options). -func TestDADDisabled(t *testing.T) { - const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Should get the address immediately since we should not have performed - // DAD on it. - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) - } - - // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { - t.Fatalf("got NeighborSolicit = %d, want = 0", got) - } -} - -// TestDADResolve tests that an address successfully resolves after performing -// DAD for various values of DupAddrDetectTransmits and RetransmitTimer. -// Included in the subtests is a test to make sure that an invalid -// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. -// This tests also validates the NDP NS packet that is transmitted. -func TestDADResolve(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - linkHeaderLen uint16 - dupAddrDetectTransmits uint8 - retransTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - name: "1:1s:1s", - dupAddrDetectTransmits: 1, - retransTimer: time.Second, - expectedRetransmitTimer: time.Second, - }, - { - name: "2:1s:1s", - linkHeaderLen: 1, - dupAddrDetectTransmits: 2, - retransTimer: time.Second, - expectedRetransmitTimer: time.Second, - }, - { - name: "1:2s:2s", - linkHeaderLen: 2, - dupAddrDetectTransmits: 1, - retransTimer: 2 * time.Second, - expectedRetransmitTimer: 2 * time.Second, - }, - // 0s is an invalid RetransmitTimer timer and will be fixed to - // the default RetransmitTimer value of 1s. - { - name: "1:0s:1s", - linkHeaderLen: 3, - dupAddrDetectTransmits: 1, - retransTimer: 0, - expectedRetransmitTimer: time.Second, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = test.retransTimer - opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(opts) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet - // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } - - // Wait for the remaining time - some delta (500ms), to - // make sure the address is still not resolved. - const delta = 500 * time.Millisecond - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } - - // Wait for DAD to resolve. - select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) - } - - // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { - t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) - } - - // Validate the sent Neighbor Solicitation messages. - for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, _ := e.ReadContext(context.Background()) - - // Make sure its an IPv6 packet. - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Make sure the right remote link address is used. - snmc := header.SolicitedNodeAddr(addr1) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - // Check NDP NS packet. - // - // As per RFC 4861 section 4.3, a possible option is the Source Link - // Layer option, but this option MUST NOT be included when the source - // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions(nil), - )) - - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) - } - } - }) - } -} - -// TestDADFail tests to make sure that the DAD process fails if another node is -// detected to be performing DAD on the same address (receive an NS message from -// a node doing DAD for the same address), or if another node is detected to own -// the address already (receive an NA message for the tentative address). -func TestDADFail(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - makeBuf func(tgt tcpip.Address) buffer.Prependable - getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - }{ - { - "RxSolicit", - func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - - return hdr - - }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborSolicit - }, - }, - { - "RxAdvert", - func(tgt tcpip.Address) buffer.Prependable { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) - pkt := header.ICMPv6(hdr.Prepend(naSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(tgt) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - return hdr - - }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborAdvert - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = time.Second * 2 - - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet - // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } - - // Receive a packet to simulate multiple nodes owning or - // attempting to own the same address. - hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) - if got := stat.Value(); got != 1 { - t.Fatalf("got stat = %d, want = 1", got) - } - - // Wait for DAD to fail and make sure the address did - // not get resolved. - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } - }) - } -} - -func TestDADStop(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - stopFn func(t *testing.T, s *stack.Stack) - skipFinalAddrCheck bool - }{ - // Tests to make sure that DAD stops when an address is removed. - { - name: "Remove address", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) - } - }, - }, - - // Tests to make sure that DAD stops when the NIC is disabled. - { - name: "Disable NIC", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - }, - }, - - // Tests to make sure that DAD stops when the NIC is removed. - { - name: "Remove NIC", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("RemoveNIC(%d): %s", nicID, err) - } - }, - // The NIC is removed so we can't check its addresses after calling - // stopFn. - skipFinalAddrCheck: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - ndpConfigs := stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } - - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } - - test.stopFn(t, s) - - // Wait for DAD to fail (since the address was removed during DAD). - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - - if !test.skipFinalAddrCheck { - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } - } - - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Errorf("got NeighborSolicit = %d, want <= 1", got) - } - }) - } -} - -// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NDP configurations using an invalid NICID. -func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - - // No NIC with ID 1 yet. - if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) - } -} - -// TestSetNDPConfigurations tests that we can update and use per-interface NDP -// configurations without affecting the default NDP configurations or other -// interfaces' configurations. -func TestSetNDPConfigurations(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const nicID3 = 3 - - tests := []struct { - name string - dupAddrDetectTransmits uint8 - retransmitTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - "OK", - 1, - time.Second, - time.Second, - }, - { - "Invalid Retransmit Timer", - 1, - 0, - time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - }) - - expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("expected DAD event for %s", addr) - } - } - - // This NIC(1)'s NDP configurations will be updated to - // be different from the default. - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - - // Created before updating NIC(1)'s NDP configurations - // but updating NIC(1)'s NDP configurations should not - // affect other existing NICs. - if err := s.CreateNIC(nicID2, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - - // Update the NDP configurations on NIC(1) to use DAD. - configs := stack.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - RetransmitTimer: test.retransmitTimer, - } - if err := s.SetNDPConfigurations(nicID1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err) - } - - // Created after updating NIC(1)'s NDP configurations - // but the stack's default NDP configurations should not - // have been updated. - if err := s.CreateNIC(nicID3, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err) - } - - // Add addresses for each NIC. - if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err) - } - if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err) - } - expectDADEvent(nicID2, addr2) - if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err) - } - expectDADEvent(nicID3, addr3) - - // Address should not be considered bound to NIC(1) yet - // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) - } - - // Should get the address on NIC(2) and NIC(3) - // immediately since we should not have performed DAD on - // it as the stack was configured to not do DAD by - // default and we only updated the NDP configurations on - // NIC(1). - addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2) - } - addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3) - } - - // Sleep until right (500ms before) before resolution to - // make sure the address didn't resolve on NIC(1) yet. - const delta = 500 * time.Millisecond - time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) - } - - // Wait for DAD to resolve. - select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1) - } - }) - } -} - -// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options -// and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(0) - raPayload := pkt.NDPPayload() - ra := header.NDPRouterAdvert(raPayload) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[2:], rl) - // Populate the Managed Address flag field. - if managedAddress { - // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) - // of the RA payload. - raPayload[1] |= (1 << 7) - } - // Populate the Other Configurations flag field. - if otherConfigurations { - // The Other Configurations flag field is the 6th bit of byte #1 - // (0-indexing) of the RA payload. - raPayload[1] |= (1 << 6) - } - opts := ra.Options() - opts.Serialize(optSer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} -} - -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) -} - -// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related -// fields set. -// -// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the -// DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) -} - -// raBuf returns a valid NDP Router Advertisement. -// -// Note, raBuf does not populate any of the RA fields other than the -// Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) -} - -// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix -// Information option. -// -// Note, raBufWithPI does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { - flags := uint8(0) - if onLink { - // The OnLink flag is the 7th bit in the flags byte. - flags |= 1 << 7 - } - if auto { - // The Address Auto-Configuration flag is the 6th bit in the - // flags byte. - flags |= 1 << 6 - } - - // A valid header.NDPPrefixInformation must be 30 bytes. - buf := [30]byte{} - // The first byte in a header.NDPPrefixInformation is the Prefix Length - // field. - buf[0] = uint8(prefix.PrefixLen) - // The 2nd byte within a header.NDPPrefixInformation is the Flags field. - buf[1] = flags - // The Valid Lifetime field starts after the 2nd byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[2:], vl) - // The Preferred Lifetime field starts after the 6th byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[6:], pl) - // The Prefix Address field starts after the 14th byte within a - // header.NDPPrefixInformation. - copy(buf[14:], prefix.Address) - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ - header.NDPPrefixInformation(buf[:]), - }) -} - -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - }) - s.SetForwarding(forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - default: - } - }) - } -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// discovered flag set to discovered. -func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { - return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) -} - -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA for a router we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr2, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the router in the first place. - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): - } -} - -func TestRouterDiscovery(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - }) - - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - } - - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for router discovery event") - } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) - - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) - - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } - - // Wait for lladdr2's router invalidation timer to fire. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout) - - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) - - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) - - // Wait for lladdr3's router invalidation timer to fire. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout) -} - -// TestRouterDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. -func TestRouterDiscoveryMaxRouters(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { - linkAddr := []byte{2, 2, 3, 4, 5, 0} - linkAddr[5] = byte(i) - llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) - - if i <= stack.MaxDiscoveredDefaultRouters { - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - } else { - select { - case <-ndpDisp.routerC: - t.Fatal("should not have discovered a new router after we already discovered the max number of routers") - default: - } - } - } -} - -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - }) - s.SetForwarding(forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - default: - } - }) - } -} - -// Check e to make sure that the event is for prefix on nic with ID 1, and the -// discovered flag set to discovered. -func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { - return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) -} - -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - t.Parallel() - - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix that we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the prefix in the first place. - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): - } -} - -func TestPrefixDiscovery(t *testing.T) { - t.Parallel() - - prefix1, subnet1, _ := prefixSubnetAddr(0, "") - prefix2, subnet2, _ := prefixSubnetAddr(1, "") - prefix3, subnet3, _ := prefixSubnetAddr(2, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) - - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) - - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) - - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } - - // Wait for prefix2's most recent invalidation timer plus some buffer to - // expire. - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout): - t.Fatal("timed out waiting for prefix discovery event") - } - - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) -} - -func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - subnet := prefix.Subnet() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } - - // Receive an RA with prefix in an NDP Prefix Information option (PI) - // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - expectPrefixEvent(subnet, true) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): - } - - // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - case <-time.After(testInfiniteLifetime): - t.Fatal("timed out waiting for prefix discovery event") - } - - // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - expectPrefixEvent(subnet, true) - - // Receive an RA with prefix with an infinite lifetime. - // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): - } - - // Receive an RA with 0 lifetime. - // The prefix should get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) - expectPrefixEvent(subnet, false) -} - -// TestPrefixDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. -func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} - - // Receive an RA with 2 more than the max number of discovered on-link - // prefixes. - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { - prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} - prefixAddr[7] = byte(i) - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixAddr[:]), - PrefixLen: 64, - } - prefixes[i] = prefix.Subnet() - buf := [30]byte{} - buf[0] = uint8(prefix.PrefixLen) - buf[1] = 128 - binary.BigEndian.PutUint32(buf[2:], 10) - copy(buf[14:], prefix.Address) - - optSer[i] = header.NDPPrefixInformation(buf[:]) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < stack.MaxDiscoveredOnLinkPrefixes { - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } else { - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") - default: - } - } - } -} - -// Checks to see if list contains an IPv6 address, item. -func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: item, - } - - for _, i := range list { - if i == protocolAddress { - return true - } - } - - return false -} - -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - }) - s.SetForwarding(forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - default: - } - }) - } -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// event type is set to eventType. -func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { - return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) -} - -// TestAutoGenAddr tests that an address is properly generated and invalidated -// when configured to do so. -func TestAutoGenAddr(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved - }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } - - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } - - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } -} - -// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, -// channel.Endpoint and stack.Stack. -// -// stack.Stack will have a default route through the router (llAddr3) installed -// and a static link-address (linkAddr3) added to the link address cache for the -// router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { - t.Helper() - ndpDisp := &ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: ndpDisp, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) - return ndpDisp, e, s -} - -// addrForNewConnectionTo returns the local address used when creating a new -// connection to addr. -func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } - if err := ep.Connect(addr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", addr, err) - } - got, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("ep.GetLocalAddress(): %s", err) - } - return got.Addr -} - -// addrForNewConnection returns the local address used when creating a new -// connection. -func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { - t.Helper() - - return addrForNewConnectionTo(t, s, dstAddr) -} - -// addrForNewConnectionWithAddr returns the local address used when creating a -// new connection with a specific local address. -func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } - if err := ep.Bind(addr); err != nil { - t.Fatalf("ep.Bind(%+v): %s", addr, err) - } - if err := ep.Connect(dstAddr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) - } - got, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("ep.GetLocalAddress(): %s", err) - } - return got.Addr -} - -// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when -// receiving a PI with 0 preferred lifetime. -func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) - - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) - - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) - } - - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) - } - - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) -} - -// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated -// when its preferred lifetime expires. -func TestAutoGenAddrTimerDeprecation(t *testing.T) { - const nicID = 1 - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved - }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) - - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) - } - - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) - } - - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) - - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) - } - - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event") - } - - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); got != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } - - if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) - } -} - -// Tests transitioning a SLAAC address's valid lifetime between finite and -// infinite values. -func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { - const infiniteVLSeconds = 2 - const minVLSeconds = 1 - savedIL := header.NDPInfiniteLifetime - savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL - header.NDPInfiniteLifetime = savedIL - }() - stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second - header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - infiniteVL uint32 - }{ - { - name: "EqualToInfiniteVL", - infiniteVL: infiniteVLSeconds, - }, - // Our implementation supports changing header.NDPInfiniteLifetime for tests - // such that a packet can be received where the lifetime field has a value - // greater than header.NDPInfiniteLifetime. Because of this, we test to make - // sure that receiving a value greater than header.NDPInfiniteLifetime is - // handled the same as when receiving a value equal to - // header.NDPInfiniteLifetime. - { - name: "MoreThanInfiniteVL", - infiniteVL: infiniteVLSeconds + 1, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - default: - t.Fatal("expected addr auto gen event") - } - - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } - }) -} - -// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an -// auto-generated address only gets updated when required to, as specified in -// RFC 4862 section 5.5.3.e. -func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { - const infiniteVL = 4294967295 - const newMinVL = 4 - saved := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved - }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - ovl uint32 - nvl uint32 - evl uint32 - }{ - // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than - // it. - { - "LargeVLToVLLessThanMinVLForUpdate", - 9999, - 1, - newMinVL, - }, - { - "LargeVLTo0", - 9999, - 0, - newMinVL, - }, - { - "InfiniteVLToVLLessThanMinVLForUpdate", - infiniteVL, - 1, - newMinVL, - }, - { - "InfiniteVLTo0", - infiniteVL, - 0, - newMinVL, - }, - - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. - { - "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, - }, - - // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. - { - "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, - }, - { - "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, - }, - { - "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, - }, - } - - const delta = 500 * time.Millisecond - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - - // - // Validate that the VL for the address got set - // to test.evl. - // - - // Make sure we do not get any invalidation - // events until atleast 500ms (delta) before - // test.evl. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - delta): - } - - // Wait for another second (2x delta), but now - // we expect the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - case <-time.After(2 * delta): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } - }) -} - -// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed -// by the user, its resources will be cleaned up and an invalidation event will -// be sent to the integrator. -func TestAutoGenAddrRemoval(t *testing.T) { - t.Parallel() - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive a PI to auto-generate an address. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr, newAddr) - - // Removing the address should result in an invalidation event - // immediately. - if err := s.RemoveAddress(1, addr.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - - // Wait for the original valid lifetime to make sure the original timer - // got stopped/cleaned up. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): - } -} - -// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously -// assigned to the NIC but is in the permanentExpired state. -func TestAutoGenAddrAfterRemoval(t *testing.T) { - t.Parallel() - - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) - - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) - - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) - - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) - - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) - - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) -} - -// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that -// is already assigned to the NIC, the static address remains. -func TestAutoGenAddrStaticConflict(t *testing.T) { - t.Parallel() - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive a PI where the generated address will be the same as the one - // that we already have assigned statically. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") - default: - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Should not get an invalidation event after the PI's invalidation - // time. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } -} - -// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use -// opaque interface identifiers when configured to do so. -func TestAutoGenAddrWithOpaqueIID(t *testing.T) { - t.Parallel() - - const nicID = 1 - const nicName = "nic1" - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } - - prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) - prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) - // addr1 and addr2 are the addresses that are expected to be generated when - // stack.Stack is configured to generate opaque interface identifiers as - // defined by RFC 7217. - addrBytes := []byte(subnet1.ID()) - addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), - PrefixLen: 64, - } - addrBytes = []byte(subnet2.ID()) - addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), - PrefixLen: 64, - } - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix1 in a PI. - const validLifetimeSecondPrefix1 = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in a PI with a large valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } -} - -// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event -// to the integrator when an RA is received with the NDP Recursive DNS Server -// option with at least one valid address. -func TestNDPRecursiveDNSServerDispatch(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - opt header.NDPRecursiveDNSServer - expected *ndpRDNSS - }{ - { - "Unspecified", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }), - nil, - }, - { - "Multicast", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }), - nil, - }, - { - "OptionTooSmall", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, - }), - nil, - }, - { - "0Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - }), - nil, - }, - { - "Valid1Address", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - }, - 2 * time.Second, - }, - }, - { - "Valid2Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - }, - time.Second, - }, - }, - { - "Valid3Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 0, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", - }, - 0, - }, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - // We do not expect more than a single RDNSS - // event at any time for this test. - rdnssC: make(chan ndpRDNSSEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) - - if test.expected != nil { - select { - case e := <-ndpDisp.rdnssC: - if e.nicID != 1 { - t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) - } - if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { - t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) - } - if e.rdnss.lifetime != test.expected.lifetime { - t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) - } - default: - t.Fatal("expected an RDNSS option event") - } - } - - // Should have no more RDNSS options. - select { - case e := <-ndpDisp.rdnssC: - t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) - default: - } - }) - } -} - -// TestCleanupNDPState tests that all discovered routers and prefixes, and -// auto-generated addresses are invalidated when a NIC becomes a router. -func TestCleanupNDPState(t *testing.T) { - t.Parallel() - - const ( - lifetimeSeconds = 5 - maxRouterAndPrefixEvents = 4 - nicID1 = 1 - nicID2 = 2 - ) - - prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) - e2Addr1 := addrForSubnet(subnet1, linkAddr2) - e2Addr2 := addrForSubnet(subnet2, linkAddr2) - llAddrWithPrefix1 := tcpip.AddressWithPrefix{ - Address: llAddr1, - PrefixLen: 64, - } - llAddrWithPrefix2 := tcpip.AddressWithPrefix{ - Address: llAddr2, - PrefixLen: 64, - } - - tests := []struct { - name string - cleanupFn func(t *testing.T, s *stack.Stack) - keepAutoGenLinkLocal bool - maxAutoGenAddrEvents int - skipFinalAddrCheck bool - }{ - // A NIC should still keep its auto-generated link-local address when - // becoming a router. - { - name: "Enable forwarding", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(true) - }, - keepAutoGenLinkLocal: true, - maxAutoGenAddrEvents: 4, - }, - - // A NIC should cleanup all NDP state when it is disabled. - { - name: "Disable NIC", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - }, - keepAutoGenLinkLocal: false, - maxAutoGenAddrEvents: 6, - }, - - // A NIC should cleanup all NDP state when it is removed. - { - name: "Remove NIC", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.RemoveNIC(nicID1); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err) - } - if err := s.RemoveNIC(nicID2); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err) - } - }, - keepAutoGenLinkLocal: false, - maxAutoGenAddrEvents: 6, - // The NICs are removed so we can't check their addresses after calling - // stopFn. - skipFinalAddrCheck: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - expectRouterEvent := func() (bool, ndpRouterEvent) { - select { - case e := <-ndpDisp.routerC: - return true, e - default: - } - - return false, ndpRouterEvent{} - } - - expectPrefixEvent := func() (bool, ndpPrefixEvent) { - select { - case e := <-ndpDisp.prefixC: - return true, e - default: - } - - return false, ndpPrefixEvent{} - } - - expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { - select { - case e := <-ndpDisp.autoGenAddrC: - return true, e - default: - } - - return false, ndpAutoGenAddrEvent{} - } - - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - // We have other tests that make sure we receive the *correct* events - // on normal discovery of routers/prefixes, and auto-generated - // addresses. Here we just make sure we get an event and let other tests - // handle the correctness check. - expectAutoGenAddrEvent() - - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectAutoGenAddrEvent() - - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and - // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from - // llAddr4) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) - } - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) - } - - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) - } - - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) - } - - // We should have the auto-generated addresses added. - nicinfo := s.NICInfo() - nic1Addrs := nicinfo[nicID1].ProtocolAddresses - nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } - - // We can't proceed any further if we already failed the test (missing - // some discovery/auto-generated address events or addresses). - if t.Failed() { - t.FailNow() - } - - test.cleanupFn(t, s) - - // Collect invalidation events after having NDP state cleaned up. - gotRouterEvents := make(map[ndpRouterEvent]int) - for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectRouterEvent() - if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) - break - } - gotRouterEvents[e]++ - } - gotPrefixEvents := make(map[ndpPrefixEvent]int) - for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectPrefixEvent() - if !ok { - t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) - break - } - gotPrefixEvents[e]++ - } - gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) - for i := 0; i < test.maxAutoGenAddrEvents; i++ { - ok, e := expectAutoGenAddrEvent() - if !ok { - t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) - break - } - gotAutoGenAddrEvents[e]++ - } - - // No need to proceed any further if we already failed the test (missing - // some invalidation events). - if t.Failed() { - t.FailNow() - } - - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, - } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) - } - expectedPrefixEvents := map[ndpPrefixEvent]int{ - {nicID: nicID1, prefix: subnet1, discovered: false}: 1, - {nicID: nicID1, prefix: subnet2, discovered: false}: 1, - {nicID: nicID2, prefix: subnet1, discovered: false}: 1, - {nicID: nicID2, prefix: subnet2, discovered: false}: 1, - } - if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { - t.Errorf("prefix events mismatch (-want +got):\n%s", diff) - } - expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ - {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, - } - - if !test.keepAutoGenLinkLocal { - expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 - expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 - } - - if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { - t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) - } - - if !test.skipFinalAddrCheck { - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } - } - - // Should not get any more events (invalidation timers should have been - // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) - select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") - default: - } - select { - case <-ndpDisp.prefixC: - t.Error("unexpected prefix event") - default: - } - select { - case <-ndpDisp.autoGenAddrC: - t.Error("unexpected auto-generated address event") - default: - } - }) - } -} - -// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly -// informed when new information about what configurations are available via -// DHCPv6 is learned. -func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { - t.Helper() - select { - case e := <-ndpDisp.dhcpv6ConfigurationC: - if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { - t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DHCPv6 configuration event") - } - } - - expectNoDHCPv6Event := func() { - t.Helper() - select { - case <-ndpDisp.dhcpv6ConfigurationC: - t.Fatal("unexpected DHCPv6 configuration event") - default: - } - } - - // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) - // Receiving the same update again should not result in an event to the - // NDPDispatcher. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Managed Address. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to none. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Managed Address. - // - // Note, when the M flag is set, the O flag is redundant. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectNoDHCPv6Event() - // Even though the DHCPv6 flags are different, the effective configuration is - // the same so we should not receive a new event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectNoDHCPv6Event() - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() -} - -// TestRouterSolicitation tests the initial Router Solicitations that are sent -// when a NIC newly becomes enabled. -func TestRouterSolicitation(t *testing.T) { - t.Parallel() - - const nicID = 1 - - tests := []struct { - name string - linkHeaderLen uint16 - linkAddr tcpip.LinkAddress - nicAddr tcpip.Address - expectedSrcAddr tcpip.Address - expectedNDPOpts []header.NDPOption - maxRtrSolicit uint8 - rtrSolicitInt time.Duration - effectiveRtrSolicitInt time.Duration - maxRtrSolicitDelay time.Duration - effectiveMaxRtrSolicitDelay time.Duration - }{ - { - name: "Single RS with delay", - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 1, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: time.Second, - effectiveMaxRtrSolicitDelay: time.Second, - }, - { - name: "Two RS with delay", - linkHeaderLen: 1, - nicAddr: llAddr1, - expectedSrcAddr: llAddr1, - maxRtrSolicit: 2, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: 500 * time.Millisecond, - effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, - }, - { - name: "Single RS without delay", - linkHeaderLen: 2, - linkAddr: linkAddr1, - nicAddr: llAddr1, - expectedSrcAddr: llAddr1, - expectedNDPOpts: []header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }, - maxRtrSolicit: 1, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Two RS without delay and invalid zero interval", - linkHeaderLen: 3, - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 2, - rtrSolicitInt: 0, - effectiveRtrSolicitInt: 4 * time.Second, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Three RS without delay", - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 3, - rtrSolicitInt: 500 * time.Millisecond, - effectiveRtrSolicitInt: 500 * time.Millisecond, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Two RS with invalid negative delay", - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 2, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: -3 * time.Second, - effectiveMaxRtrSolicitDelay: time.Second, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - checker.IPv6(t, - p.Pkt.Header.View(), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) - - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet") - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } - - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) - remaining-- - } - - for ; remaining > 0; remaining-- { - waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) - waitForPkt(defaultAsyncEventTimeout) - } - - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout) - } - - // Make sure the counter got properly - // incremented. - if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } - }) -} - -func TestStopStartSolicitingRouters(t *testing.T) { - t.Parallel() - - const nicID = 1 - const interval = 500 * time.Millisecond - const delay = time.Second - const maxRtrSolicitations = 3 - - tests := []struct { - name string - startFn func(t *testing.T, s *stack.Stack) - // first is used to tell stopFn that it is being called for the first time - // after router solicitations were last enabled. - stopFn func(t *testing.T, s *stack.Stack, first bool) - }{ - // Tests that when forwarding is enabled or disabled, router solicitations - // are stopped or started, respectively. - { - name: "Enable and disable forwarding", - startFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(false) - }, - stopFn: func(t *testing.T, s *stack.Stack, _ bool) { - t.Helper() - s.SetForwarding(true) - }, - }, - - // Tests that when a NIC is enabled or disabled, router solicitations - // are started or stopped, respectively. - { - name: "Enable and disable NIC", - startFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - }, - stopFn: func(t *testing.T, s *stack.Stack, _ bool) { - t.Helper() - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - }, - }, - - // Tests that when a NIC is removed, router solicitations are stopped. We - // cannot start router solications on a removed NIC. - { - name: "Remove NIC", - stopFn: func(t *testing.T, s *stack.Stack, first bool) { - t.Helper() - - // Only try to remove the NIC the first time stopFn is called since it's - // impossible to remove an already removed NIC. - if !first { - return - } - - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, - }, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Stop soliciting routers. - test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - // A single RS may have been sent before forwarding was enabled. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout) - defer cancel() - if _, ok = e.ReadContext(ctx); ok { - t.Fatal("should not have sent more than one RS message") - } - } - - // Stopping router solicitations after it has already been stopped should - // do nothing. - test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") - } - - // If test.startFn is nil, there is no way to restart router solications. - if test.startFn == nil { - return - } - - // Start soliciting routers. - test.startFn(t, s) - waitForPkt(delay + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - } - - // Starting router solicitations after it has already completed should do - // nothing. - test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after finishing router solicitations") - } - }) - } -} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go deleted file mode 100644 index edaee3b86..000000000 --- a/pkg/tcpip/stack/nic_test.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { - // When the NIC is disabled, the only field that matters is the stats field. - // This test is limited to stats counter checks. - nic := NIC{ - stats: makeNICStats(), - } - - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { - t.Errorf("got DisabledRx.Packets = %d, want = 0", got) - } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { - t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) - } - if got := nic.stats.Rx.Packets.Value(); got != 0 { - t.Errorf("got Rx.Packets = %d, want = 0", got) - } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { - t.Errorf("got Rx.Bytes = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) - - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { - t.Errorf("got DisabledRx.Packets = %d, want = 1", got) - } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { - t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) - } - if got := nic.stats.Rx.Packets.Value(); got != 0 { - t.Errorf("got Rx.Packets = %d, want = 0", got) - } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { - t.Errorf("got Rx.Bytes = %d, want = 0", got) - } -} diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go new file mode 100755 index 000000000..3e28a4e34 --- /dev/null +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -0,0 +1,129 @@ +// automatically generated by stateify. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *linkAddrEntryList) beforeSave() {} +func (x *linkAddrEntryList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *linkAddrEntryList) afterLoad() {} +func (x *linkAddrEntryList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *linkAddrEntryEntry) beforeSave() {} +func (x *linkAddrEntryEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *linkAddrEntryEntry) afterLoad() {} +func (x *linkAddrEntryEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *TransportEndpointID) beforeSave() {} +func (x *TransportEndpointID) save(m state.Map) { + x.beforeSave() + m.Save("LocalPort", &x.LocalPort) + m.Save("LocalAddress", &x.LocalAddress) + m.Save("RemotePort", &x.RemotePort) + m.Save("RemoteAddress", &x.RemoteAddress) +} + +func (x *TransportEndpointID) afterLoad() {} +func (x *TransportEndpointID) load(m state.Map) { + m.Load("LocalPort", &x.LocalPort) + m.Load("LocalAddress", &x.LocalAddress) + m.Load("RemotePort", &x.RemotePort) + m.Load("RemoteAddress", &x.RemoteAddress) +} + +func (x *GSOType) save(m state.Map) { + m.SaveValue("", (int)(*x)) +} + +func (x *GSOType) load(m state.Map) { + m.LoadValue("", new(int), func(y interface{}) { *x = (GSOType)(y.(int)) }) +} + +func (x *GSO) beforeSave() {} +func (x *GSO) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("NeedsCsum", &x.NeedsCsum) + m.Save("CsumOffset", &x.CsumOffset) + m.Save("MSS", &x.MSS) + m.Save("L3HdrLen", &x.L3HdrLen) + m.Save("MaxSize", &x.MaxSize) +} + +func (x *GSO) afterLoad() {} +func (x *GSO) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("NeedsCsum", &x.NeedsCsum) + m.Load("CsumOffset", &x.CsumOffset) + m.Load("MSS", &x.MSS) + m.Load("L3HdrLen", &x.L3HdrLen) + m.Load("MaxSize", &x.MaxSize) +} + +func (x *TransportEndpointInfo) beforeSave() {} +func (x *TransportEndpointInfo) save(m state.Map) { + x.beforeSave() + m.Save("NetProto", &x.NetProto) + m.Save("TransProto", &x.TransProto) + m.Save("ID", &x.ID) + m.Save("BindNICID", &x.BindNICID) + m.Save("BindAddr", &x.BindAddr) + m.Save("RegisterNICID", &x.RegisterNICID) +} + +func (x *TransportEndpointInfo) afterLoad() {} +func (x *TransportEndpointInfo) load(m state.Map) { + m.Load("NetProto", &x.NetProto) + m.Load("TransProto", &x.TransProto) + m.Load("ID", &x.ID) + m.Load("BindNICID", &x.BindNICID) + m.Load("BindAddr", &x.BindAddr) + m.Load("RegisterNICID", &x.RegisterNICID) +} + +func (x *multiPortEndpoint) beforeSave() {} +func (x *multiPortEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("demux", &x.demux) + m.Save("netProto", &x.netProto) + m.Save("transProto", &x.transProto) + m.Save("endpoints", &x.endpoints) + m.Save("reuse", &x.reuse) +} + +func (x *multiPortEndpoint) afterLoad() {} +func (x *multiPortEndpoint) load(m state.Map) { + m.Load("demux", &x.demux) + m.Load("netProto", &x.netProto) + m.Load("transProto", &x.transProto) + m.Load("endpoints", &x.endpoints) + m.Load("reuse", &x.reuse) +} + +func init() { + state.Register("pkg/tcpip/stack.linkAddrEntryList", (*linkAddrEntryList)(nil), state.Fns{Save: (*linkAddrEntryList).save, Load: (*linkAddrEntryList).load}) + state.Register("pkg/tcpip/stack.linkAddrEntryEntry", (*linkAddrEntryEntry)(nil), state.Fns{Save: (*linkAddrEntryEntry).save, Load: (*linkAddrEntryEntry).load}) + state.Register("pkg/tcpip/stack.TransportEndpointID", (*TransportEndpointID)(nil), state.Fns{Save: (*TransportEndpointID).save, Load: (*TransportEndpointID).load}) + state.Register("pkg/tcpip/stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load}) + state.Register("pkg/tcpip/stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load}) + state.Register("pkg/tcpip/stack.TransportEndpointInfo", (*TransportEndpointInfo)(nil), state.Fns{Save: (*TransportEndpointInfo).save, Load: (*TransportEndpointInfo).load}) + state.Register("pkg/tcpip/stack.multiPortEndpoint", (*multiPortEndpoint)(nil), state.Fns{Save: (*multiPortEndpoint).save, Load: (*multiPortEndpoint).load}) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go deleted file mode 100644 index 9836b340f..000000000 --- a/pkg/tcpip/stack/stack_test.go +++ /dev/null @@ -1,3280 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package stack_test contains tests for the stack. It is in its own package so -// that the tests can also validate that all definitions needed to implement -// transport and network protocols are properly exported by the stack package. -package stack_test - -import ( - "bytes" - "fmt" - "math" - "sort" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fakeNetHeaderLen = 12 - fakeDefaultPrefixLen = 8 - - // fakeControlProtocol is used for control packets that represent - // destination port unreachable. - fakeControlProtocol tcpip.TransportProtocolNumber = 2 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fakeNetworkEndpoint struct { - nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int - proto *fakeNetworkProtocol - dispatcher stack.TransportDispatcher - ep stack.LinkEndpoint -} - -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) -} - -func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID -} - -func (f *fakeNetworkEndpoint) PrefixLen() int { - return f.prefixLen -} - -func (*fakeNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { - return &f.id -} - -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { - // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - - // Consume the network header. - b := pkt.Data.First() - pkt.Data.TrimFront(fakeNetHeaderLen) - - // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { - return - } - - pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) - return - } - - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) -} - -func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fakeNetHeaderLen -} - -func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - -func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.ep.Capabilities() -} - -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { - // Increment the sent packet count in the protocol descriptor. - f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ - - // Add the protocol's header to the packet and send it to the link - // endpoint. - b := pkt.Header.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) - - if r.Loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) - } - if r.Loop&stack.PacketOut == 0 { - return nil - } - - return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { - panic("not implemented") -} - -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported -} - -func (*fakeNetworkEndpoint) Close() {} - -type fakeNetGoodOption bool - -type fakeNetBadOption bool - -type fakeNetInvalidValueOption int - -type fakeNetOptions struct { - good bool -} - -// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the -// number of packets sent and received via endpoints of this protocol. The index -// where packets are added is given by the packet's destination address MOD 10. -type fakeNetworkProtocol struct { - packetCount [10]int - sendPacketCount [10]int - opts fakeNetOptions -} - -func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fakeNetNumber -} - -func (f *fakeNetworkProtocol) MinimumPacketSize() int { - return fakeNetHeaderLen -} - -func (f *fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - -func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { - return f.packetCount[int(intfAddr)%len(f.packetCount)] -} - -func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) -} - -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - return &fakeNetworkEndpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - proto: f, - dispatcher: dispatcher, - ep: ep, - }, nil -} - -func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case fakeNetGoodOption: - f.opts.good = bool(v) - return nil - case fakeNetInvalidValueOption: - return tcpip.ErrInvalidOptionValue - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *fakeNetGoodOption: - *v = fakeNetGoodOption(f.opts.good) - return nil - default: - return tcpip.ErrUnknownProtocolOption - } -} - -// Close implements TransportProtocol.Close. -func (*fakeNetworkProtocol) Close() {} - -// Wait implements TransportProtocol.Wait. -func (*fakeNetworkProtocol) Wait() {} - -func fakeNetFactory() stack.NetworkProtocol { - return &fakeNetworkProtocol{} -} - -// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify -// that LinkEndpoint.Attach was called. -type linkEPWithMockedAttach struct { - stack.LinkEndpoint - attached bool -} - -// Attach implements stack.LinkEndpoint.Attach. -func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { - l.LinkEndpoint.Attach(d) - l.attached = d != nil -} - -func (l *linkEPWithMockedAttach) isAttached() bool { - return l.attached -} - -func TestNetworkReceive(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and two - // addresses attached to it: 1 & 2. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Make sure packet with wrong address is not delivered. - buf[0] = 3 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to first endpoint. - buf[0] = 1 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to second endpoint. - buf[0] = 2 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet that is too small is dropped. - buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } -} - -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { - r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - return send(r, payload) -} - -func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }) -} - -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := sendTo(s, addr, payload); err != nil { - t.Error("sendTo failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("sendTo packet count: got = %d, want %d", got, want) - } -} - -func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := send(r, payload); err != nil { - t.Error("send failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("send packet count: got = %d, want %d", got, want) - } -} - -func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { - t.Helper() - if gotErr := send(r, payload); gotErr != wantErr { - t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { - t.Helper() - if gotErr := sendTo(s, addr, payload); gotErr != wantErr { - t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we expect to receive it. - want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we do NOT expect to receive it. - want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { - t.Helper() - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if got := fakeNet.PacketCount(localAddrByte); got != want { - t.Errorf("receive packet count: got = %d, want %d", got, want) - } -} - -func TestNetworkSend(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and one - // address: 1. The route table sends all packets through the only - // existing nic. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("NewNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", ep, nil) -} - -func TestNetworkSendMultiRoute(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Send a packet to an odd destination. - testSendTo(t, s, "\x05", ep1, nil) - - // Send a packet to an even destination. - testSendTo(t, s, "\x06", ep2, nil) -} - -func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { - r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - defer r.Release() - - if r.LocalAddress != expectedSrcAddr { - t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress) - } - - if r.RemoteAddress != dstAddr { - t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress) - } -} - -func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { - _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute) - } -} - -// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to -// a NetworkDispatcher when the NIC is created. -func TestAttachToLinkEndpointImmediately(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - nicOpts stack.NICOptions - }{ - { - name: "Create enabled NIC", - nicOpts: stack.NICOptions{Disabled: false}, - }, - { - name: "Create disabled NIC", - nicOpts: stack.NICOptions{Disabled: true}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - - if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } - }) - } -} - -func TestDisableUnknownNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) - } -} - -func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - e := loopback.New() - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - checkNIC := func(enabled bool) { - t.Helper() - - allNICInfo := s.NICInfo() - nicInfo, ok := allNICInfo[nicID] - if !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } else if nicInfo.Flags.Running != enabled { - t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) - } - - if got := s.CheckNIC(nicID); got != enabled { - t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) - } - } - - // NIC should initially report itself as disabled. - checkNIC(false) - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - checkNIC(true) - - // If the NIC is not reporting a correct enabled status, we cannot trust the - // next check so end the test here. - if t.Failed() { - t.FailNow() - } - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - checkNIC(false) -} - -func TestRemoveUnknownNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) - } -} - -func TestRemoveNIC(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // NIC should be present in NICInfo and attached to a NetworkDispatcher. - allNICInfo := s.NICInfo() - if _, ok := allNICInfo[nicID]; !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } - - // Removing a NIC should remove it from NICInfo and e should be detached from - // the NetworkDispatcher. - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - if nicInfo, ok := s.NICInfo()[nicID]; ok { - t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) - } - if e.isAttached() { - t.Error("link endpoint for removed NIC still attached to a network dispatcher") - } -} - -func TestRouteWithDownNIC(t *testing.T) { - tests := []struct { - name string - downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error - upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error - }{ - { - name: "Disabled NIC", - downFn: (*stack.Stack).DisableNIC, - upFn: (*stack.Stack).EnableNIC, - }, - - // Once a NIC is removed, it cannot be brought up. - { - name: "Removed NIC", - downFn: (*stack.Stack).RemoveNIC, - }, - } - - const unspecifiedNIC = 0 - const nicID1 = 1 - const nicID2 = 2 - const addr1 = tcpip.Address("\x01") - const addr2 = tcpip.Address("\x02") - const nic1Dst = tcpip.Address("\x05") - const nic2Dst = tcpip.Address("\x06") - - setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } - - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) - } - - return s, ep1, ep2 - } - - // Tests that routes through a down NIC are not used when looking up a route - // for a destination. - t.Run("Find", func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, _, _ := setup(t) - - // Test routes to odd address. - testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) - testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) - testRoute(t, s, nicID1, addr1, "\x05", addr1) - - // Test routes to even address. - testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) - testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) - testRoute(t, s, nicID2, addr2, "\x06", addr2) - - // Bringing NIC1 down should result in no routes to odd addresses. Routes to - // even addresses should continue to be available as NIC2 is still up. - if err := test.downFn(s, nicID1); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID1, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) - testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) - testRoute(t, s, nicID2, addr2, nic2Dst, addr2) - - // Bringing NIC2 down should result in no routes to even addresses. No - // route should be available to any address as routes to odd addresses - // were made unavailable by bringing NIC1 down above. - if err := test.downFn(s, nicID2); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID2, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - - if upFn := test.upFn; upFn != nil { - // Bringing NIC1 up should make routes to odd addresses available - // again. Routes to even addresses should continue to be unavailable - // as NIC2 is still down. - if err := upFn(s, nicID1); err != nil { - t.Fatalf("test.upFn(_, %d): %s", nicID1, err) - } - testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) - testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) - testRoute(t, s, nicID1, addr1, nic1Dst, addr1) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - } - }) - } - }) - - // Tests that writing a packet using a Route through a down NIC fails. - t.Run("WritePacket", func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, ep1, ep2 := setup(t) - - r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) - } - defer r1.Release() - - r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) - } - defer r2.Release() - - // If we failed to get routes r1 or r2, we cannot proceed with the test. - if t.Failed() { - t.FailNow() - } - - buf := buffer.View([]byte{1}) - testSend(t, r1, ep1, buf) - testSend(t, r2, ep2, buf) - - // Writes with Routes that use NIC1 after being brought down should fail. - if err := test.downFn(s, nicID1); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID1, err) - } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testSend(t, r2, ep2, buf) - - // Writes with Routes that use NIC2 after being brought down should fail. - if err := test.downFn(s, nicID2); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID2, err) - } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) - - if upFn := test.upFn; upFn != nil { - // Writes with Routes that use NIC1 after being brought up should - // succeed. - // - // TODO(b/147015577): Should we instead completely invalidate all - // Routes that were bound to a NIC that was brought down at some - // point? - if err := upFn(s, nicID1); err != nil { - t.Fatalf("test.upFn(_, %d): %s", nicID1, err) - } - testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) - } - }) - } - }) -} - -func TestRoutes(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Test routes to odd address. - testRoute(t, s, 0, "", "\x05", "\x01") - testRoute(t, s, 0, "\x01", "\x05", "\x01") - testRoute(t, s, 1, "\x01", "\x05", "\x01") - testRoute(t, s, 0, "\x03", "\x05", "\x03") - testRoute(t, s, 1, "\x03", "\x05", "\x03") - - // Test routes to even address. - testRoute(t, s, 0, "", "\x06", "\x02") - testRoute(t, s, 0, "\x02", "\x06", "\x02") - testRoute(t, s, 2, "\x02", "\x06", "\x02") - testRoute(t, s, 0, "\x04", "\x06", "\x04") - testRoute(t, s, 2, "\x04", "\x06", "\x04") - - // Try to send to odd numbered address from even numbered ones, then - // vice-versa. - testNoRoute(t, s, 0, "\x02", "\x05") - testNoRoute(t, s, 2, "\x02", "\x05") - testNoRoute(t, s, 0, "\x04", "\x05") - testNoRoute(t, s, 2, "\x04", "\x05") - - testNoRoute(t, s, 0, "\x01", "\x06") - testNoRoute(t, s, 1, "\x01", "\x06") - testNoRoute(t, s, 0, "\x03", "\x06") - testNoRoute(t, s, 1, "\x03", "\x06") -} - -func TestAddressRemoval(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Send and receive packets, and verify they are received. - buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - - // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) - } -} - -func TestAddressRemovalWithRouteHeld(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - // Send and receive packets, and verify they are received. - buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - - // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) - } -} - -func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { - t.Helper() - info, ok := s.NICInfo()[nicID] - if !ok { - t.Fatalf("NICInfo() failed to find nicID=%d", nicID) - } - if len(addr) == 0 { - // No address given, verify that there is no address assigned to the NIC. - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) - } - } - return - } - // Address given, verify the address is assigned to the NIC and no other - // address is. - found := false - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber { - if a.AddressWithPrefix.Address == addr { - found = true - } else { - t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) - } - } - } - if !found { - t.Errorf("verify address: couldn't find %s on the NIC", addr) - } -} - -func TestEndpointExpiration(t *testing.T) { - const ( - localAddrByte byte = 0x01 - remoteAddr tcpip.Address = "\x03" - noAddr tcpip.Address = "" - nicID tcpip.NICID = 1 - ) - localAddr := tcpip.Address([]byte{localAddrByte}) - - for _, promiscuous := range []bool{true, false} { - for _, spoofing := range []bool{true, false} { - t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - buf[0] = localAddrByte - - if promiscuous { - if err := s.SetPromiscuousMode(nicID, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - } - - if spoofing { - if err := s.SetSpoofing(nicID, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - } - - // 1. No Address yet, send should only work for spoofing, receive for - // promiscuous mode. - //----------------------- - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - } - - // 2. Add Address, everything should work. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 3. Remove the address, send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - } - - // 4. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 5. Take a reference to the endpoint by getting a route. Verify that - // we can still send/receive, including sending using the route. - //----------------------- - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 6. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - } - - // 7. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 8. Remove the route, sendTo/recv should still work. - //----------------------- - r.Release() - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 9. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - } - }) - } - } -} - -func TestPromiscuousMode(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it doesn't get delivered as we don't - // have a matching endpoint. - const localAddrByte byte = 0x01 - buf[0] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - - // Set promiscuous mode, then check that packet is delivered. - if err := s.SetPromiscuousMode(1, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - - // Check that we can't get a route as there is no local address. - _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute) - } - - // Set promiscuous mode to false, then check that packet can't be - // delivered anymore. - if err := s.SetPromiscuousMode(1, false); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func TestSpoofingWithAddress(t *testing.T) { - localAddr := tcpip.Address("\x01") - nonExistentLocalAddr := tcpip.Address("\x02") - dstAddr := tcpip.Address("\x03") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet works. - testSendTo(t, s, dstAddr, ep, nil) - testSend(t, r, ep, nil) - - // FindRoute should also work with a local address that exists on the NIC. - r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != localAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet using the route works. - testSend(t, r, ep, nil) -} - -func TestSpoofingNoAddress(t *testing.T) { - nonExistentLocalAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet works. - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) -} - -func verifyRoute(gotRoute, wantRoute stack.Route) error { - if gotRoute.LocalAddress != wantRoute.LocalAddress { - return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) - } - if gotRoute.RemoteAddress != wantRoute.RemoteAddress { - return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) - } - if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { - return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) - } - if gotRoute.NextHop != wantRoute.NextHop { - return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) - } - return nil -} - -func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - s.SetRouteTable([]tcpip.Route{}) - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) - } - - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) - } - r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) - } -} - -func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} - // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} - nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") - // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} - nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") - - // Create a new stack with two NICs. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - if err := s.CreateNIC(2, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) - } - - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) - } - - // Set the initial route table. - rt := []tcpip.Route{ - {Destination: nic1Addr.Subnet(), NIC: 1}, - {Destination: nic2Addr.Subnet(), NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, - } - s.SetRouteTable(rt) - - // When an interface is given, the route for a broadcast goes through it. - r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) - } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) - } - - // When an interface is not given, it consults the route table. - // 1. Case: Using the default route. - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) - } - - // 2. Case: Having an explicit route for broadcast will select that one. - rt = append( - []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, - }, - rt..., - ) - s.SetRouteTable(rt) - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) - } -} - -func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { - for _, tc := range []struct { - name string - routeNeeded bool - address tcpip.Address - }{ - // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255 - // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff - {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"}, - {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"}, - {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"}, - {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"}, - {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"}, - - // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112] - {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 link-local address starts with fe80::/10. - {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"}, - {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 addresses that are neither multicast nor link-local. - {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - } { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - s.SetRouteTable([]tcpip.Route{}) - - var anyAddr tcpip.Address - if len(tc.address) == header.IPv4AddressSize { - anyAddr = header.IPv4Any - } else { - anyAddr = header.IPv6Any - } - - want := tcpip.ErrNetworkUnreachable - if tc.routeNeeded { - want = tcpip.ErrNoRoute - } - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) - } - - if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { - // Route table is empty but we need a route, this should cause an error. - if err != tcpip.ErrNoRoute { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute) - } - } else { - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err) - } - if r.LocalAddress != anyAddr { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, anyAddr) - } - if r.RemoteAddress != tc.address { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, tc.address) - } - } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - }) - } -} - -// Add a range of addresses, then check that a packet is delivered. -func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[0] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { - t.Helper() - - // Loop over all addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - - addrBytes := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - addr := tcpip.Address(addrBytes) - wantNicID := nicID - // The subnet and broadcast addresses are skipped. - if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { - wantNicID = 0 - } - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) - } - addrBytes[0]++ - } - - // Trying the next address should always fail since it is outside the range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) - } -} - -// Set a range of addresses, then remove it again, and check at each step that -// CheckLocalAddress returns the correct NIC for each address or zero if not -// existent. -func TestCheckLocalAddressForSubnet(t *testing.T) { - const nicID tcpip.NICID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) - - if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) - - if err := s.RemoveAddressRange(nicID, subnet); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) -} - -// Set a range of addresses, then send a packet to a destination outside the -// range and then check it doesn't get delivered. -func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[0] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func TestNetworkOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{}, - }) - - // Try an unsupported network protocol. - if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.NetworkProtocol) - }{ - {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { - t.Helper() - fakeNet := p.(*fakeNetworkProtocol) - if fakeNet.opts.good != true { - t.Fatalf("fakeNet.opts.good = false, want = true") - } - var v fakeNetGoodOption - if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) - } - }}, - {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) - } - } -} - -func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { - ranges, ok := s.NICAddressRanges()[id] - if !ok { - return false - } - for _, r := range ranges { - if r == addrRange { - return true - } - } - return false -} - -func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addr := tcpip.Address("\x01\x01\x01\x01") - mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - addrRange, err := tcpip.NewSubnet(addr, mask) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.RemoveAddressRange(1, addrRange); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } -} - -func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { - for _, addrLen := range []int{4, 16} { - t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { - for canBe := 0; canBe < 3; canBe++ { - t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { - for never := 0; never < 3; never++ { - t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - // Insert <canBe> primary and <never> never-primary addresses. - // Each one will add a network endpoint to the NIC. - primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{}) - for i := 0; i < canBe+never; i++ { - var behavior stack.PrimaryEndpointBehavior - if i < canBe { - behavior = stack.CanBePrimaryEndpoint - } else { - behavior = stack.NeverPrimaryEndpoint - } - // Add an address and in case of a primary one include a - // prefixLen. - address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) - if behavior == stack.CanBePrimaryEndpoint { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, - } - if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil { - t.Fatal("AddProtocolAddressWithOptions failed:", err) - } - // Remember the address/prefix. - primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} - } else { - if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - } - } - } - // Check that GetMainNICAddress returns an address if at least - // one primary address was added. In that case make sure the - // address/prefixLen matches what we added. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if len(primaryAddrAdded) == 0 { - // No primary addresses present. - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) - } - } else { - // At least one primary address was added, verify the returned - // address is in the list of primary addresses we added. - if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) - } - } - }) - } - }) - } - }) - } -} - -func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - for _, tc := range []struct { - name string - address tcpip.Address - prefixLen int - }{ - {"IPv4", "\x01\x01\x01\x01", 24}, - {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, - } { - t.Run(tc.name, func(t *testing.T) { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tc.address, - PrefixLen: tc.prefixLen, - }, - } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) - } - - // Check that we get the right initial address and prefix length. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { - t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) - } - - if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - - // Check that we get no address after removal. - gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) - } - }) - } -} - -// Simple network address generator. Good for 255 addresses. -type addressGenerator struct{ cnt byte } - -func (g *addressGenerator) next(addrLen int) tcpip.Address { - g.cnt++ - return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen)) -} - -func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { - t.Helper() - - if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) - } - - sort.Slice(gotAddresses, func(i, j int) bool { - return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address - }) - sort.Slice(expectedAddresses, func(i, j int) bool { - return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address - }) - - for i, gotAddr := range gotAddresses { - expectedAddr := expectedAddresses[i] - if gotAddr != expectedAddr { - t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr) - } - } -} - -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestCreateNICWithOptions(t *testing.T) { - type callArgsAndExpect struct { - nicID tcpip.NICID - opts stack.NICOptions - err *tcpip.Error - } - - tests := []struct { - desc string - calls []callArgsAndExpect - }{ - { - desc: "DuplicateNICID", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "eth1"}, - err: nil, - }, - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "eth2"}, - err: tcpip.ErrDuplicateNICID, - }, - }, - }, - { - desc: "DuplicateName", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "lo"}, - err: nil, - }, - { - nicID: tcpip.NICID(2), - opts: stack.NICOptions{Name: "lo"}, - err: tcpip.ErrDuplicateNICID, - }, - }, - }, - { - desc: "Unnamed", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: nil, - }, - { - nicID: tcpip.NICID(2), - opts: stack.NICOptions{}, - err: nil, - }, - }, - }, - { - desc: "UnnamedDuplicateNICID", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: nil, - }, - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: tcpip.ErrDuplicateNICID, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) - for _, call := range test.calls { - if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { - t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) - } - } - }) - } -} - -func TestNICStats(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) - } - - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) - } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) - } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } -} - -func TestNICForwarding(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const dstAddr = tcpip.Address("\x03") - - tests := []struct { - name string - headerLen uint16 - }{ - { - name: "Zero header length", - }, - { - name: "Non-zero header length", - headerLen: 16, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) - } - - ep2 := channelLinkWithHeaderLength{ - Endpoint: channel.New(10, defaultMTU, ""), - headerLength: test.headerLen, - } - if err := s.CreateNIC(nicID2, &ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) - } - - // Route all packets to dstAddr to NIC 2. - { - subnet, err := tcpip.NewSubnet(dstAddr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) - } - - // Send a packet to dstAddr. - buf := buffer.NewView(30) - buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not forwarded") - } - - // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { - t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - }) - } -} - -// TestNICContextPreservation tests that you can read out via stack.NICInfo the -// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. -func TestNICContextPreservation(t *testing.T) { - var ctx *int - tests := []struct { - name string - opts stack.NICOptions - want stack.NICContext - }{ - { - "context_set", - stack.NICOptions{Context: ctx}, - ctx, - }, - { - "context_not_set", - stack.NICOptions{}, - nil, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) - if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { - t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) - } - nicinfos := s.NICInfo() - nicinfo, ok := nicinfos[id] - if !ok { - t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) - } - if got, want := nicinfo.Context == test.want, true; got != want { - t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) - } - }) - } -} - -// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local -// addresses. -func TestNICAutoGenLinkLocalAddr(t *testing.T) { - const nicID = 1 - - var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte - n, err := rand.Read(secretKey[:]) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) - } - - nicNameFunc := func(_ tcpip.NICID, name string) string { - return name - } - - tests := []struct { - name string - nicName string - autoGen bool - linkAddr tcpip.LinkAddress - iidOpts stack.OpaqueInterfaceIdentifierOptions - shouldGen bool - expectedAddr tcpip.Address - }{ - { - name: "Disabled", - nicName: "nic1", - autoGen: false, - linkAddr: linkAddr1, - shouldGen: false, - }, - { - name: "Disabled without OIID options", - nicName: "nic1", - autoGen: false, - linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:], - }, - shouldGen: false, - }, - - // Tests for EUI64 based addresses. - { - name: "EUI64 Enabled", - autoGen: true, - linkAddr: linkAddr1, - shouldGen: true, - expectedAddr: header.LinkLocalAddr(linkAddr1), - }, - { - name: "EUI64 Empty MAC", - autoGen: true, - shouldGen: false, - }, - { - name: "EUI64 Invalid MAC", - autoGen: true, - linkAddr: "\x01\x02\x03", - shouldGen: false, - }, - { - name: "EUI64 Multicast MAC", - autoGen: true, - linkAddr: "\x01\x02\x03\x04\x05\x06", - shouldGen: false, - }, - { - name: "EUI64 Unspecified MAC", - autoGen: true, - linkAddr: "\x00\x00\x00\x00\x00\x00", - shouldGen: false, - }, - - // Tests for Opaque IID based addresses. - { - name: "OIID Enabled", - nicName: "nic1", - autoGen: true, - linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), - }, - // These are all cases where we would not have generated a - // link-local address if opaque IIDs were disabled. - { - name: "OIID Empty MAC and empty nicName", - autoGen: true, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:1], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]), - }, - { - name: "OIID Invalid MAC", - nicName: "test", - autoGen: true, - linkAddr: "\x01\x02\x03", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:2], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]), - }, - { - name: "OIID Multicast MAC", - nicName: "test2", - autoGen: true, - linkAddr: "\x01\x02\x03\x04\x05\x06", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:3], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]), - }, - { - name: "OIID Unspecified MAC and nil SecretKey", - nicName: "test3", - autoGen: true, - linkAddr: "\x00\x00\x00\x00\x00\x00", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, - } - - e := channel.New(0, 1280, test.linkAddr) - s := stack.New(opts) - nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - // A new disabled NIC should not have any address, even if auto generation - // was enabled. - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) - } - - // Enabling the NIC should attempt auto-generation of a link-local - // address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - var expectedMainAddr tcpip.AddressWithPrefix - if test.shouldGen { - expectedMainAddr = tcpip.AddressWithPrefix{ - Address: test.expectedAddr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - } - - // Should have auto-generated an address and resolved immediately (DAD - // is disabled). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } else { - // Should not have auto-generated an address. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address") - default: - } - } - - gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if gotMainAddr != expectedMainAddr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) - } - }) - } -} - -// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are -// not auto-generated for loopback NICs. -func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { - const nicID = 1 - const nicName = "nicName" - - tests := []struct { - name string - opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions - }{ - { - name: "IID From MAC", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, - }, - { - name: "Opaque IID", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, - } - - e := loopback.New() - s := stack.New(opts) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) - } - }) - } -} - -// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 -// link-local addresses will only be assigned after the DAD process resolves. -func TestNICAutoGenAddrDoesDAD(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, - } - - e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Address should not be considered bound to the - // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } - - linkLocalAddr := header.LinkLocalAddr(linkAddr1) - - // Wait for DAD to resolve. - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // We should get a resolution event after 1s (default time to - // resolve as per default NDP configurations). Waiting for that - // resolution time + an extra 1s without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } -} - -// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected -// when an address's kind gets "promoted" to permanent from permanentExpired. -func TestNewPEBOnPromotionToPermanent(t *testing.T) { - pebs := []stack.PrimaryEndpointBehavior{ - stack.NeverPrimaryEndpoint, - stack.CanBePrimaryEndpoint, - stack.FirstPrimaryEndpoint, - } - - for _, pi := range pebs { - for _, ps := range pebs { - t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - // Add a permanent address with initial - // PrimaryEndpointBehavior (peb), pi. If pi is - // NeverPrimaryEndpoint, the address should not - // be returned by a call to GetMainNICAddress; - // else, it should. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - } - addr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) - } - if pi == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) - - } - } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatalf("NewSubnet failed:", err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Take a route through the address so its ref - // count gets incremented and does not actually - // get deleted when RemoveAddress is called - // below. This is because we want to test that a - // new peb is respected when an address gets - // "promoted" to permanent from a - // permanentExpired kind. - r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - defer r.Release() - if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed:", err) - } - - // - // At this point, the address should still be - // known by the NIC, but have its - // kind = permanentExpired. - // - - // Add some other address with peb set to - // FirstPrimaryEndpoint. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - - } - - // Add back the address we removed earlier and - // make sure the new peb was respected. - // (The address should just be promoted now). - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - } - var primaryAddrs []tcpip.Address - for _, pa := range s.NICInfo()[1].ProtocolAddresses { - primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) - } - var expectedList []tcpip.Address - switch ps { - case stack.FirstPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x01", - "\x03", - } - case stack.CanBePrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - "\x01", - } - case stack.NeverPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - } - } - if !cmp.Equal(primaryAddrs, expectedList) { - t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) - } - - // Once we remove the other address, if the new - // peb, ps, was NeverPrimaryEndpoint, no address - // should be returned by a call to - // GetMainNICAddress; else, our original address - // should be returned. - if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed:", err) - } - addr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) - } - if ps == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) - - } - } else { - if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) - } - } - }) - } - } -} - -func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { - const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 - ) - - // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. - tests := []struct { - name string - nicAddrs []tcpip.Address - connectAddr tcpip.Address - expectedLocalAddr tcpip.Address - }{ - // Test Rule 1 of RFC 6724 section 5. - { - name: "Same Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr1, - expectedLocalAddr: globalAddr1, - }, - { - name: "Same Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr1, - expectedLocalAddr: globalAddr1, - }, - { - name: "Same Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr1, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Same Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr1, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Same Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr1, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Same Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr1, - expectedLocalAddr: uniqueLocalAddr1, - }, - - // Test Rule 2 of RFC 6724 section 5. - { - name: "Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, - expectedLocalAddr: globalAddr1, - }, - { - name: "Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr2, - expectedLocalAddr: globalAddr1, - }, - { - name: "Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred for link local multicast (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalMulticastAddr, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred for link local multicast (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalMulticastAddr, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - - // Test returning the endpoint that is closest to the front when - // candidate addresses are "equal" from the perspective of RFC 6724 - // section 5. - { - name: "Unique Local for Global", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, - connectAddr: globalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Link Local for Global", - nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: globalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local for Unique Local", - nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: uniqueLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) - - for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) - } - } - - if t.Failed() { - t.FailNow() - } - - if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { - t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) - } - }) - } -} - -func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { - const nicID = 1 - - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) - } - - // Enabling the NIC should add the IPv4 broadcast address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 1 { - t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) - } - want := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - if allNICAddrs[0] != want { - t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) - } - - // Disabling the NIC should remove the IPv4 broadcast address. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) - } -} - -// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 -// address after leaving its solicited node multicast address does not result in -// an error. -func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) - } - - // The NIC should have joined addr1's solicited node multicast address. - snmc := header.SolicitedNodeAddr(addr1) - in, err := s.IsInGroup(nicID, snmc) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) - } - if !in { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) - } - - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) - } - in, err = s.IsInGroup(nicID, snmc) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) - } - if in { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) - } - - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) - } -} - -func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { - const nicID = 1 - - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - // Should not be in the IPv6 all-nodes multicast group yet because the NIC has - // not been enabled yet. - isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) - } - - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) - } - - // The all-nodes multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) - } -} - -// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC -// was disabled have DAD performed on them when the NIC is enabled. -func TestDoDADWhenNICEnabled(t *testing.T) { - t.Parallel() - - const dadTransmits = 1 - const retransmitTimer = time.Second - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPDisp: &ndpDisp, - } - - e := channel.New(dadTransmits, 1280, linkAddr1) - s := stack.New(opts) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: llAddr1, - PrefixLen: 128, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) - } - - // Address should be in the list of all addresses. - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - - // Address should be tentative so it should not be a main address. - got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) - } - - // Enabling the NIC should start DAD for the address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) - } - - // Wait for DAD to resolve. - select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) - } - - // Enabling the NIC again should be a no-op. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) - } -} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go deleted file mode 100644 index 0e3e239c5..000000000 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ /dev/null @@ -1,358 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "math" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testPort = 4096 -) - -type testContext struct { - t *testing.T - linkEps map[tcpip.NICID]*channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createV6Endpoint(v6only bool) { - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } -} - -// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. -func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEps := make(map[tcpip.NICID]*channel.Endpoint) - for _, linkEpID := range linkEpIDs { - channelEp := channel.New(256, mtu, "") - if err := s.CreateNIC(linkEpID, channelEp); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - linkEps[linkEpID] = channelEp - - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress IPv4 failed: %v", err) - } - - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress IPv6 failed: %v", err) - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEps: linkEps, - } -} - -type headers struct { - srcPort uint16 - dstPort uint16 -} - -func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -func TestTransportDemuxerRegister(t *testing.T) { - for _, test := range []struct { - name string - proto tcpip.NetworkProtocolNumber - want *tcpip.Error - }{ - {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, - {"success", ipv4.ProtocolNumber, nil}, - } { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - }) - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - tEP, ok := ep.(stack.TransportEndpoint) - if !ok { - t.Fatalf("%T does not implement stack.TransportEndpoint", ep) - } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) - } - }) - } -} - -// TestReuseBindToDevice injects varied packets on input devices and checks that -// the distribution of packets received matches expectations. -func TestDistribution(t *testing.T) { - type endpointSockopts struct { - reuse int - bindToDevice tcpip.NICID - } - for _, test := range []struct { - name string - // endpoints will received the inject packets. - endpoints []endpointSockopts - // wantedDistribution is the wanted ratio of packets received on each - // endpoint for each NIC on which packets are injected. - wantedDistributions map[tcpip.NICID][]float64 - }{ - { - "BindPortReuse", - // 5 endpoints that all have reuse set. - []endpointSockopts{ - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, - }, - map[tcpip.NICID][]float64{ - // Injected packets on dev0 get distributed evenly. - 1: {0.2, 0.2, 0.2, 0.2, 0.2}, - }, - }, - { - "BindToDevice", - // 3 endpoints with various bindings. - []endpointSockopts{ - {0, 1}, - {0, 2}, - {0, 3}, - }, - map[tcpip.NICID][]float64{ - // Injected packets on dev0 go only to the endpoint bound to dev0. - 1: {1, 0, 0}, - // Injected packets on dev1 go only to the endpoint bound to dev1. - 2: {0, 1, 0}, - // Injected packets on dev2 go only to the endpoint bound to dev2. - 3: {0, 0, 1}, - }, - }, - { - "ReuseAndBindToDevice", - // 6 endpoints with various bindings. - []endpointSockopts{ - {1, 1}, - {1, 1}, - {1, 2}, - {1, 2}, - {1, 2}, - {1, 0}, - }, - map[tcpip.NICID][]float64{ - // Injected packets on dev0 get distributed among endpoints bound to - // dev0. - 1: {0.5, 0.5, 0, 0, 0, 0}, - // Injected packets on dev1 get distributed among endpoints bound to - // dev1 or unbound. - 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, - // Injected packets on dev999 go only to the unbound. - 1000: {0, 0, 0, 0, 0, 1}, - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - for device, wantedDistribution := range test.wantedDistributions { - t.Run(string(device), func(t *testing.T) { - var devices []tcpip.NICID - for d := range test.wantedDistributions { - devices = append(devices, d) - } - c := newDualTestContextMultiNIC(t, defaultMTU, devices) - defer c.cleanup() - - c.createV6Endpoint(false) - - eps := make(map[tcpip.Endpoint]int) - - pollChannel := make(chan tcpip.Endpoint) - for i, endpoint := range test.endpoints { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - eps[ep] = i - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(ep) - - defer ep.Close() - reusePortOption := tcpip.ReusePortOption(endpoint.reuse) - if err := ep.SetSockOpt(reusePortOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) - } - bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) - } - if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) - } - } - - npackets := 100000 - nports := 10000 - if got, want := len(test.endpoints), len(wantedDistribution); got != want { - t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) - } - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.sendV6Packet(payload, - &headers{ - srcPort: testPort + port, - dstPort: stackPort}, - device) - - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled by the same - // socket. - if want, got := ports[port], ep; want != got { - t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) - } - } - } - - // Check that a packet distribution is as expected. - for ep, i := range eps { - wantedRatio := wantedDistribution[i] - wantedRecv := wantedRatio * float64(npackets) - actualRecv := stats[ep] - actualRatio := float64(stats[ep]) / float64(npackets) - // The deviation is less than 10%. - if math.Abs(actualRatio-wantedRatio) > 0.05 { - t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) - } - } - }) - } - }) - } -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go deleted file mode 100644 index 5d1da2f8b..000000000 --- a/pkg/tcpip/stack/transport_test.go +++ /dev/null @@ -1,650 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 -) - -// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't -// use it. -type fakeTransportEndpoint struct { - stack.TransportEndpointInfo - stack *stack.Stack - proto *fakeTransportProtocol - peerAddr tcpip.Address - route stack.Route - uniqueID uint64 - - // acceptQueue is non-nil iff bound. - acceptQueue []fakeTransportEndpoint -} - -func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { - return &f.TransportEndpointInfo -} - -func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { - return nil -} - -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} -} - -func (f *fakeTransportEndpoint) Abort() { - f.Close() -} - -func (f *fakeTransportEndpoint) Close() { - f.route.Release() -} - -func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return mask -} - -func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return buffer.View{}, tcpip.ControlMessages{}, nil -} - -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - if len(f.route.RemoteAddress) == 0 { - return 0, nil, tcpip.ErrNoRoute - } - - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.FullPayload() - if err != nil { - return 0, nil, err - } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: buffer.View(v).ToVectorisedView(), - }); err != nil { - return 0, nil, err - } - - return int64(len(v)), nil, nil -} - -func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - -// SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// SetSockOptBool sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption -} - -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return -1, tcpip.ErrUnknownProtocolOption -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - } - return tcpip.ErrInvalidEndpointState -} - -// Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported -} - -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - f.peerAddr = addr.Addr - - // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return tcpip.ErrNoRoute - } - defer r.Release() - - // Try to register so that we can start receiving packets. - f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) - if err != nil { - return err - } - - f.route = r.Clone() - - return nil -} - -func (f *fakeTransportEndpoint) UniqueID() uint64 { - return f.uniqueID -} - -func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Reset() { -} - -func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { - return nil -} - -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - if len(f.acceptQueue) == 0 { - return nil, nil, nil - } - a := f.acceptQueue[0] - f.acceptQueue = f.acceptQueue[1:] - return &a, nil, nil -} - -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { - if err := f.stack.RegisterTransportEndpoint( - a.NIC, - []tcpip.NetworkProtocolNumber{fakeNetNumber}, - fakeTransNumber, - stack.TransportEndpointID{LocalAddress: a.Addr}, - f, - false, /* reuse */ - 0, /* bindtoDevice */ - ); err != nil { - return err - } - f.acceptQueue = []fakeTransportEndpoint{} - return nil -} - -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { - // Increment the number of received packets. - f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - stack: f.stack, - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) - } -} - -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { - // Increment the number of received control packets. - f.proto.controlCount++ -} - -func (f *fakeTransportEndpoint) State() uint32 { - return 0 -} - -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} - -func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { - return iptables.IPTables{}, nil -} - -func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} - -func (f *fakeTransportEndpoint) Wait() {} - -type fakeTransportGoodOption bool - -type fakeTransportBadOption bool - -type fakeTransportInvalidValueOption int - -type fakeTransportProtocolOptions struct { - good bool -} - -// fakeTransportProtocol is a transport-layer protocol descriptor. It -// aggregates the number of packets received via endpoints of this protocol. -type fakeTransportProtocol struct { - packetCount int - controlCount int - opts fakeTransportProtocolOptions -} - -func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { - return fakeTransNumber -} - -func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil -} - -func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return nil, tcpip.ErrUnknownProtocol -} - -func (*fakeTransportProtocol) MinimumPacketSize() int { - return fakeTransHeaderLen -} - -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { - return 0, 0, nil -} - -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { - return true -} - -func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case fakeTransportGoodOption: - f.opts.good = bool(v) - return nil - case fakeTransportInvalidValueOption: - return tcpip.ErrInvalidOptionValue - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *fakeTransportGoodOption: - *v = fakeTransportGoodOption(f.opts.good) - return nil - default: - return tcpip.ErrUnknownProtocolOption - } -} - -// Abort implements TransportProtocol.Abort. -func (*fakeTransportProtocol) Abort() {} - -// Close implements tcpip.Endpoint.Close. -func (*fakeTransportProtocol) Close() {} - -// Wait implements TransportProtocol.Wait. -func (*fakeTransportProtocol) Wait() {} - -func fakeTransFactory() stack.TransportProtocol { - return &fakeTransportProtocol{} -} - -func TestTransportReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the packet. - buf := buffer.NewView(30) - - // Make sure packet with wrong protocol is not delivered. - buf[0] = 1 - buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[0] = 1 - buf[1] = 3 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet is delivered. - buf[0] = 1 - buf[1] = 2 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.packetCount != 1 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) - } -} - -func TestTransportControlReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the control packet. - buf := buffer.NewView(2*fakeNetHeaderLen + 30) - - // Outer packet contains the control protocol number. - buf[0] = 1 - buf[1] = 0xfe - buf[2] = uint8(fakeControlProtocol) - - // Make sure packet with wrong protocol is not delivered. - buf[fakeNetHeaderLen+0] = 0 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[fakeNetHeaderLen+0] = 3 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet is delivered. - buf[fakeNetHeaderLen+0] = 2 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.controlCount != 1 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) - } -} - -func TestTransportSend(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Create endpoint and bind it. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Create buffer that will hold the payload. - view := buffer.NewView(30) - _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("write failed: %v", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - if fakeNet.sendPacketCount[2] != 1 { - t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1) - } -} - -func TestTransportOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - - // Try an unsupported transport protocol. - if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.TransportProtocol) - }{ - {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { - t.Helper() - fakeTrans := p.(*fakeTransportProtocol) - if fakeTrans.opts.good != true { - t.Fatalf("fakeTrans.opts.good = false, want = true") - } - var v fakeTransportGoodOption - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) - } - - }}, - {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) - } - } -} - -func TestTransportForwarding(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - s.SetForwarding(true) - - // TODO(b/123449044): Change this to a channel NIC. - ep1 := loopback.New() - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - { - subnet0, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - }) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: req.ToVectorisedView(), - }) - - aep, _, err := ep.Accept() - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - p, ok := ep2.Read() - if !ok { - t.Fatal("Response packet not forwarded") - } - - if dst := p.Pkt.Header.View()[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := p.Pkt.Header.View()[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go new file mode 100755 index 000000000..6753503f0 --- /dev/null +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -0,0 +1,108 @@ +// automatically generated by stateify. + +// +build go1.9 +// +build !go1.15 + +package tcpip + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *PacketBuffer) save(m state.Map) { + x.beforeSave() + m.Save("Data", &x.Data) + m.Save("DataOffset", &x.DataOffset) + m.Save("DataSize", &x.DataSize) + m.Save("Header", &x.Header) + m.Save("LinkHeader", &x.LinkHeader) + m.Save("NetworkHeader", &x.NetworkHeader) + m.Save("TransportHeader", &x.TransportHeader) +} + +func (x *PacketBuffer) afterLoad() {} +func (x *PacketBuffer) load(m state.Map) { + m.Load("Data", &x.Data) + m.Load("DataOffset", &x.DataOffset) + m.Load("DataSize", &x.DataSize) + m.Load("Header", &x.Header) + m.Load("LinkHeader", &x.LinkHeader) + m.Load("NetworkHeader", &x.NetworkHeader) + m.Load("TransportHeader", &x.TransportHeader) +} + +func (x *FullAddress) beforeSave() {} +func (x *FullAddress) save(m state.Map) { + x.beforeSave() + m.Save("NIC", &x.NIC) + m.Save("Addr", &x.Addr) + m.Save("Port", &x.Port) +} + +func (x *FullAddress) afterLoad() {} +func (x *FullAddress) load(m state.Map) { + m.Load("NIC", &x.NIC) + m.Load("Addr", &x.Addr) + m.Load("Port", &x.Port) +} + +func (x *ControlMessages) beforeSave() {} +func (x *ControlMessages) save(m state.Map) { + x.beforeSave() + m.Save("HasTimestamp", &x.HasTimestamp) + m.Save("Timestamp", &x.Timestamp) + m.Save("HasInq", &x.HasInq) + m.Save("Inq", &x.Inq) + m.Save("HasTOS", &x.HasTOS) + m.Save("TOS", &x.TOS) + m.Save("HasTClass", &x.HasTClass) + m.Save("TClass", &x.TClass) + m.Save("HasIPPacketInfo", &x.HasIPPacketInfo) + m.Save("PacketInfo", &x.PacketInfo) +} + +func (x *ControlMessages) afterLoad() {} +func (x *ControlMessages) load(m state.Map) { + m.Load("HasTimestamp", &x.HasTimestamp) + m.Load("Timestamp", &x.Timestamp) + m.Load("HasInq", &x.HasInq) + m.Load("Inq", &x.Inq) + m.Load("HasTOS", &x.HasTOS) + m.Load("TOS", &x.TOS) + m.Load("HasTClass", &x.HasTClass) + m.Load("TClass", &x.TClass) + m.Load("HasIPPacketInfo", &x.HasIPPacketInfo) + m.Load("PacketInfo", &x.PacketInfo) +} + +func (x *IPPacketInfo) beforeSave() {} +func (x *IPPacketInfo) save(m state.Map) { + x.beforeSave() + m.Save("NIC", &x.NIC) + m.Save("LocalAddr", &x.LocalAddr) + m.Save("DestinationAddr", &x.DestinationAddr) +} + +func (x *IPPacketInfo) afterLoad() {} +func (x *IPPacketInfo) load(m state.Map) { + m.Load("NIC", &x.NIC) + m.Load("LocalAddr", &x.LocalAddr) + m.Load("DestinationAddr", &x.DestinationAddr) +} + +func (x *StdClock) beforeSave() {} +func (x *StdClock) save(m state.Map) { + x.beforeSave() +} + +func (x *StdClock) afterLoad() {} +func (x *StdClock) load(m state.Map) { +} + +func init() { + state.Register("pkg/tcpip.PacketBuffer", (*PacketBuffer)(nil), state.Fns{Save: (*PacketBuffer).save, Load: (*PacketBuffer).load}) + state.Register("pkg/tcpip.FullAddress", (*FullAddress)(nil), state.Fns{Save: (*FullAddress).save, Load: (*FullAddress).load}) + state.Register("pkg/tcpip.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load}) + state.Register("pkg/tcpip.IPPacketInfo", (*IPPacketInfo)(nil), state.Fns{Save: (*IPPacketInfo).save, Load: (*IPPacketInfo).load}) + state.Register("pkg/tcpip.StdClock", (*StdClock)(nil), state.Fns{Save: (*StdClock).save, Load: (*StdClock).load}) +} diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go deleted file mode 100644 index 8c0aacffa..000000000 --- a/pkg/tcpip/tcpip_test.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import ( - "fmt" - "net" - "strings" - "testing" -) - -func TestSubnetContains(t *testing.T) { - tests := []struct { - s Address - m AddressMask - a Address - want bool - }{ - {"\xa0", "\xf0", "\x90", false}, - {"\xa0", "\xf0", "\xa0", true}, - {"\xa0", "\xf0", "\xa5", true}, - {"\xa0", "\xf0", "\xaf", true}, - {"\xa0", "\xf0", "\xb0", false}, - {"\xa0", "\xf0", "", false}, - {"\xa0", "\xf0", "\xa0\x00", false}, - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - s, err := NewSubnet(tt.s, tt.m) - if err != nil { - t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err) - continue - } - if got := s.Contains(tt.a); got != tt.want { - t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want) - } - } -} - -func TestSubnetBits(t *testing.T) { - tests := []struct { - a AddressMask - want1 int - want0 int - }{ - {"\x00", 0, 8}, - {"\x00\x00", 0, 16}, - {"\x36", 0, 8}, - {"\x5c", 0, 8}, - {"\x5c\x5c", 0, 16}, - {"\x5c\x36", 0, 16}, - {"\x36\x5c", 0, 16}, - {"\x36\x36", 0, 16}, - {"\xff", 8, 0}, - {"\xff\xff", 16, 0}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got1, got0 := s.Bits() - if got1 != tt.want1 || got0 != tt.want0 { - t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0) - } - } -} - -func TestSubnetPrefix(t *testing.T) { - tests := []struct { - a AddressMask - want int - }{ - {"\x00", 0}, - {"\x00\x00", 0}, - {"\x36", 0}, - {"\x86", 1}, - {"\xc5", 2}, - {"\xff\x00", 8}, - {"\xff\x36", 8}, - {"\xff\x8c", 9}, - {"\xff\xc8", 10}, - {"\xff", 8}, - {"\xff\xff", 16}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got := s.Prefix() - if got != tt.want { - t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want) - } - } -} - -func TestSubnetCreation(t *testing.T) { - tests := []struct { - a Address - m AddressMask - want error - }{ - {"\xa0", "\xf0", nil}, - {"\xa0\xa0", "\xf0", errSubnetLengthMismatch}, - {"\xaa", "\xf0", errSubnetAddressMasked}, - {"", "", nil}, - } - for _, tt := range tests { - if _, err := NewSubnet(tt.a, tt.m); err != tt.want { - t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want) - } - } -} - -func TestAddressString(t *testing.T) { - for _, want := range []string{ - // Taken from stdlib. - "2001:db8::123:12:1", - "2001:db8::1", - "2001:db8:0:1:0:1:0:1", - "2001:db8:1:0:1:0:1:0", - "2001::1:0:0:1", - "2001:db8:0:0:1::", - "2001:db8::1:0:0:1", - "2001:db8::a:b:c:d", - - // Leading zeros. - "::1", - // Trailing zeros. - "8::", - // No zeros. - "1:1:1:1:1:1:1:1", - // Longer sequence is after other zeros, but not at the end. - "1:0:0:1::1", - // Longer sequence is at the beginning, shorter sequence is at - // the end. - "::1:1:1:0:0", - // Longer sequence is not at the beginning, shorter sequence is - // at the end. - "1::1:1:0:0", - // Longer sequence is at the beginning, shorter sequence is not - // at the end. - "::1:1:0:0:1", - // Neither sequence is at an end, longer is after shorter. - "1:0:0:1::1", - // Shorter sequence is at the beginning, longer sequence is not - // at the end. - "0:0:1:1::1", - // Shorter sequence is at the beginning, longer sequence is at - // the end. - "0:0:1:1:1::", - // Short sequences at both ends, longer one in the middle. - "0:1:1::1:1:0", - // Short sequences at both ends, longer one in the middle. - "0:1::1:0:0", - // Short sequences at both ends, longer one in the middle. - "0:0:1::1:0", - // Longer sequence surrounded by shorter sequences, but none at - // the end. - "1:0:1::1:0:1", - } { - addr := Address(net.ParseIP(want)) - if got := addr.String(); got != want { - t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want) - } - } -} - -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - -func TestAddressWithPrefixSubnet(t *testing.T) { - tests := []struct { - addr Address - prefixLen int - subnetAddr Address - subnetMask AddressMask - }{ - {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, - {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, - {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, - {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, - } - for _, tt := range tests { - ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} - gotSubnet := ap.Subnet() - wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) - if err != nil { - t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) - continue - } - if gotSubnet != wantSubnet { - t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) - } - } -} diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/tcpip/time.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index f5f01f32f..f5f01f32f 100644..100755 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go deleted file mode 100644 index 2d20f7ef3..000000000 --- a/pkg/tcpip/timer_test.go +++ /dev/null @@ -1,236 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip_test - -import ( - "sync" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -const ( - shortDuration = 1 * time.Nanosecond - middleDuration = 100 * time.Millisecond - longDuration = 1 * time.Second -) - -func TestCancellableTimerFire(t *testing.T) { - t.Parallel() - - ch := make(chan struct{}) - var lock sync.Mutex - - timer := tcpip.MakeCancellableTimer(&lock, func() { - ch <- struct{}{} - }) - timer.Reset(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} - -func TestCancellableTimerResetFromLongDuration(t *testing.T) { - t.Parallel() - - ch := make(chan struct{}) - var lock sync.Mutex - - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(middleDuration) - - lock.Lock() - timer.StopLocked() - lock.Unlock() - - timer.Reset(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} - -func TestCancellableTimerResetFromShortDuration(t *testing.T) { - t.Parallel() - - ch := make(chan struct{}) - var lock sync.Mutex - - lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() - lock.Unlock() - - // Wait for timer to fire if it wasn't correctly stopped. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): - } - - timer.Reset(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} - -func TestCancellableTimerImmediatelyStop(t *testing.T) { - t.Parallel() - - ch := make(chan struct{}) - var lock sync.Mutex - - for i := 0; i < 1000; i++ { - lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() - lock.Unlock() - } - - // Wait for timer to fire if it wasn't correctly stopped. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): - } -} - -func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { - t.Parallel() - - ch := make(chan struct{}) - var lock sync.Mutex - - lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() - lock.Unlock() - - for i := 0; i < 10; i++ { - timer.Reset(middleDuration) - - lock.Lock() - // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration * 2) - timer.StopLocked() - lock.Unlock() - } - - // Wait for double the duration so timers that weren't correctly stopped can - // fire. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration * 2): - } -} - -func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { - t.Parallel() - - ch := make(chan struct{}) - var lock sync.Mutex - - lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - for i := 0; i < 10; i++ { - // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration) - timer.StopLocked() - timer.Reset(shortDuration) - } - lock.Unlock() - - // Wait for double the duration for the last timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} - -func TestManyCancellableTimerResetUnderLock(t *testing.T) { - t.Parallel() - - ch := make(chan struct{}) - var lock sync.Mutex - - lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - for i := 0; i < 10; i++ { - timer.StopLocked() - timer.Reset(shortDuration) - } - lock.Unlock() - - // Wait for double the duration for the last timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD deleted file mode 100644 index ac18ec5b1..000000000 --- a/pkg/tcpip/transport/icmp/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "icmp_packet_list", - out = "icmp_packet_list.go", - package = "icmp", - prefix = "icmpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*icmpPacket", - "Linker": "*icmpPacket", - }, -) - -go_library( - name = "icmp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "icmp_packet_list.go", - "protocol.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go new file mode 100755 index 000000000..ddee31adb --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go @@ -0,0 +1,186 @@ +package icmp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type icmpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type icmpPacketList struct { + head *icmpPacket + tail *icmpPacket +} + +// Reset resets list l to the empty state. +func (l *icmpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *icmpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *icmpPacketList) Front() *icmpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *icmpPacketList) Back() *icmpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *icmpPacketList) PushFront(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *icmpPacketList) PushBack(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *icmpPacketList) PushBackList(m *icmpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) { + bLinker := icmpPacketElementMapper{}.linkerFor(b) + eLinker := icmpPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) { + aLinker := icmpPacketElementMapper{}.linkerFor(a) + eLinker := icmpPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *icmpPacketList) Remove(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + icmpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type icmpPacketEntry struct { + next *icmpPacket + prev *icmpPacket +} + +// Next returns the entry that follows e in the list. +func (e *icmpPacketEntry) Next() *icmpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *icmpPacketEntry) Prev() *icmpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *icmpPacketEntry) SetNext(elem *icmpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go new file mode 100755 index 000000000..b856a4b89 --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go @@ -0,0 +1,92 @@ +// automatically generated by stateify. + +package icmp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *icmpPacket) beforeSave() {} +func (x *icmpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("icmpPacketEntry", &x.icmpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("timestamp", &x.timestamp) +} + +func (x *icmpPacket) afterLoad() {} +func (x *icmpPacket) load(m state.Map) { + m.Load("icmpPacketEntry", &x.icmpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("timestamp", &x.timestamp) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("uniqueID", &x.uniqueID) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("state", &x.state) + m.Save("ttl", &x.ttl) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("uniqueID", &x.uniqueID) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("state", &x.state) + m.Load("ttl", &x.ttl) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *icmpPacketList) beforeSave() {} +func (x *icmpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *icmpPacketList) afterLoad() {} +func (x *icmpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *icmpPacketEntry) beforeSave() {} +func (x *icmpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *icmpPacketEntry) afterLoad() {} +func (x *icmpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/tcpip/transport/icmp.icmpPacket", (*icmpPacket)(nil), state.Fns{Save: (*icmpPacket).save, Load: (*icmpPacket).load}) + state.Register("pkg/tcpip/transport/icmp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("pkg/tcpip/transport/icmp.icmpPacketList", (*icmpPacketList)(nil), state.Fns{Save: (*icmpPacketList).save, Load: (*icmpPacketList).load}) + state.Register("pkg/tcpip/transport/icmp.icmpPacketEntry", (*icmpPacketEntry)(nil), state.Fns{Save: (*icmpPacketEntry).save, Load: (*icmpPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD deleted file mode 100644 index d22de6b26..000000000 --- a/pkg/tcpip/transport/packet/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "packet_list", - out = "packet_list.go", - package = "packet", - prefix = "packet", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*packet", - "Linker": "*packet", - }, -) - -go_library( - name = "packet", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 09a1cd436..09a1cd436 100644..100755 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..9b88f17e4 100644..100755 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go new file mode 100755 index 000000000..ad27c7c06 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_list.go @@ -0,0 +1,186 @@ +package packet + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type packetElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (packetElementMapper) linkerFor(elem *packet) *packet { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type packetList struct { + head *packet + tail *packet +} + +// Reset resets list l to the empty state. +func (l *packetList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *packetList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *packetList) Front() *packet { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *packetList) Back() *packet { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *packetList) PushFront(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + packetElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *packetList) PushBack(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *packetList) PushBackList(m *packetList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(m.head) + packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *packetList) InsertAfter(b, e *packet) { + bLinker := packetElementMapper{}.linkerFor(b) + eLinker := packetElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + packetElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *packetList) InsertBefore(a, e *packet) { + aLinker := packetElementMapper{}.linkerFor(a) + eLinker := packetElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + packetElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *packetList) Remove(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + packetElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + packetElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type packetEntry struct { + next *packet + prev *packet +} + +// Next returns the entry that follows e in the list. +func (e *packetEntry) Next() *packet { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *packetEntry) Prev() *packet { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *packetEntry) SetNext(elem *packet) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *packetEntry) SetPrev(elem *packet) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go new file mode 100755 index 000000000..8ff339e08 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -0,0 +1,90 @@ +// automatically generated by stateify. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *packet) beforeSave() {} +func (x *packet) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("packetEntry", &x.packetEntry) + m.Save("timestampNS", &x.timestampNS) + m.Save("senderAddr", &x.senderAddr) +} + +func (x *packet) afterLoad() {} +func (x *packet) load(m state.Map) { + m.Load("packetEntry", &x.packetEntry) + m.Load("timestampNS", &x.timestampNS) + m.Load("senderAddr", &x.senderAddr) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("netProto", &x.netProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("cooked", &x.cooked) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("closed", &x.closed) + m.Save("bound", &x.bound) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("netProto", &x.netProto) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("cooked", &x.cooked) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("closed", &x.closed) + m.Load("bound", &x.bound) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *packetList) beforeSave() {} +func (x *packetList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *packetList) afterLoad() {} +func (x *packetList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *packetEntry) beforeSave() {} +func (x *packetEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *packetEntry) afterLoad() {} +func (x *packetEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/tcpip/transport/packet.packet", (*packet)(nil), state.Fns{Save: (*packet).save, Load: (*packet).load}) + state.Register("pkg/tcpip/transport/packet.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("pkg/tcpip/transport/packet.packetList", (*packetList)(nil), state.Fns{Save: (*packetList).save, Load: (*packetList).load}) + state.Register("pkg/tcpip/transport/packet.packetEntry", (*packetEntry)(nil), state.Fns{Save: (*packetEntry).save, Load: (*packetEntry).load}) +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD deleted file mode 100644 index c9baf4600..000000000 --- a/pkg/tcpip/transport/raw/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "raw_packet_list", - out = "raw_packet_list.go", - package = "raw", - prefix = "rawPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*rawPacket", - "Linker": "*rawPacket", - }, -) - -go_library( - name = "raw", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "protocol.go", - "raw_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/packet", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go new file mode 100755 index 000000000..e8c1bc997 --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_packet_list.go @@ -0,0 +1,186 @@ +package raw + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type rawPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (rawPacketElementMapper) linkerFor(elem *rawPacket) *rawPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type rawPacketList struct { + head *rawPacket + tail *rawPacket +} + +// Reset resets list l to the empty state. +func (l *rawPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *rawPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *rawPacketList) Front() *rawPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *rawPacketList) Back() *rawPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *rawPacketList) PushFront(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *rawPacketList) PushBack(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *rawPacketList) PushBackList(m *rawPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + rawPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *rawPacketList) InsertAfter(b, e *rawPacket) { + bLinker := rawPacketElementMapper{}.linkerFor(b) + eLinker := rawPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + rawPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *rawPacketList) InsertBefore(a, e *rawPacket) { + aLinker := rawPacketElementMapper{}.linkerFor(a) + eLinker := rawPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + rawPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *rawPacketList) Remove(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + rawPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + rawPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type rawPacketEntry struct { + next *rawPacket + prev *rawPacket +} + +// Next returns the entry that follows e in the list. +func (e *rawPacketEntry) Next() *rawPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *rawPacketEntry) Prev() *rawPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *rawPacketEntry) SetNext(elem *rawPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *rawPacketEntry) SetPrev(elem *rawPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go new file mode 100755 index 000000000..41b72cf93 --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -0,0 +1,90 @@ +// automatically generated by stateify. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *rawPacket) beforeSave() {} +func (x *rawPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("rawPacketEntry", &x.rawPacketEntry) + m.Save("timestampNS", &x.timestampNS) + m.Save("senderAddr", &x.senderAddr) +} + +func (x *rawPacket) afterLoad() {} +func (x *rawPacket) load(m state.Map) { + m.Load("rawPacketEntry", &x.rawPacketEntry) + m.Load("timestampNS", &x.timestampNS) + m.Load("senderAddr", &x.senderAddr) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("associated", &x.associated) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("closed", &x.closed) + m.Save("connected", &x.connected) + m.Save("bound", &x.bound) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("associated", &x.associated) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("closed", &x.closed) + m.Load("connected", &x.connected) + m.Load("bound", &x.bound) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *rawPacketList) beforeSave() {} +func (x *rawPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *rawPacketList) afterLoad() {} +func (x *rawPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *rawPacketEntry) beforeSave() {} +func (x *rawPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *rawPacketEntry) afterLoad() {} +func (x *rawPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/tcpip/transport/raw.rawPacket", (*rawPacket)(nil), state.Fns{Save: (*rawPacket).save, Load: (*rawPacket).load}) + state.Register("pkg/tcpip/transport/raw.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("pkg/tcpip/transport/raw.rawPacketList", (*rawPacketList)(nil), state.Fns{Save: (*rawPacketList).save, Load: (*rawPacketList).load}) + state.Register("pkg/tcpip/transport/raw.rawPacketEntry", (*rawPacketEntry)(nil), state.Fns{Save: (*rawPacketEntry).save, Load: (*rawPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD deleted file mode 100644 index a32f9eacf..000000000 --- a/pkg/tcpip/transport/tcp/BUILD +++ /dev/null @@ -1,110 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "tcp_segment_list", - out = "tcp_segment_list.go", - package = "tcp", - prefix = "segment", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*segment", - "Linker": "*segment", - }, -) - -go_template_instance( - name = "tcp_endpoint_list", - out = "tcp_endpoint_list.go", - package = "tcp", - prefix = "endpoint", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*endpoint", - "Linker": "*endpoint", - }, -) - -go_library( - name = "tcp", - srcs = [ - "accept.go", - "connect.go", - "connect_unsafe.go", - "cubic.go", - "cubic_state.go", - "dispatcher.go", - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "rcv.go", - "rcv_state.go", - "reno.go", - "sack.go", - "sack_scoreboard.go", - "segment.go", - "segment_heap.go", - "segment_queue.go", - "segment_state.go", - "snd.go", - "snd_state.go", - "tcp_endpoint_list.go", - "tcp_segment_list.go", - "timer.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tmutex", - "//pkg/waiter", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "tcp_test", - size = "medium", - srcs = [ - "dual_stack_test.go", - "sack_scoreboard_test.go", - "tcp_noracedetector_test.go", - "tcp_sack_test.go", - "tcp_test.go", - "tcp_timestamp_test.go", - ], - # FIXME(b/68809571) - tags = ["flaky"], - deps = [ - ":tcp", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp/testing/context", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go index cfc304616..cfc304616 100644..100755 --- a/pkg/tcpip/transport/tcp/connect_unsafe.go +++ b/pkg/tcpip/transport/tcp/connect_unsafe.go diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index d792b07d6..d792b07d6 100644..100755 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go deleted file mode 100644 index 4f361b226..000000000 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ /dev/null @@ -1,652 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestV4MappedConnectOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Start connection attempt, it must fail. - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrNoRoute { - t.Fatalf("Unexpected return value from Connect: %v", err) - } -} - -func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv4(t, b, synCheckers...) - - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - )) - checker.IPv4(t, c.GetPacket(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV4MappedConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt to IPv6 address. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetV6Packet() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv6(t, b, synCheckers...) - - tcp := header.TCP(header.IPv6(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - )) - checker.IPv6(t, c.GetV6Packet(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV6Connect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to local address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV4RefuseOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), - ), - ) -} - -func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), - ), - ) -} - -func testV4Accept(t *testing.T, c *context.Context) { - c.SetGSOEnabled(true) - defer c.SetGSOEnabled(false) - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv4(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Make sure we get the same error when calling the original ep and the - // new one. This validates that v4-mapped endpoints are still able to - // query the V6Only flag, whereas pure v4 endpoints are not. - _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) - if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { - t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestAddr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) - } - - data := "Don't panic" - nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestV4AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV6AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv6(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Make sure we can still query the v6 only status of the new endpoint, - // that is, that it is in fact a v6 socket. - if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) - } -} - -func TestV4AcceptOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func testV4ListenClose(t *testing.T, c *context.Context) { - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 - const n = uint16(32) - - // Start listening. - if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - irs := seqnum.Value(789) - for i := uint16(0); i < n; i++ { - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + i, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } - - // Each of these ACK's will cause a syn-cookie based connection to be - // accepted and delivered to the listening endpoint. - for i := uint16(0); i < n; i++ { - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - } - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - nep.Close() - c.EP.Close() -} - -func TestV4ListenCloseOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4ListenClose(t, c) -} diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go index 2bf21a2e7..2bf21a2e7 100644..100755 --- a/pkg/tcpip/transport/tcp/rcv_state.go +++ b/pkg/tcpip/transport/tcp/rcv_state.go diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go deleted file mode 100644 index b4e5ba0df..000000000 --- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" -) - -const smss = 1500 - -func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard { - s := tcp.NewSACKScoreboard(smss, iss) - for _, blk := range blocks { - s.Insert(blk) - } - return s -} - -func TestSACKScoreboardIsSACKED(t *testing.T) { - type blockTest struct { - block header.SACKBlock - sacked bool - } - testCases := []struct { - comment string - scoreboardBlocks []header.SACKBlock - blockTests []blockTest - iss seqnum.Value - }{ - { - "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", - []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, - []blockTest{ - {header.SACKBlock{15, 21}, true}, - {header.SACKBlock{200, 201}, false}, - {header.SACKBlock{50, 51}, false}, - {header.SACKBlock{53, 120}, true}, - }, - 0, - }, - { - "Test disjoint SACKBlocks", - []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, - []blockTest{ - {header.SACKBlock{2288624809, 2288810057}, true}, - {header.SACKBlock{2288811477, 2288838565}, true}, - {header.SACKBlock{2288810057, 2288811477}, false}, - }, - 2288624809, - }, - { - "Test sequence number wrap around", - []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, - []blockTest{ - {header.SACKBlock{4294254144, 4294254145}, true}, - {header.SACKBlock{4294254143, 4294254144}, false}, - {header.SACKBlock{4294254144, 1}, true}, - {header.SACKBlock{225652, 5350509}, false}, - {header.SACKBlock{5340409, 5350509}, true}, - {header.SACKBlock{5350509, 5350609}, false}, - }, - 4294254144, - }, - { - "Test disjoint SACKBlocks out of order", - []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, - []blockTest{ - {header.SACKBlock{827426028, 827428867}, true}, - {header.SACKBlock{827450168, 827450275}, false}, - }, - 827426000, - }, - } - for _, tc := range testCases { - sb := initScoreboard(tc.scoreboardBlocks, tc.iss) - for _, blkTest := range tc.blockTests { - if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { - t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) - } - } - } -} - -func TestSACKScoreboardIsRangeLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - block header.SACKBlock - lost bool - }{ - // Block not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number covered by this block. - {block: header.SACKBlock{0, 1}, lost: true}, - - // These blocks have all been SACKed and should not be - // considered lost. - {block: header.SACKBlock{1, 2}, lost: false}, - {block: header.SACKBlock{25, 26}, lost: false}, - {block: header.SACKBlock{1, 45}, lost: false}, - - // Same as the first case above. - {block: header.SACKBlock{50, 51}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{119, 120}, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {block: header.SACKBlock{120, 121}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{125, 126}, lost: false}, - - // This block has not been SACKed and there are nDupAckThreshold - // number of SACKed blocks after it. - {block: header.SACKBlock{141, 145}, lost: true}, - - // This block has not been SACKed and there are less than - // nDupAckThreshold SACKed sequences after it. - {block: header.SACKBlock{151, 152}, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { - t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) - } - } -} - -func TestSACKScoreboardIsLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - seq seqnum.Value - lost bool - }{ - // Sequence number not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number. - {seq: 0, lost: true}, - - // These sequence numbers have all been SACKed and should not be - // considered lost. - {seq: 1, lost: false}, - {seq: 25, lost: false}, - {seq: 45, lost: false}, - - // Same as first case above. - {seq: 50, lost: true}, - - // This block has been SACKed and should not be considered lost. - {seq: 119, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {seq: 120, lost: true}, - - // This sequence number has been SACKed and should not be - // considered lost. - {seq: 125, lost: false}, - - // This sequence number has not been SACKed and there are - // nDupAckThreshold number of SACKed blocks after it. - {seq: 141, lost: true}, - - // This sequence number has not been SACKed and there are less - // than nDupAckThreshold SACKed sequences after it. - {seq: 151, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsLost(tc.seq); got != want { - t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) - } - } -} - -func TestSACKScoreboardDelete(t *testing.T) { - blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} - s := initScoreboard(blocks, 4294254143) - s.Delete(5340408) - if s.Empty() { - t.Fatalf("s.Empty() = true, want false") - } - if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } - s.Delete(5340410) - if s.Empty() { - t.Fatal("s.Empty() = true, want false") - } - newSB := header.SACKBlock{5340410, 5350509} - if !s.IsSACKED(newSB) { - t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) - } - s.Delete(5350509) - lastOctet := header.SACKBlock{5350508, 5350509} - if s.IsSACKED(lastOctet) { - t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) - } - - s.Delete(5350510) - if !s.Empty() { - t.Fatal("s.Empty() = false, want true") - } - if got, want := s.Sacked(), seqnum.Size(0); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go new file mode 100755 index 000000000..62c042aff --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go @@ -0,0 +1,186 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type endpointElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (endpointElementMapper) linkerFor(elem *endpoint) *endpoint { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type endpointList struct { + head *endpoint + tail *endpoint +} + +// Reset resets list l to the empty state. +func (l *endpointList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *endpointList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *endpointList) Front() *endpoint { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *endpointList) Back() *endpoint { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *endpointList) PushFront(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + endpointElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *endpointList) PushBack(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + endpointElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *endpointList) PushBackList(m *endpointList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + endpointElementMapper{}.linkerFor(l.tail).SetNext(m.head) + endpointElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *endpointList) InsertAfter(b, e *endpoint) { + bLinker := endpointElementMapper{}.linkerFor(b) + eLinker := endpointElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + endpointElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *endpointList) InsertBefore(a, e *endpoint) { + aLinker := endpointElementMapper{}.linkerFor(a) + eLinker := endpointElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + endpointElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *endpointList) Remove(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + endpointElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + endpointElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type endpointEntry struct { + next *endpoint + prev *endpoint +} + +// Next returns the entry that follows e in the list. +func (e *endpointEntry) Next() *endpoint { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *endpointEntry) Prev() *endpoint { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *endpointEntry) SetNext(elem *endpoint) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *endpointEntry) SetPrev(elem *endpoint) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go deleted file mode 100644 index 782d7b42c..000000000 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ /dev/null @@ -1,527 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// These tests are flaky when run under the go race detector due to some -// iterations taking long enough that the retransmit timer can kick in causing -// the congestion window measurements to fail due to extra packets etc. -// -// +build !race - -package tcp_test - -import ( - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" -) - -func TestFastRecovery(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - for i := 0; i < 3; i++ { - c.SendAck(790, rtxOffset) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) - } - - // Now send 7 mode duplicate acks. Each of these should cause a window - // inflation by 1 and cause the sender to send an extra packet. - for i := 0; i < 7; i++ { - c.SendAck(790, rtxOffset) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the retransmit due to partial ack. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) - } - - // Receive the 10 extra packets that should have been released due to - // the congestion window inflation in recovery. - for i := 0; i < 10; i++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // A partial ACK during recovery should reduce congestion window by the - // number acked. Since we had "expected" packets outstanding before sending - // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered - // fast recovery). Which means the sender should not send any more packets - // till we ack this one. - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", - 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 10 - // packets outstanding. - // - // NOTE: Technically netstack is incorrect in that we adjust the cwnd on - // the same segment that takes us out of recovery. But because of that - // the actual cwnd at exit of recovery will be expected/2 + 1 as we - // acked a cwnd worth of packets which will increase the cwnd further by - // 1 in congestion avoidance. - // - // Now in the first iteration since there are 10 packets outstanding. - // We would expect to get expected/2 +1 - 10 packets. But subsequent - // iterations will send us expected/2 + 1 + 1 (per iteration). - expected = expected/2 + 1 - 10 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 10 - } - expected++ - } -} - -func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // Double the number of expected packets for the next iteration. - expected *= 2 - } -} - -func TestCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd/2. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected/2. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 (which "consumes" expected/2-1 of the - // acknowledgements), then the congestion avoidance part will consume - // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack - // remains in the "ack count" (which will cause cwnd to be incremented - // once it reaches cwnd acks). - // - // So we're straight into congestion avoidance with cwnd set to - // expected/2 + 1. - // - // Check that packets trains of cwnd packets are sent, and that cwnd is - // incremented by 1 after we acknowledge each packet. - expected = expected/2 + 1 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - expected++ - } -} - -// cubicCwnd returns an estimate of a cubic window given the -// originalCwnd, wMax, last congestion event time and sRTT. -func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { - cwnd := float64(origCwnd) - // We wait 50ms between each iteration so sRTT as computed by cubic - // should be close to 50ms. - elapsed := (time.Since(congEventTime) + sRTT).Seconds() - k := math.Cbrt(float64(wMax) * 0.3 / 0.7) - wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) - cwnd += (wtRTT - cwnd) / cwnd - return int(cwnd) -} - -func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - enableCUBIC(t, c) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd * 0.7. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all pending data. - c.SendAck(790, bytesRead) - - // Store away the time we sent the ACK and assuming a 200ms RTO - // we estimate that the sender will have an RTO 200ms from now - // and go back into slow start. - packetDropTime := time.Now().Add(200 * time.Millisecond) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 essentially putting the connection - // straight into congestion avoidance. - wMax := expected - // Lower expected as per cubic spec after a congestion event. - expected = int(float64(expected) * 0.7) - cwnd := expected - for i := 0; i < iterations; i++ { - // Cubic grows window independent of ACKs. Cubic Window growth - // is a function of time elapsed since last congestion event. - // As a result the congestion window does not grow - // deterministically in response to ACKs. - // - // We need to roughly estimate what the cwnd of the sender is - // based on when we sent the dupacks. - cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) - - packetsExpected := cwnd - for j := 0; j < packetsExpected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - t.Logf("expected packets received, next trying to receive any extra packets that may come") - - // If our estimate was correct there should be no more pending packets. - // We attempt to read a packet a few times with a short sleep in between - // to ensure that we don't see the sender send any unexpected packets. - unexpectedPackets := 0 - for { - gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) - if !gotPacket { - break - } - bytesRead += maxPayload - unexpectedPackets++ - time.Sleep(1 * time.Millisecond) - } - if unexpectedPackets != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - } -} - -func TestRetransmit(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in two shots. Packets will only be written at the - // MTU size though. - half := data[:len(data)/2] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - half = data[len(data)/2:] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Wait for a timeout and retransmit. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) - } - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the remaining data, making sure that acknowledged data is not - // retransmitted. - for offset := rtxOffset; offset < len(data); offset += maxPayload { - c.ReceiveAndCheckPacket(data, offset, maxPayload) - c.SendAck(790, offset+maxPayload) - } - - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) -} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go deleted file mode 100644 index afea124ec..000000000 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ /dev/null @@ -1,572 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "fmt" - "log" - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" -) - -// createConnectedWithSACKPermittedOption creates and connects c.ep with the -// SACKPermitted option enabled if the stack in the context has the SACK support -// enabled. -func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) -} - -// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS -// option enabled if the stack in the context has SACK and TS enabled. -func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) -} - -func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { - t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err) - } -} - -// TestSackPermittedConnect establishes a connection with the SACK option -// enabled. -func TestSackPermittedConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - rep := createConnectedWithSACKPermittedOption(c) - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // Restore the saved sequence number so that the - // VerifyXXX calls use the right sequence number for - // checking ACK numbers. - rep.NextSeqNum = savedSeqNum - if sackEnabled { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackDisabledConnect establishes a connection with the SACK option -// disabled and verifies that no SACKs are sent for out of order segments. -func TestSackDisabledConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) - - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older sequence number and - // no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackPermittedAccept accepts and establishes a connection with the -// SACKPermitted option enabled if the connection request specifies the -// SACKPermitted option. In case of SYN cookies SACK should be disabled as we -// don't encode the SACK information in the cookie. -func TestSackPermittedAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - sackPermitted bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number. - rep.NextSeqNum = savedSeqNum - if sackEnabled && tc.sackPermitted { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -// TestSackDisabledAccept accepts and establishes a connection with -// the SACKPermitted option disabled and verifies that no SACKs are -// sent for out of order packets. -func TestSackDisabledAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number and no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -func TestUpdateSACKBlocks(t *testing.T) { - testCases := []struct { - segStart seqnum.Value - segEnd seqnum.Value - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - updated []header.SACKBlock - }{ - // Trivial cases where current SACK block list is empty and we - // have an out of order delivery. - {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, - {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, - {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, - - // Cases where current SACK block list is not empty and we have - // an out of order delivery. Tests that the updated SACK block - // list has the first block as the one that contains the new - // SACK block representing the segment that was just delivered. - {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, - - // Ensure that we only retain header.MaxSACKBlocks and drop the - // oldest one if adding a new block exceeds - // header.MaxSACKBlocks. - {24, 30, 9, - []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, - []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, - - // Cases where segment extends an existing SACK block. - {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, - {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, - - // Cases where segment contains rcvNxt. - {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, - } - - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { - t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) - } - - } -} - -func TestTrimSackBlockList(t *testing.T) { - testCases := []struct { - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - trimmed []header.SACKBlock - }{ - // Simple cases where we trim whole entries. - {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, - {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, - {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, - {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - // Cases where we need to update a block. - {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, - {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, - {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, - {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - } - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.TrimSACKBlockList(&sack, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { - t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) - } - } -} - -func TestSACKRecovery(t *testing.T) { - const maxPayload = 10 - // See: tcp.makeOptions for why tsOptionSize is set to 12 here. - const tsOptionSize = 12 - // Enabling SACK means the payload size is reduced to account - // for the extra space required for the TCP options. - // - // We increase the MTU by 40 bytes to account for SACK and Timestamp - // options. - const maxTCPOptionSize = 40 - - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) - defer c.Cleanup() - - c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { - // We use log.Printf instead of t.Logf here because this probe - // can fire even when the test function has finished. This is - // because closing the endpoint in cleanup() does not mean the - // actual worker loop terminates immediately as it still has to - // do a full TCP shutdown. But this test can finish running - // before the shutdown is done. Using t.Logf in such a case - // causes the test to panic due to logging after test finished. - log.Printf("state: %+v\n", s) - }) - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end := start.Add(10) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want) - } - } - - // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause - // window inflation and sending of packets is completely handled by the - // SACK Recovery algorithm. We should see no packets being released, as - // the cwnd at this point after entering recovery should be half of the - // outstanding number of packets in flight. - for i := 0; i < 7; i++ { - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. This along with the 10 sacked - // segments above should reduce the outstanding below the current - // congestion window allowing the sender to transmit data. - rtxOffset = bytesRead - expected*maxPayload/2 - - // Now send a partial ACK w/ a SACK block that indicates that the next 3 - // segments are lost and we have received 6 segments after the lost - // segments. This should cause the sender to immediately transmit all 3 - // segments in response to this ACK unlike in FastRecovery where only 1 - // segment is retransmitted per ACK. - start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end = start.Add(60) - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - - // At this point, we acked expected/2 packets and we SACKED 6 packets and - // 3 segments were considered lost due to the SACK block we sent. - // - // So total packets outstanding can be calculated as follows after 7 - // iterations of slow start -> 10/20/40/80/160/320/640. So expected - // should be 640 at start, then we went to recover at which point the - // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the - // network). - // Outstanding at this point after acking half the window - // (320 packets) will be: - // outstanding = 640-320-6(due to SACK block)-3 = 311 - // - // The last 3 is due to the fact that the first 3 packets after - // rtxOffset will be considered lost due to the SACK blocks sent. - // Receive the retransmit due to partial ack. - - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - // Receive the 2 extra packets that should have been retransmitted as - // those should be considered lost and immediately retransmitted based - // on the SACK information in the previous ACK sent above. - for i := 0; i < 2; i++ { - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) - } - - // Now we should get 9 more new unsent packets as the cwnd is 323 and - // outstanding is 311. - for i := 0; i < 9; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) - } - - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 9 - // packets outstanding. - // - // Now in the first iteration since there are 9 packets outstanding. - // We would expect to get expected/2 - 9 packets. But subsequent - // iterations will send us expected/2 + 1 (per iteration). - expected = expected/2 - 9 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 9 - } - expected++ - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go new file mode 100755 index 000000000..27f17f037 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,186 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *segmentList) Back() *segment { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *segmentList) PushFront(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *segmentList) PushBack(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *segmentList) InsertAfter(b, e *segment) { + bLinker := segmentElementMapper{}.linkerFor(b) + eLinker := segmentElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *segmentList) InsertBefore(a, e *segment) { + aLinker := segmentElementMapper{}.linkerFor(a) + eLinker := segmentElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *segmentList) Remove(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go new file mode 100755 index 000000000..9c1514e62 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -0,0 +1,531 @@ +// automatically generated by stateify. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *cubicState) beforeSave() {} +func (x *cubicState) save(m state.Map) { + x.beforeSave() + var t unixTime = x.saveT() + m.SaveValue("t", t) + m.Save("wLastMax", &x.wLastMax) + m.Save("wMax", &x.wMax) + m.Save("numCongestionEvents", &x.numCongestionEvents) + m.Save("c", &x.c) + m.Save("k", &x.k) + m.Save("beta", &x.beta) + m.Save("wC", &x.wC) + m.Save("wEst", &x.wEst) + m.Save("s", &x.s) +} + +func (x *cubicState) afterLoad() {} +func (x *cubicState) load(m state.Map) { + m.Load("wLastMax", &x.wLastMax) + m.Load("wMax", &x.wMax) + m.Load("numCongestionEvents", &x.numCongestionEvents) + m.Load("c", &x.c) + m.Load("k", &x.k) + m.Load("beta", &x.beta) + m.Load("wC", &x.wC) + m.Load("wEst", &x.wEst) + m.Load("s", &x.s) + m.LoadValue("t", new(unixTime), func(y interface{}) { x.loadT(y.(unixTime)) }) +} + +func (x *SACKInfo) beforeSave() {} +func (x *SACKInfo) save(m state.Map) { + x.beforeSave() + m.Save("Blocks", &x.Blocks) + m.Save("NumBlocks", &x.NumBlocks) +} + +func (x *SACKInfo) afterLoad() {} +func (x *SACKInfo) load(m state.Map) { + m.Load("Blocks", &x.Blocks) + m.Load("NumBlocks", &x.NumBlocks) +} + +func (x *rcvBufAutoTuneParams) beforeSave() {} +func (x *rcvBufAutoTuneParams) save(m state.Map) { + x.beforeSave() + var measureTime unixTime = x.saveMeasureTime() + m.SaveValue("measureTime", measureTime) + var rttMeasureTime unixTime = x.saveRttMeasureTime() + m.SaveValue("rttMeasureTime", rttMeasureTime) + m.Save("copied", &x.copied) + m.Save("prevCopied", &x.prevCopied) + m.Save("rtt", &x.rtt) + m.Save("rttMeasureSeqNumber", &x.rttMeasureSeqNumber) + m.Save("disabled", &x.disabled) +} + +func (x *rcvBufAutoTuneParams) afterLoad() {} +func (x *rcvBufAutoTuneParams) load(m state.Map) { + m.Load("copied", &x.copied) + m.Load("prevCopied", &x.prevCopied) + m.Load("rtt", &x.rtt) + m.Load("rttMeasureSeqNumber", &x.rttMeasureSeqNumber) + m.Load("disabled", &x.disabled) + m.LoadValue("measureTime", new(unixTime), func(y interface{}) { x.loadMeasureTime(y.(unixTime)) }) + m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) }) +} + +func (x *EndpointInfo) beforeSave() {} +func (x *EndpointInfo) save(m state.Map) { + x.beforeSave() + var HardError string = x.saveHardError() + m.SaveValue("HardError", HardError) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) +} + +func (x *EndpointInfo) afterLoad() {} +func (x *EndpointInfo) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.LoadValue("HardError", new(string), func(y interface{}) { x.loadHardError(y.(string)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var lastError string = x.saveLastError() + m.SaveValue("lastError", lastError) + var state EndpointState = x.saveState() + m.SaveValue("state", state) + var acceptedChan []*endpoint = x.saveAcceptedChan() + m.SaveValue("acceptedChan", acceptedChan) + m.Save("EndpointInfo", &x.EndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("uniqueID", &x.uniqueID) + m.Save("rcvList", &x.rcvList) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvBufUsed", &x.rcvBufUsed) + m.Save("rcvAutoParams", &x.rcvAutoParams) + m.Save("zeroWindow", &x.zeroWindow) + m.Save("isRegistered", &x.isRegistered) + m.Save("ttl", &x.ttl) + m.Save("v6only", &x.v6only) + m.Save("isConnectNotified", &x.isConnectNotified) + m.Save("broadcast", &x.broadcast) + m.Save("boundBindToDevice", &x.boundBindToDevice) + m.Save("boundPortFlags", &x.boundPortFlags) + m.Save("workerRunning", &x.workerRunning) + m.Save("workerCleanup", &x.workerCleanup) + m.Save("sendTSOk", &x.sendTSOk) + m.Save("recentTS", &x.recentTS) + m.Save("tsOffset", &x.tsOffset) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("sackPermitted", &x.sackPermitted) + m.Save("sack", &x.sack) + m.Save("reusePort", &x.reusePort) + m.Save("bindToDevice", &x.bindToDevice) + m.Save("delay", &x.delay) + m.Save("cork", &x.cork) + m.Save("scoreboard", &x.scoreboard) + m.Save("reuseAddr", &x.reuseAddr) + m.Save("slowAck", &x.slowAck) + m.Save("segmentQueue", &x.segmentQueue) + m.Save("synRcvdCount", &x.synRcvdCount) + m.Save("userMSS", &x.userMSS) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("sndBufUsed", &x.sndBufUsed) + m.Save("sndClosed", &x.sndClosed) + m.Save("sndBufInQueue", &x.sndBufInQueue) + m.Save("sndQueue", &x.sndQueue) + m.Save("cc", &x.cc) + m.Save("packetTooBigCount", &x.packetTooBigCount) + m.Save("sndMTU", &x.sndMTU) + m.Save("keepalive", &x.keepalive) + m.Save("userTimeout", &x.userTimeout) + m.Save("deferAccept", &x.deferAccept) + m.Save("rcv", &x.rcv) + m.Save("snd", &x.snd) + m.Save("connectingAddress", &x.connectingAddress) + m.Save("amss", &x.amss) + m.Save("sendTOS", &x.sendTOS) + m.Save("gso", &x.gso) + m.Save("tcpLingerTimeout", &x.tcpLingerTimeout) + m.Save("closed", &x.closed) +} + +func (x *endpoint) load(m state.Map) { + m.Load("EndpointInfo", &x.EndpointInfo) + m.LoadWait("waiterQueue", &x.waiterQueue) + m.Load("uniqueID", &x.uniqueID) + m.LoadWait("rcvList", &x.rcvList) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvBufUsed", &x.rcvBufUsed) + m.Load("rcvAutoParams", &x.rcvAutoParams) + m.Load("zeroWindow", &x.zeroWindow) + m.Load("isRegistered", &x.isRegistered) + m.Load("ttl", &x.ttl) + m.Load("v6only", &x.v6only) + m.Load("isConnectNotified", &x.isConnectNotified) + m.Load("broadcast", &x.broadcast) + m.Load("boundBindToDevice", &x.boundBindToDevice) + m.Load("boundPortFlags", &x.boundPortFlags) + m.Load("workerRunning", &x.workerRunning) + m.Load("workerCleanup", &x.workerCleanup) + m.Load("sendTSOk", &x.sendTSOk) + m.Load("recentTS", &x.recentTS) + m.Load("tsOffset", &x.tsOffset) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("sackPermitted", &x.sackPermitted) + m.Load("sack", &x.sack) + m.Load("reusePort", &x.reusePort) + m.Load("bindToDevice", &x.bindToDevice) + m.Load("delay", &x.delay) + m.Load("cork", &x.cork) + m.Load("scoreboard", &x.scoreboard) + m.Load("reuseAddr", &x.reuseAddr) + m.Load("slowAck", &x.slowAck) + m.LoadWait("segmentQueue", &x.segmentQueue) + m.Load("synRcvdCount", &x.synRcvdCount) + m.Load("userMSS", &x.userMSS) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("sndBufUsed", &x.sndBufUsed) + m.Load("sndClosed", &x.sndClosed) + m.Load("sndBufInQueue", &x.sndBufInQueue) + m.LoadWait("sndQueue", &x.sndQueue) + m.Load("cc", &x.cc) + m.Load("packetTooBigCount", &x.packetTooBigCount) + m.Load("sndMTU", &x.sndMTU) + m.Load("keepalive", &x.keepalive) + m.Load("userTimeout", &x.userTimeout) + m.Load("deferAccept", &x.deferAccept) + m.LoadWait("rcv", &x.rcv) + m.LoadWait("snd", &x.snd) + m.Load("connectingAddress", &x.connectingAddress) + m.Load("amss", &x.amss) + m.Load("sendTOS", &x.sendTOS) + m.Load("gso", &x.gso) + m.Load("tcpLingerTimeout", &x.tcpLingerTimeout) + m.Load("closed", &x.closed) + m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) }) + m.LoadValue("state", new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) }) + m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *keepalive) beforeSave() {} +func (x *keepalive) save(m state.Map) { + x.beforeSave() + m.Save("enabled", &x.enabled) + m.Save("idle", &x.idle) + m.Save("interval", &x.interval) + m.Save("count", &x.count) + m.Save("unacked", &x.unacked) +} + +func (x *keepalive) afterLoad() {} +func (x *keepalive) load(m state.Map) { + m.Load("enabled", &x.enabled) + m.Load("idle", &x.idle) + m.Load("interval", &x.interval) + m.Load("count", &x.count) + m.Load("unacked", &x.unacked) +} + +func (x *receiver) beforeSave() {} +func (x *receiver) save(m state.Map) { + x.beforeSave() + var lastRcvdAckTime unixTime = x.saveLastRcvdAckTime() + m.SaveValue("lastRcvdAckTime", lastRcvdAckTime) + m.Save("ep", &x.ep) + m.Save("rcvNxt", &x.rcvNxt) + m.Save("rcvAcc", &x.rcvAcc) + m.Save("rcvWnd", &x.rcvWnd) + m.Save("rcvWndScale", &x.rcvWndScale) + m.Save("closed", &x.closed) + m.Save("pendingRcvdSegments", &x.pendingRcvdSegments) + m.Save("pendingBufUsed", &x.pendingBufUsed) + m.Save("pendingBufSize", &x.pendingBufSize) +} + +func (x *receiver) afterLoad() {} +func (x *receiver) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("rcvNxt", &x.rcvNxt) + m.Load("rcvAcc", &x.rcvAcc) + m.Load("rcvWnd", &x.rcvWnd) + m.Load("rcvWndScale", &x.rcvWndScale) + m.Load("closed", &x.closed) + m.Load("pendingRcvdSegments", &x.pendingRcvdSegments) + m.Load("pendingBufUsed", &x.pendingBufUsed) + m.Load("pendingBufSize", &x.pendingBufSize) + m.LoadValue("lastRcvdAckTime", new(unixTime), func(y interface{}) { x.loadLastRcvdAckTime(y.(unixTime)) }) +} + +func (x *renoState) beforeSave() {} +func (x *renoState) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *renoState) afterLoad() {} +func (x *renoState) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *SACKScoreboard) beforeSave() {} +func (x *SACKScoreboard) save(m state.Map) { + x.beforeSave() + m.Save("smss", &x.smss) + m.Save("maxSACKED", &x.maxSACKED) +} + +func (x *SACKScoreboard) afterLoad() {} +func (x *SACKScoreboard) load(m state.Map) { + m.Load("smss", &x.smss) + m.Load("maxSACKED", &x.maxSACKED) +} + +func (x *segment) beforeSave() {} +func (x *segment) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + var options []byte = x.saveOptions() + m.SaveValue("options", options) + var rcvdTime unixTime = x.saveRcvdTime() + m.SaveValue("rcvdTime", rcvdTime) + var xmitTime unixTime = x.saveXmitTime() + m.SaveValue("xmitTime", xmitTime) + m.Save("segmentEntry", &x.segmentEntry) + m.Save("refCnt", &x.refCnt) + m.Save("viewToDeliver", &x.viewToDeliver) + m.Save("sequenceNumber", &x.sequenceNumber) + m.Save("ackNumber", &x.ackNumber) + m.Save("flags", &x.flags) + m.Save("window", &x.window) + m.Save("csum", &x.csum) + m.Save("csumValid", &x.csumValid) + m.Save("parsedOptions", &x.parsedOptions) + m.Save("hasNewSACKInfo", &x.hasNewSACKInfo) +} + +func (x *segment) afterLoad() {} +func (x *segment) load(m state.Map) { + m.Load("segmentEntry", &x.segmentEntry) + m.Load("refCnt", &x.refCnt) + m.Load("viewToDeliver", &x.viewToDeliver) + m.Load("sequenceNumber", &x.sequenceNumber) + m.Load("ackNumber", &x.ackNumber) + m.Load("flags", &x.flags) + m.Load("window", &x.window) + m.Load("csum", &x.csum) + m.Load("csumValid", &x.csumValid) + m.Load("parsedOptions", &x.parsedOptions) + m.Load("hasNewSACKInfo", &x.hasNewSACKInfo) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) + m.LoadValue("options", new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) }) + m.LoadValue("rcvdTime", new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) }) + m.LoadValue("xmitTime", new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) }) +} + +func (x *segmentQueue) beforeSave() {} +func (x *segmentQueue) save(m state.Map) { + x.beforeSave() + m.Save("list", &x.list) + m.Save("limit", &x.limit) + m.Save("used", &x.used) +} + +func (x *segmentQueue) afterLoad() {} +func (x *segmentQueue) load(m state.Map) { + m.LoadWait("list", &x.list) + m.Load("limit", &x.limit) + m.Load("used", &x.used) +} + +func (x *sender) beforeSave() {} +func (x *sender) save(m state.Map) { + x.beforeSave() + var lastSendTime unixTime = x.saveLastSendTime() + m.SaveValue("lastSendTime", lastSendTime) + var rttMeasureTime unixTime = x.saveRttMeasureTime() + m.SaveValue("rttMeasureTime", rttMeasureTime) + var firstRetransmittedSegXmitTime unixTime = x.saveFirstRetransmittedSegXmitTime() + m.SaveValue("firstRetransmittedSegXmitTime", firstRetransmittedSegXmitTime) + m.Save("ep", &x.ep) + m.Save("dupAckCount", &x.dupAckCount) + m.Save("fr", &x.fr) + m.Save("sndCwnd", &x.sndCwnd) + m.Save("sndSsthresh", &x.sndSsthresh) + m.Save("sndCAAckCount", &x.sndCAAckCount) + m.Save("outstanding", &x.outstanding) + m.Save("sndWnd", &x.sndWnd) + m.Save("sndUna", &x.sndUna) + m.Save("sndNxt", &x.sndNxt) + m.Save("sndNxtList", &x.sndNxtList) + m.Save("rttMeasureSeqNum", &x.rttMeasureSeqNum) + m.Save("closed", &x.closed) + m.Save("writeNext", &x.writeNext) + m.Save("writeList", &x.writeList) + m.Save("rtt", &x.rtt) + m.Save("rto", &x.rto) + m.Save("maxPayloadSize", &x.maxPayloadSize) + m.Save("gso", &x.gso) + m.Save("sndWndScale", &x.sndWndScale) + m.Save("maxSentAck", &x.maxSentAck) + m.Save("state", &x.state) + m.Save("cc", &x.cc) +} + +func (x *sender) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("dupAckCount", &x.dupAckCount) + m.Load("fr", &x.fr) + m.Load("sndCwnd", &x.sndCwnd) + m.Load("sndSsthresh", &x.sndSsthresh) + m.Load("sndCAAckCount", &x.sndCAAckCount) + m.Load("outstanding", &x.outstanding) + m.Load("sndWnd", &x.sndWnd) + m.Load("sndUna", &x.sndUna) + m.Load("sndNxt", &x.sndNxt) + m.Load("sndNxtList", &x.sndNxtList) + m.Load("rttMeasureSeqNum", &x.rttMeasureSeqNum) + m.Load("closed", &x.closed) + m.Load("writeNext", &x.writeNext) + m.Load("writeList", &x.writeList) + m.Load("rtt", &x.rtt) + m.Load("rto", &x.rto) + m.Load("maxPayloadSize", &x.maxPayloadSize) + m.Load("gso", &x.gso) + m.Load("sndWndScale", &x.sndWndScale) + m.Load("maxSentAck", &x.maxSentAck) + m.Load("state", &x.state) + m.Load("cc", &x.cc) + m.LoadValue("lastSendTime", new(unixTime), func(y interface{}) { x.loadLastSendTime(y.(unixTime)) }) + m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) }) + m.LoadValue("firstRetransmittedSegXmitTime", new(unixTime), func(y interface{}) { x.loadFirstRetransmittedSegXmitTime(y.(unixTime)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *rtt) beforeSave() {} +func (x *rtt) save(m state.Map) { + x.beforeSave() + m.Save("srtt", &x.srtt) + m.Save("rttvar", &x.rttvar) + m.Save("srttInited", &x.srttInited) +} + +func (x *rtt) afterLoad() {} +func (x *rtt) load(m state.Map) { + m.Load("srtt", &x.srtt) + m.Load("rttvar", &x.rttvar) + m.Load("srttInited", &x.srttInited) +} + +func (x *fastRecovery) beforeSave() {} +func (x *fastRecovery) save(m state.Map) { + x.beforeSave() + m.Save("active", &x.active) + m.Save("first", &x.first) + m.Save("last", &x.last) + m.Save("maxCwnd", &x.maxCwnd) + m.Save("highRxt", &x.highRxt) + m.Save("rescueRxt", &x.rescueRxt) +} + +func (x *fastRecovery) afterLoad() {} +func (x *fastRecovery) load(m state.Map) { + m.Load("active", &x.active) + m.Load("first", &x.first) + m.Load("last", &x.last) + m.Load("maxCwnd", &x.maxCwnd) + m.Load("highRxt", &x.highRxt) + m.Load("rescueRxt", &x.rescueRxt) +} + +func (x *unixTime) beforeSave() {} +func (x *unixTime) save(m state.Map) { + x.beforeSave() + m.Save("second", &x.second) + m.Save("nano", &x.nano) +} + +func (x *unixTime) afterLoad() {} +func (x *unixTime) load(m state.Map) { + m.Load("second", &x.second) + m.Load("nano", &x.nano) +} + +func (x *endpointList) beforeSave() {} +func (x *endpointList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *endpointList) afterLoad() {} +func (x *endpointList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *endpointEntry) beforeSave() {} +func (x *endpointEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *endpointEntry) afterLoad() {} +func (x *endpointEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *segmentList) beforeSave() {} +func (x *segmentList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *segmentList) afterLoad() {} +func (x *segmentList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *segmentEntry) beforeSave() {} +func (x *segmentEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *segmentEntry) afterLoad() {} +func (x *segmentEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/tcpip/transport/tcp.cubicState", (*cubicState)(nil), state.Fns{Save: (*cubicState).save, Load: (*cubicState).load}) + state.Register("pkg/tcpip/transport/tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load}) + state.Register("pkg/tcpip/transport/tcp.rcvBufAutoTuneParams", (*rcvBufAutoTuneParams)(nil), state.Fns{Save: (*rcvBufAutoTuneParams).save, Load: (*rcvBufAutoTuneParams).load}) + state.Register("pkg/tcpip/transport/tcp.EndpointInfo", (*EndpointInfo)(nil), state.Fns{Save: (*EndpointInfo).save, Load: (*EndpointInfo).load}) + state.Register("pkg/tcpip/transport/tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("pkg/tcpip/transport/tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load}) + state.Register("pkg/tcpip/transport/tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load}) + state.Register("pkg/tcpip/transport/tcp.renoState", (*renoState)(nil), state.Fns{Save: (*renoState).save, Load: (*renoState).load}) + state.Register("pkg/tcpip/transport/tcp.SACKScoreboard", (*SACKScoreboard)(nil), state.Fns{Save: (*SACKScoreboard).save, Load: (*SACKScoreboard).load}) + state.Register("pkg/tcpip/transport/tcp.segment", (*segment)(nil), state.Fns{Save: (*segment).save, Load: (*segment).load}) + state.Register("pkg/tcpip/transport/tcp.segmentQueue", (*segmentQueue)(nil), state.Fns{Save: (*segmentQueue).save, Load: (*segmentQueue).load}) + state.Register("pkg/tcpip/transport/tcp.sender", (*sender)(nil), state.Fns{Save: (*sender).save, Load: (*sender).load}) + state.Register("pkg/tcpip/transport/tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load}) + state.Register("pkg/tcpip/transport/tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load}) + state.Register("pkg/tcpip/transport/tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load}) + state.Register("pkg/tcpip/transport/tcp.endpointList", (*endpointList)(nil), state.Fns{Save: (*endpointList).save, Load: (*endpointList).load}) + state.Register("pkg/tcpip/transport/tcp.endpointEntry", (*endpointEntry)(nil), state.Fns{Save: (*endpointEntry).save, Load: (*endpointEntry).load}) + state.Register("pkg/tcpip/transport/tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load}) + state.Register("pkg/tcpip/transport/tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load}) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go deleted file mode 100644 index 5b2b16afa..000000000 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ /dev/null @@ -1,6970 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65535 - - // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an - // IPv4 endpoint when the MTU is set to defaultMTU in the test. - defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize -) - -func TestGiveUpConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - // Close the connection, wait for completion. - ep.Close() - - // Wait for ep to become writable. - <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { - t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) - } - - // Call Connect again to retreive the handshake failure status - // and stats updates. - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) - } - - if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } -} - -func TestConnectIncrementActiveConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) - } -} - -func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.FailedConnectionAttempts.Value() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) - } -} - -func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - want := stats.TCP.FailedConnectionAttempts.Value() + 1 - - if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) - } - - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) - } -} - -func TestTCPSegmentsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - // SYN and ACK - want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { - t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) - } -} - -func TestTCPResetsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - want := stats.TCP.SegmentsSent.Value() + 1 - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - // If the AckNum is not the increment of the last sequence number, a RST - // segment is sent back in response. - AckNum: c.IRS + 2, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.GetPacket() - if got := stats.TCP.ResetsSent.Value(); got != want { - t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) - } -} - -// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates -// a RST if an ACK is received on the listening socket for which there is no -// active handshake in progress and we are not using SYN cookies. -func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Lower stackwide TIME_WAIT timeout so that the reservations - // are released instantly on Close. - tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - c.GetPacket() - - // Since an active close was done we need to wait for a little more than - // tcpLingerTimeout for the port reservations to be released and the - // socket to move to a CLOSED state. - time.Sleep(20 * time.Millisecond) - - // Now resend the same ACK, this ACK should generate a RST as there - // should be no endpoint in SYN-RCVD state and we are not using - // syn-cookies yet. The reason we send the same ACK is we need a valid - // cookie(IRS) generated by the netstack without which the ACK will be - // rejected. - c.SendPacket(nil, ackHeaders) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestTCPResetsReceivedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) - } -} - -func TestTCPResetsDoNotGenerateResets(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) - } - c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) -} - -func TestActiveHandshake(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) -} - -func TestNonBlockingClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) - } -} - -func TestConnectResetAfterClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLinger to 3 seconds so that sockets are marked closed - // after 3 second in FIN_WAIT2 state. - tcpLingerTimeout := 3 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) - } - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint, make sure we get a FIN segment, then acknowledge - // to complete closure of sender, but don't send our own FIN. - ep.Close() - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Wait for the ep to give up waiting for a FIN. - time.Sleep(tcpLingerTimeout + 1*time.Second) - - // Now send an ACK and it should trigger a RST as the endpoint should - // not exist anymore. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - for { - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { - // This is a retransmit of the FIN, ignore it. - continue - } - - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence number - // of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - break - } -} - -// TestCurrentConnectedIncrement tests increment of the current -// established and connected counters. -func TestCurrentConnectedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) - } - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got) - } - gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() - if gotConnected != 1 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected) - } - - ep.Close() - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected) - } - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait for a little more than the TIME-WAIT duration for the socket to - // transition to CLOSED state. - time.Sleep(1200 * time.Millisecond) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got) - } -} - -// TestClosingWithEnqueuedSegments tests handling of still enqueued segments -// when the endpoint transitions to StateClose. The in-flight segments would be -// re-enqueued to a any listening endpoint. -func TestClosingWithEnqueuedSegments(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Send a FIN for ESTABLISHED --> CLOSED-WAIT - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Get the ACK for the FIN we sent. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Close the application endpoint for CLOSE_WAIT --> LAST_ACK - ep.Close() - - // Get the FIN - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Pause the endpoint`s protocolMainLoop. - ep.(interface{ StopWork() }).StopWork() - - // Enqueue last ACK followed by an ACK matching the endpoint - // - // Send Last ACK for LAST_ACK --> CLOSED - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 791, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Send a packet with ACK set, this would generate RST when - // not using SYN cookies as in this test. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 792, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Unpause endpoint`s protocolMainLoop. - ep.(interface{ ResumeWork() }).ResumeWork() - - // Wait for the protocolMainLoop to resume and update state. - time.Sleep(10 * time.Millisecond) - - // Expect the endpoint to be closed. - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } - - // Check if the endpoint was moved to CLOSED and netstack a reset in - // response to the ACK packet that we sent after last-ACK. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) -} - -func TestSimpleReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when -// creating a new active IPv4 TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnectV4(t *testing.T) { - const mtu = 5000 - const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, - }, - { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - c.Create(-1) - - // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) - } - - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - - // Start connection attempt to IPv4 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet with our user supplied MSS. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) - }) - } -} - -// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when -// creating a new active IPv6 TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnectV6(t *testing.T) { - const mtu = 5000 - const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, - }, - { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) - } - - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - - // Start connection attempt to IPv6 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet with our user supplied MSS. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) - }) - } -} - -func TestSendRstOnListenerRxSynAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) -} - -func TestSendRstOnListenerRxSynAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) -} - -// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv4. -func TestTCPAckBeforeAcceptV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) -} - -// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv6. -func TestTCPAckBeforeAcceptV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) -} - -func TestSendRstOnListenerRxAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) -} - -func TestSendRstOnListenerRxAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true /* v6Only */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) -} - -// TestListenShutdown tests for the listening endpoint not processing -// any receive when it is on read shutdown. -func TestListenShutdown(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatal("Shutdown failed:", err) - } - - // Wait for the endpoint state to be propagated. - time.Sleep(10 * time.Millisecond) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: 100, - AckNum: 200, - }) - - c.CheckNoPacket("Packet received when listening socket was shutdown") -} - -func TestTOSV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - - const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) - } - - var v tcpip.IPv4TOSOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Errorf("GetSockopt failed: %s", err) - } - - if want := tcpip.IPv4TOSOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) - } - - testV4Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestTrafficClassV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) - } - - var v tcpip.IPv6TrafficClassOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockopt failed: %s", err) - } - - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) - } - - // Test the connection request. - testV6Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetV6Packet() - checker.IPv6(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestConnectBindToDevice(t *testing.T) { - for _, test := range []struct { - name string - device tcpip.NICID - want tcp.EndpointState - }{ - {"RightDevice", 1, tcp.StateEstablished}, - {"WrongDevice", 2, tcp.StateSynSent}, - {"AnyDevice", 0, tcp.StateEstablished}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - bindToDevice := tcpip.BindToDeviceOption(test.device) - c.EP.SetSockOpt(bindToDevice) - // Start connection attempt. - waitEntry, _ := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - - c.GetPacket() - if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - }) - } -} - -func TestRstOnSynSent(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create an endpoint, don't handshake because we want to interfere with the - // handshake process. - c.Create(-1) - - // Start connection attempt. - waitEntry, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { - t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - - // Ensure that we've reached SynSent state - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - // Send a packet with a proper ACK and a RST flag to cause the socket - // to Error and close out - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagRst | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatal("timed out waiting for packet to arrive") - } - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused) - } - - // Due to the RST the endpoint should be in an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } -} - -func TestOutOfOrderReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send second half of data first, with seqnum 3 ahead of expected. - data := []byte{1, 2, 3, 4, 5, 6} - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that we get an ACK specifying which seqnum is expected. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait 200ms and check that no data has been received. - time.Sleep(200 * time.Millisecond) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send the first 3 bytes now. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive data. - read := make([]byte, 0, 6) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - if err == tcpip.ErrWouldBlock { - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - continue - } - t.Fatalf("Read failed: %v", err) - } - - read = append(read, v...) - } - - // Check that we received the data in proper order. - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that the whole data is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestOutOfOrderFlood(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create a new connection with initial window size of 10. - c.CreateConnected(789, 30000, 10) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send 100 packets before the actual one that is expected. - data := []byte{1, 2, 3, 4, 5, 6} - for i := 0; i < 100; i++ { - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 796, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Send packet with seqnum 793. It must be discarded because the - // out-of-order buffer was filled by the previous packets. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now send the expected packet, seqnum 790. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that only packet 790 is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestRstOnCloseWithUnreadData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now that we know we have unread data, let's just close the connection - // and verify that netstack sends an RST rather than a FIN. - c.EP.Close() - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // This final ACK should be ignored because an ACK on a reset doesn't mean - // anything. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) - - // Make sure we get the FIN but DON't ACK IT. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.SeqNum(uint32(c.IRS)+1), - )) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Cause a RST to be generated by closing the read end now since we have - // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) - - // Make sure we get the RST - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence - // number of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // The ACK to the FIN should now be rejected since the connection has been - // closed by a RST. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestShutdownRead(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) - } - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { - t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) - } -} - -func TestFullWindowReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, 10) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - _, _, err := c.EP.Read(nil) - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) - } - - // Fill up the window. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and window goes to zero. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), - ), - ) - - // Receive data and check it. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { - t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) - } - - // Check that we get an ACK for the newly non-zero window. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(10), - ), - ) -} - -func TestNoWindowShrinking(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Start off with a window size of 10, then shrink it to 5. - c.CreateConnected(789, 30000, 10) - - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send 3 bytes, check that the peer acknowledges them. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and that window doesn't go to zero - // just yet because it was previously set to 10. It must go to 7 now. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(7), - ), - ) - - // Send 7 more bytes, check that the window fills up. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), - ), - ) - - // Receive data and check it. - read := make([]byte, 0, 10) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - read = append(read, v...) - } - - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that we get an ACK for the newly non-zero window, which is the - // new size. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(5), - ), - ) -} - -func TestSimpleSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestZeroWindowSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 0, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Since the window is currently zero, check that no packet is received. - c.CheckNoPacket("Packet received when window is zero") - - // Open up the window. Data should be received now. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xbfff, - // that is, that it is scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestNonScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnected(789, 30000, 65535*3) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xbfff, - // that is, that it is scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestNonScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN - // should not carry the window scaling option. - c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestZeroScaledWindowReceive(t *testing.T) { - // This test ensures that the endpoint sends a non-zero window size - // advertisement when the scaled window transitions from 0 to non-zero, - // but the actual window (not scaled) hasn't gotten to zero. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size such that a window scale of 4 will be used. - const wnd = 65535 * 10 - const ws = uint32(4) - c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - // Write chunks of 50000 bytes. - remain := wnd - sent := 0 - data := make([]byte, 50000) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(remain>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Make the window non-zero, but the scaled window zero. - if remain >= 16 { - data = data[:remain-15] - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(0), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Read at least 1MSS of data. An ack should be sent in response to that. - sz := 0 - for sz < defaultMTU { - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - sz += len(v) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(sz>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestSegmentMerging(t *testing.T) { - tests := []struct { - name string - stop func(tcpip.Endpoint) - resume func(tcpip.Endpoint) - }{ - { - "stop work", - func(ep tcpip.Endpoint) { - ep.(interface{ StopWork() }).StopWork() - }, - func(ep tcpip.Endpoint) { - ep.(interface{ ResumeWork() }).ResumeWork() - }, - }, - { - "cork", - func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(1)) - }, - func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(0)) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Prevent the endpoint from processing packets. - test.stop(c.EP) - - var allData []byte - for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - // Let the endpoint process the segments that we just sent. - test.resume(c.EP) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(allData)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { - t.Fatalf("got data = %v, want = %v", got, allData) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), - RcvWnd: 30000, - }) - }) - } -} - -func TestDelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - c.EP.SetSockOptInt(tcpip.DelayOption, 1) - - var allData []byte - for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - for _, want := range [][]byte{allData[:1], allData[1:]} { - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(want)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { - t.Fatalf("got data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(want))) - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) - } -} - -func TestUndelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - c.EP.SetSockOptInt(tcpip.DelayOption, 1) - - allData := [][]byte{{0}, {1, 2, 3}} - for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - - // Check that data is received. - first := c.GetPacket() - checker.IPv4(t, first, - checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { - t.Fatalf("got first packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[0]))) - - // Check that we don't get the second packet yet. - c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - - c.EP.SetSockOptInt(tcpip.DelayOption, 0) - - // Check that data is received. - second := c.GetPacket() - checker.IPv4(t, second, - checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { - t.Fatalf("got second packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[1]))) - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) -} - -func TestMSSNotDelayed(t *testing.T) { - tests := []struct { - name string - fn func(tcpip.Endpoint) - }{ - {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const maxPayload = 100 - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - test.fn(c.EP) - - allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} - for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - - for i, data := range allData { - // Check that data is received. - packet := c.GetPacket() - checker.IPv4(t, packet, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { - t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) - } - - seq = seq.Add(seqnum.Size(len(data))) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) - }) - } -} - -func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { - payloadMultiplier := 10 - dataLen := payloadMultiplier * maxPayload - data := make([]byte, dataLen) - for i := range data { - data[i] = byte(i) - } - - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received in chunks. - bytesReceived := 0 - numPackets := 0 - for bytesReceived != dataLen { - b := c.GetPacket() - numPackets++ - tcpHdr := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcpHdr.Payload()) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { - t.Fatalf("got data = %v, want = %v", p, pdata) - } - bytesReceived += payloadLen - var options []byte - if c.TimeStampEnabled { - // If timestamp option is enabled, echo back the timestamp and increment - // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcpHdr.ParsedOptions() - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) - options = tsOpt[:] - } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options, - }) - } - if numPackets == 1 { - t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") - } -} - -func TestSendGreaterThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSetTTL(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := context.New(t, 65535) - defer c.Cleanup() - - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) - } - - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - - checker.IPv4(t, b, checker.TTL(wantTTL)) - }) - } -} - -func TestActiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, 65535) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestPassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 536 - const mtu = 2000 - c := context.New(t, mtu) - defer c.Cleanup() - - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestForwarderSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan *tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err *tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Wait for connection to be available. - select { - case err := <-ch: - if err != nil { - t.Fatalf("Error creating endpoint: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynOptionsOnActiveConnect(t *testing.T) { - const mtu = 1400 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - const wndScale = 2 - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - // Receive SYN packet. - b := c.GetPacket() - mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - // Wait for retransmit. - time.Sleep(1 * time.Second) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcpHdr.SourcePort()), - checker.SeqNum(tcpHdr.SequenceNumber()), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - // Send SYN-ACK. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestCloseListener(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Close the listener and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) - } -} - -func TestReceiveOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - RcvWnd: 30000, - }) - - // Try to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - -loop: - for { - switch _, _, err := c.EP.Read(nil); err { - case tcpip.ErrWouldBlock: - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") - } - case tcpip.ErrConnectionReset: - break loop - default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) - } - } - // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) - } - if tcp.EndpointState(c.EP.State()) != tcp.StateError { - t.Fatalf("got EP state is not StateError") - } - - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } -} - -func TestSendOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - RcvWnd: 30000, - }) - - // Wait for the RST to be received. - time.Sleep(1 * time.Second) - - // Try to write. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset) - } -} - -func TestFinImmediately(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Don't acknowledge yet. We should get a retransmit of the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithNoPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and have it acknowledged. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Shutdown, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingDataCwndFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write enough segments to fill the congestion window before ACK'ing - // any of them. - view := buffer.NewView(10) - for i := tcp.InitialCwnd; i > 0; i-- { - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - } - - next := uint32(c.IRS) + 1 - for i := tcp.InitialCwnd; i > 0; i-- { - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - } - - // Shutdown the connection, check that the FIN segment isn't sent - // because the congestion window doesn't allow it. Wait until a - // retransmit is received. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Send the ACK that will allow the FIN to be sent as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPartialAck(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. Also send - // FIN from the test side. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that we get an ACK for the fin. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send an ACK for the data, but not for the FIN yet. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 791, - AckNum: seqnum.Value(next - 1), - RcvWnd: 30000, - }) - - // Check that we don't get a retransmit of the FIN. - c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) - - // Ack the FIN. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 791, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) -} - -func TestUpdateListenBacklog(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Update the backlog with another Listen() on the same endpoint. - if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %v", err) - } - - ep.Close() -} - -func scaledSendWindow(t *testing.T, scale uint8) { - // This test ensures that the endpoint is using the right scaling by - // sending a buffer that is larger than the window size, and ensuring - // that the endpoint doesn't send more than allowed. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - header.TCPOptionWS, 3, scale, header.TCPOptionNOP, - }) - - // Open up the window with a scaled value. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 1, - }) - - // Send some data. Check that it's capped by the window size. - view := buffer.NewView(65535) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that only data that fits in the scaled window is sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen((1<<scale)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Reset the connection to free resources. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - }) -} - -func TestScaledSendWindow(t *testing.T) { - for scale := uint8(0); scale <= 14; scale++ { - scaledSendWindow(t, scale) - } -} - -func TestReceivedValidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ValidSegmentsReceived.Value() + 1 - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { - t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) - } - // Ensure there were no errors during handshake. If these stats have - // incremented, then the connection should not have been established. - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) - } - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) - } -} - -func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.InvalidSegmentsReceived.Value() + 1 - vv := c.BuildSegment(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] - tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 - - c.SendSegment(vv) - - if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) - } -} - -func TestReceivedIncorrectChecksumIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ChecksumErrors.Value() + 1 - vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] - // Overwrite a byte in the payload which should cause checksum - // verification to fail. - tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 - - c.SendSegment(vv) - - if got := stats.TCP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) - } -} - -func TestReceivedSegmentQueuing(t *testing.T) { - // This test sends 200 segments containing a few bytes each to an - // endpoint and checks that they're all received and acknowledged by - // the endpoint, that is, that none of the segments are dropped by - // internal queues. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send 200 segments. - data := []byte{1, 2, 3} - for i := 0; i < 200; i++ { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + i*len(data)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - } - - // Receive ACKs for all segments. - last := seqnum.Value(790 + 200*len(data)) - for { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - tcpHdr := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcpHdr.AckNumber()) - if ack == last { - break - } - - if last.LessThan(ack) { - t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) - } - } -} - -func TestReadAfterClosedState(t *testing.T) { - // This test ensures that calling Read() or Peek() after the endpoint - // has transitioned to closedState still works if there is pending - // data. To transition to stateClosed without calling Close(), we must - // shutdown the send path and the peer must send its own FIN. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) - } - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) - } - - // Shutdown immediately for write, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Send some data and acknowledge the FIN. - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(uint32(791+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack the chance to transition to closed state from - // TIME_WAIT. - time.Sleep(tcpTimeWaitTimeout * 2) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that peek works. - peekBuf := make([]byte, 10) - n, _, err := c.EP.Peek([][]byte{peekBuf}) - if err != nil { - t.Fatalf("Peek failed: %s", err) - } - - peekBuf = peekBuf[:n] - if !bytes.Equal(data, peekBuf) { - t.Fatalf("got data = %v, want = %v", peekBuf, data) - } - - // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Now that we drained the queue, check that functions fail with the - // right error code. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) - } - - if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) - } -} - -func TestReusePort(t *testing.T) { - // This test ensures that ports are immediately available for reuse - // after Close on the endpoints using them returns. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // First case, just an endpoint that was bound. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - c.EP.Close() - - // Second case, an endpoint that was bound and is connecting.. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - c.EP.Close() - - // Third case, an endpoint that was bound and is listening. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } -} - -func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - if int(s) != v { - t.Fatalf("got receive buffer size = %v, want = %v", s, v) - } -} - -func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - if int(s) != v { - t.Fatalf("got send buffer size = %v, want = %v", s, v) - } -} - -func TestDefaultBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer func() { - if ep != nil { - ep.Close() - } - }() - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) -} - -func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer ep.Close() - - // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Set values below the min. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkRecvBufferSize(t, ep, 200) - - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkSendBufferSize(t, ep, 300) - - // Set values above the max. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) - - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) - - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer ep.Close() - - if err := s.CreateNIC(321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) - } - } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %v, want %v", err, nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %d, want %d", got, want) - } - }) - } -} - -func makeStack() (*stack.Stack, *tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - ipv6.NewProtocol(), - }, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - id := loopback.New() - if testing.Verbose() { - id = sniffer.New(id) - } - - if err := s.CreateNIC(1, id); err != nil { - return nil, err - } - - for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address - }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, - } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { - return nil, err - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return s, nil -} - -func TestSelfConnect(t *testing.T) { - // This test ensures that intentional self-connects work. In particular, - // it checks that if an endpoint binds to say 127.0.0.1:1000 then - // connects to 127.0.0.1:1000, then it will be connected to itself, and - // is able to send and receive data through the same endpoint. - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Write something. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Read back what was written. - wq.EventUnregister(&waitEntry) - wq.EventRegister(&waitEntry, waiter.EventIn) - rd, _, err := ep.Read(nil) - if err != nil { - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) - } - <-notifyCh - rd, _, err = ep.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - } - - if !bytes.Equal(data, rd) { - t.Fatalf("got data = %v, want = %v", rd, data) - } -} - -func TestConnectAvoidsBoundPorts(t *testing.T) { - addressTypes := func(t *testing.T, network string) []string { - switch network { - case "ipv4": - return []string{"v4"} - case "ipv6": - return []string{"v6"} - case "dual": - return []string{"v6", "mapped"} - default: - t.Fatalf("unknown network: '%s'", network) - } - - panic("unreachable") - } - - address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { - switch addressType { - case "v4": - if isAny { - return "" - } - return context.StackAddr - case "v6": - if isAny { - return "" - } - return context.StackV6Addr - case "mapped": - if isAny { - return context.V4MappedWildcardAddr - } - return context.StackV4MappedAddr - default: - t.Fatalf("unknown address type: '%s'", addressType) - } - - panic("unreachable") - } - // This test ensures that Endpoint.Connect doesn't select already-bound ports. - networks := []string{"ipv4", "ipv6", "dual"} - for _, exhaustedNetwork := range networks { - t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { - for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { - t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { - for _, isAny := range []bool{false, true} { - t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { - for _, candidateNetwork := range networks { - t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { - for _, candidateAddressType := range addressTypes(t, candidateNetwork) { - t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - var eps []tcpip.Endpoint - defer func() { - for _, ep := range eps { - ep.Close() - } - }() - makeEP := func(network string) tcpip.Endpoint { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch network { - case "ipv4": - networkProtocolNumber = ipv4.ProtocolNumber - case "ipv6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatalf("unknown network: '%s'", network) - } - ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - eps = append(eps, ep) - switch network { - case "ipv4": - case "ipv6": - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) - } - case "dual": - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) - } - default: - t.Fatalf("unknown network: '%s'", network) - } - return ep - } - - var v4reserved, v6reserved bool - switch exhaustedAddressType { - case "v4", "mapped": - v4reserved = true - case "v6": - v6reserved = true - // Dual stack sockets bound to v6 any reserve on v4 as - // well. - if isAny { - switch exhaustedNetwork { - case "ipv6": - case "dual": - v4reserved = true - default: - t.Fatalf("unknown address type: '%s'", exhaustedNetwork) - } - } - default: - t.Fatalf("unknown address type: '%s'", exhaustedAddressType) - } - var collides bool - switch candidateAddressType { - case "v4", "mapped": - collides = v4reserved - case "v6": - collides = v6reserved - default: - t.Fatalf("unknown address type: '%s'", candidateAddressType) - } - - for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { - if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %v", i, err) - } - } - want := tcpip.ErrConnectStarted - if collides { - want = tcpip.ErrNoPortAvailable - } - if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want) - } - }) - } - }) - } - }) - } - }) - } - }) - } -} - -func TestPathMTUDiscovery(t *testing.T) { - // This test verifies the stack retransmits packets after it receives an - // ICMP packet indicating that the path MTU has been exceeded. - c := context.New(t, 1500) - defer c.Cleanup() - - // Create new connection with MSS of 1460. - const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - // Send 3200 bytes of data. - const writeSize = 3200 - data := buffer.NewView(writeSize) - for i := range data { - data[i] = byte(i) - } - - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { - var ret []byte - for i, size := range sizes { - p := c.GetPacket() - if i == which { - ret = p - } - checker.IPv4(t, p, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(seqNum), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - seqNum += uint32(size) - } - return ret - } - - // Receive three packets. - sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} - first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) - - // Send "packet too big" messages back to netstack. - const newMTU = 1200 - const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize - mtu := []byte{0, 0, newMTU / 256, newMTU % 256} - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) - - // See retransmitted packets. None exceeding the new max. - sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} - receivePackets(c, sizes, -1, uint32(c.IRS)+1) -} - -func TestTCPEndpointProbe(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - invoked := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint ID is what we expect. - // - // We don't do an extensive validation of every field but a - // basic sanity test. - if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { - t.Fatalf("got LocalAddress: %q, want: %q", got, want) - } - if got, want := state.ID.LocalPort, c.Port; got != want { - t.Fatalf("got LocalPort: %d, want: %d", got, want) - } - if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { - t.Fatalf("got RemoteAddress: %q, want: %q", got, want) - } - if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { - t.Fatalf("got RemotePort: %d, want: %d", got, want) - } - - invoked <- struct{}{} - }) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-invoked: - case <-time.After(100 * time.Millisecond): - t.Fatalf("TCP Probe function was not called") - } -} - -func TestStackSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err *tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - var oldCC tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - - got, want := cc, oldCC - // If SetTransportProtocolOption is expected to succeed - // then the returned value for congestion control should - // match the one specified in the - // SetTransportProtocolOption call above, else it should - // be what it was before the call to - // SetTransportProtocolOption. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } -} - -func TestStackAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Query permitted congestion control algorithms. - var aCC tcpip.AvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) - } - if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestStackSetAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.AvailableCongestionControlOption("xyz") - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) - } - - // Verify that we still get the expected list of congestion control options. - var cc tcpip.AvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestEndpointSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err *tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, - } - - for _, connected := range []bool{false, true} { - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - var oldCC tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) - } - - if connected { - c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) - } - - if err := c.EP.SetSockOpt(tc.cc); err != tc.err { - t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) - } - - got, want := cc, oldCC - // If SetSockOpt is expected to succeed then the - // returned value for congestion control should match - // the one specified in the SetSockOpt above, else it - // should be what it was before the call to SetSockOpt. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } - } -} - -func enableCUBIC(t *testing.T, c *context.Context) { - t.Helper() - opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) - } -} - -func TestKeepalive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const keepAliveInterval = 10 * time.Millisecond - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) - - // 5 unacked keepalives are sent. ACK each one, and check that the - // connection stays alive after 5. - for i := 0; i < 10; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Acknowledge the keepalive. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS, - RcvWnd: 30000, - }) - } - - // Check that the connection is still alive. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send some data and wait before ACKing it. Keepalives should be disabled - // during this period. - view := buffer.NewView(3) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Wait for the packet to be retransmitted. Verify that no keepalives - // were sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - c.CheckNoPacket("Keepalive packet received while unACKed data is pending") - - next += uint32(len(view)) - - // Send ACK. Keepalives should start sending again. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Now receive 5 keepalives, but don't ACK them. The connection - // should be reset after 5. - for i := 0; i < 5; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next-1)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Sleep for a litte over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + 5*time.Millisecond) - - // The connection should be terminated after 5 unacked keepalives. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) - } - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } -} - -func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - // Send a SYN request. - irs = seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptions(), - })) - } - - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - // Send a SYN request. - irs = seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptionsV6(), - })) - } - - checker.IPv6(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -// TestListenBacklogFull tests that netstack does not complete handshakes if the -// listen backlog for the endpoint is full. -func TestListenBacklogFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 2 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - for i := 0; i < listenBacklog; i++ { - executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) - } - - time.Sleep(50 * time.Millisecond) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() - if err != tcpip.ErrWouldBlock { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } - - // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - - newEP, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a -// non unicast IPv4 address are not accepted. -func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpip.Address("\xe0\x00\x01\x02") - otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - "SourceUnspecified", - header.IPv4Any, - context.StackAddr, - }, - { - "SourceBroadcast", - header.IPv4Broadcast, - context.StackAddr, - }, - { - "SourceOurMulticast", - multicastAddr, - context.StackAddr, - }, - { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackAddr, - }, - { - "DestUnspecified", - context.TestAddr, - header.IPv4Any, - }, - { - "DestBroadcast", - context.TestAddr, - header.IPv4Broadcast, - }, - { - "DestOurMulticast", - context.TestAddr, - multicastAddr, - }, - { - "DestOtherMulticast", - context.TestAddr, - otherMulticastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(789) - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestAddr, context.StackAddr) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) - }) - } -} - -// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a -// non unicast IPv6 address are not accepted. -func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - "SourceUnspecified", - header.IPv6Any, - context.StackV6Addr, - }, - { - "SourceAllNodes", - header.IPv6AllNodesMulticastAddress, - context.StackV6Addr, - }, - { - "SourceOurMulticast", - multicastAddr, - context.StackV6Addr, - }, - { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackV6Addr, - }, - { - "DestUnspecified", - context.TestV6Addr, - header.IPv6Any, - }, - { - "DestAllNodes", - context.TestV6Addr, - header.IPv6AllNodesMulticastAddress, - }, - { - "DestOurMulticast", - context.TestV6Addr, - multicastAddr, - }, - { - "DestOtherMulticast", - context.TestV6Addr, - otherMulticastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(789) - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestV6Addr, context.StackV6Addr) - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) - }) - } -} - -func TestListenSynRcvdQueueFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send two SYN's the first one should get a SYN-ACK, the - // second one should not get any response and is dropped as - // the synRcvd count will be equal to backlog. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - // - // NOTE: we did not complete the handshake for the previous one so the - // accept backlog should be empty and there should be one connection in - // synRcvd state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Now complete the previous connection and verify that there is a connection - // to accept. - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - newEP, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestListenBacklogFullSynCookieInUse(t *testing.T) { - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 1 - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 1 - portOffset := uint16(0) - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - executeHandshake(t, c, context.TestPort+portOffset, false) - portOffset++ - // Wait for this to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - // pick a different src port for new SYN. - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - // The Syn should be dropped as the endpoint's backlog is full. - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Verify that there is only one acceptable connection at this point. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() - if err != tcpip.ErrWouldBlock { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } -} - -func TestSynRcvdBadSeqNumber(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcpHdr.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now send a packet with an out-of-window sequence number - largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: largeSeqnum, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Should receive an ACK with the expected SEQ number - b = c.GetPacket() - tcpCheckers = []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.AckNum(uint32(irs) + 1), - checker.SeqNum(uint32(iss + 1)), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now that the socket replied appropriately with the ACK, - // complete the connection to test that the large SEQ num - // did not change the state from SYN-RCVD. - - // Send ACK to move to ESTABLISHED state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - newEP, _, err := c.EP.Accept() - - if err != nil && err != tcpip.ErrWouldBlock { - t.Fatalf("Accept failed: %s", err) - } - - if err == tcpip.ErrWouldBlock { - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - - if err != nil { - t.Fatalf("Write failed: %s", err) - } - - pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) - if string(tcpHdr.Payload()) != data { - t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) - } -} - -func TestPassiveConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - stats := c.Stack().Stats() - want := stats.TCP.PassiveConnectionOpenings.Value() + 1 - - srcPort := uint16(context.TestPort) - executeHandshake(t, c, srcPort+1, false) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) - } -} - -func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - srcPort := uint16(context.TestPort) - // Now attempt a handshakes it will fill up the accept backlog. - executeHandshake(t, c, srcPort, false) - - // Give time for the final ACK to be processed as otherwise the next handshake could - // get accepted before the previous one based on goroutine scheduling. - time.Sleep(50 * time.Millisecond) - - want := stats.TCP.ListenOverflowSynDrop.Value() + 1 - - // Now we will send one more SYN and this one should get dropped - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestEndpointBindListenAcceptState(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { - t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected) - } - if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { - t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - aep, _, err := ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - aep, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { - t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) - } - // Listening endpoint remains in listen state. - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - ep.Close() - // Give worker goroutines time to receive the close notification. - time.Sleep(1 * time.Second) - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - // Accepted endpoint remains open when the listen endpoint is closed. - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - -} - -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. -func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - // Change the expected window scale to match the value needed for the - // maximum buffer size defined above. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - // NOTE: The timestamp values in the sent packets are meaningless to the - // peer so we just increment the timestamp value by 1 every batch as we - // are not really using them for anything. Send a single byte to verify - // the advertised window. - tsVal := rawEP.TSVal + 1 - - // Introduce a 25ms latency by delaying the first byte. - latency := 25 * time.Millisecond - time.Sleep(latency) - rawEP.SendPacketWithTS([]byte{1}, tsVal) - - // Verify that the ACK has the expected window. - wantRcvWnd := receiveBufferSize - wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale)) - rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1)) - time.Sleep(25 * time.Millisecond) - - // Allocate a large enough payload for the test. - b := make([]byte, int(receiveBufferSize)*2) - offset := 0 - payloadSize := receiveBufferSize - 1 - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := offset - end := offset + payloadSize - packetsSent := 0 - for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - packetsSent++ - } - - // Resume the worker so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Since we read no bytes the window should goto zero till the - // application reads some of the data. - // Discard all intermediate acks except the last one. - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } - } - rawEP.VerifyACKRcvWnd(0) - - time.Sleep(25 * time.Millisecond) - // Verify that sending more data when window is closed is dropped and - // not acked. - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - - // Verify that the stack sends us back an ACK with the sequence number - // of the last packet sent indicating it was dropped. - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - checker.Window(0), - )) - - // Now read all the data from the endpoint and verify that advertised - // window increases to the full available buffer size. - for { - _, _, err := c.EP.Read(nil) - if err == tcpip.ErrWouldBlock { - break - } - } - - // Verify that we receive a non-zero window update ACK. When running - // under thread santizer this test can end up sending more than 1 - // ack, 1 for the non-zero window - p = c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd) - } - }, - )) -} - -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. -func TestReceiveBufferAutoTuning(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - // Enable Auto-tuning. - stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 300. - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - // Change the expected window scale to match the value needed for the - // maximum buffer size used by stack. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - wantRcvWnd := receiveBufferSize - scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) - } - // Allocate a large array to send to the endpoint. - b := make([]byte, receiveBufferSize*48) - - // In every iteration we will send double the number of bytes sent in - // the previous iteration and read the same from the app. The received - // window should grow by at least 2x of bytes read by the app in every - // RTT. - offset := 0 - payloadSize := receiveBufferSize / 8 - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - tsVal := rawEP.TSVal - // We are going to do our own computation of what the moderated receive - // buffer should be based on sent/copied data per RTT and verify that - // the advertised window by the stack matches our calculations. - prevCopied := 0 - done := false - latency := 1 * time.Millisecond - for i := 0; !done; i++ { - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := offset - end := offset + payloadSize - totalSent := 0 - packetsSent := 0 - for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - totalSent += mss - packetsSent++ - } - - // Resume it so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Give 1ms for the worker to process the packets. - time.Sleep(1 * time.Millisecond) - - // Verify that the advertised window on the ACK is reduced by - // the total bytes sent. - expectedWnd := wantRcvWnd - totalSent - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } - } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd)) - - // Now read all the data from the endpoint and invoke the - // moderation API to allow for receive buffer auto-tuning - // to happen before we measure the new window. - totalCopied := 0 - for { - b, _, err := c.EP.Read(nil) - if err == tcpip.ErrWouldBlock { - break - } - totalCopied += len(b) - } - - // Invoke the moderation API. This is required for auto-tuning - // to happen. This method is normally expected to be invoked - // from a higher layer than tcpip.Endpoint. So we simulate - // copying to user-space by invoking it explicitly here. - c.EP.ModerateRecvBuf(totalCopied) - - // Now send a keep-alive packet to trigger an ACK so that we can - // measure the new window. - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - - if i == 0 { - // In the first iteration the receiver based RTT is not - // yet known as a result the moderation code should not - // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd)) - prevCopied = totalCopied - } else { - rttCopied := totalCopied - if i == 1 { - // The moderation code accumulates copied bytes till - // RTT is established. So add in the bytes sent in - // the first iteration to the total bytes for this - // RTT. - rttCopied += prevCopied - // Now reset it to the initial value used by the - // auto tuning logic. - prevCopied = tcp.InitialCwnd * mss * 2 - } - newWnd := rttCopied<<1 + 16*mss - grow := (newWnd * (rttCopied - prevCopied)) / prevCopied - newWnd += (grow << 1) - if newWnd > maxReceiveBufferSize { - newWnd = maxReceiveBufferSize - done = true - } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd)) - wantRcvWnd = newWnd - prevCopied = rttCopied - // Increase the latency after first two iterations to - // establish a low RTT value in the receiver since it - // only tracks the lowest value. This ensures that when - // ModerateRcvBuf is called the elapsed time is always > - // rtt. Without this the test is flaky due to delays due - // to scheduling/wakeup etc. - latency += 50 * time.Millisecond - } - time.Sleep(latency) - offset += payloadSize - payloadSize *= 2 - } -} - -func TestDelayEnabled(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - checkDelayOption(t, c, false, 0) // Delay is disabled by default. - - for _, v := range []struct { - delayEnabled tcp.DelayEnabled - wantDelayOption int - }{ - {delayEnabled: false, wantDelayOption: 0}, - {delayEnabled: true, wantDelayOption: 1}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) - } -} - -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { - t.Helper() - - var gotDelayEnabled tcp.DelayEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { - t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err) - } - if gotDelayEnabled != wantDelayEnabled { - t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) - } - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) - if err != nil { - t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) - } - gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) - if err != nil { - t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) - } - if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) - } -} - -func TestTCPLingerTimeout(t *testing.T) { - c := context.New(t, 1500 /* mtu */) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - testCases := []struct { - name string - tcpLingerTimeout time.Duration - want time.Duration - }{ - {"NegativeLingerTimeout", -123123, 0}, - {"ZeroLingerTimeout", 0, 0}, - {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, - // Values > stack's TCPLingerTimeout are capped to the stack's - // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { - t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) - } - var v tcpip.TCPLingerTimeoutOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) - } - if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) - } - }) - } -} - -func TestTCPTimeWaitRSTIgnored(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now send a RST and this should be ignored and not - // generate an ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - }) - - c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitOutOfOrder(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitNewSyn(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Send a SYN request w/ sequence number lower than - // the highest sequence number sent. We just reuse - // the same number. - iss = seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) - - // Send a SYN request w/ sequence number higher than - // the highest sequence number sent. - iss = seqnum.Value(792) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) - } - - want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - time.Sleep(2 * time.Second) - - // Now send a duplicate FIN. This should cause the TIME_WAIT to extend - // by another 5 seconds and also send us a duplicate ACK as it should - // indicate that the final ACK was potentially lost. - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Sleep for 4 seconds so at this point we are 1 second past the - // original tcpLingerTimeout of 5 seconds. - time.Sleep(4 * time.Second) - - // Send an ACK and it should not generate any packet as the socket - // should still be in TIME_WAIT for another another 5 seconds due - // to the duplicate FIN we sent earlier. - *ackHeaders = *finHeaders - ackHeaders.SeqNum = ackHeaders.SeqNum + 1 - ackHeaders.Flags = header.TCPFlagAck - c.SendPacket(nil, ackHeaders) - - c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) - // Now sleep for another 2 seconds so that we are past the - // extended TIME_WAIT of 7 seconds (2 + 5). - time.Sleep(2 * time.Second) - - // Resend the same ACK. - c.SendPacket(nil, ackHeaders) - - // Receive the RST that should be generated as there is no valid - // endpoint. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), - checker.TCPFlags(header.TCPFlagRst))) - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } -} - -func TestTCPCloseWithData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) - } - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: 30000, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now trigger a passive close by sending a FIN. - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - RcvWnd: 30000, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now write a few bytes and then close the endpoint. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b = c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } - - c.EP.Close() - // Check the FIN. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.AckNum(uint32(iss+2)), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - // First send a partial ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now send a full ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now ACK the FIN. - ackHeaders.AckNum++ - c.SendPacket(nil, ackHeaders) - - // Now send an ACK and we should get a RST back as the endpoint should - // be in CLOSED state. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Check the RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestTCPUserTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - - userTimeout := 50 * time.Millisecond - c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) - - // Send some data and wait before ACKing it. - view := buffer.NewView(3) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Wait for a little over the minimum retransmit timeout of 200ms for - // the retransmitTimer to fire and close the connection. - time.Sleep(tcp.MinRTO + 10*time.Millisecond) - - // No packet should be received as the connection should be silently - // closed due to timeout. - c.CheckNoPacket("unexpected packet received after userTimeout has expired") - - next += uint32(len(view)) - - // The connection should be terminated after userTimeout has expired. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) - } - - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) - } -} - -func TestKeepaliveWithUserTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - - const keepAliveInterval = 10 * time.Millisecond - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) - - // Set userTimeout to be the duration for 3 keepalive probes. - userTimeout := 30 * time.Millisecond - c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) - - // Check that the connection is still alive. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Now receive 2 keepalives, but don't ACK them. The connection should - // be reset when the 3rd one should be sent due to userTimeout being - // 30ms and each keepalive probe should be sent 10ms apart as set above after - // the connection has been idle for 10ms. - for i := 0; i < 2; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Sleep for a litte over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + 5*time.Millisecond) - - // The connection should be terminated after 30ms. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(c.IRS + 1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) - } - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) - } -} - -func TestIncreaseWindowOnReceive(t *testing.T) { - // This test ensures that the endpoint sends an ack, - // after recv() when the window grows to more than 1 MSS. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBuf = 65535 * 10 - c.CreateConnected(789, 30000, rcvBuf) - - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf - sent := 0 - data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) - - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - - // We now have < 1 MSS in the buffer space. Read the data! An - // ack should be sent in response to that. The window was not - // zero, but it grew to larger than MSS. - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %v", err) - } - - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %v", err) - } - - // After reading two packets, we surely crossed MSS. See the ack: - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestIncreaseWindowOnBufferResize(t *testing.T) { - // This test ensures that the endpoint sends an ack, - // after available recv buffer grows to more than 1 MSS. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBuf = 65535 * 10 - c.CreateConnected(789, 30000, rcvBuf) - - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf - sent := 0 - data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) - - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - - // Increasing the buffer from should generate an ACK, - // since window grew from small value to larger equal MSS - c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - - // After reading two packets, we surely crossed MSS. See the ack: - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestTCPDeferAccept(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - const tcpDeferAccept = 1 * time.Second - if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) - } - - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) - - // Give a bit of time for the socket to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() - if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) -} - -func TestTCPDeferAcceptTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - const tcpDeferAccept = 1 * time.Second - if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) - } - - // Sleep for a little of the tcpDeferAccept timeout. - time.Sleep(tcpDeferAccept + 100*time.Millisecond) - - // On timeout expiry we should get a SYN-ACK retransmission. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) - - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) - - // Give sometime for the endpoint to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() - if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) -} - -func TestResetDuringClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - iss := seqnum.Value(789) - c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) - // Send some data to make sure there is some unread - // data to trigger a reset on c.Close. - irs := c.IRS - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), - AckNum: irs.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(irs.Add(1))), - checker.AckNum(uint32(iss.Add(5))))) - - // Close in a separate goroutine so that we can trigger - // a race with the RST we send below. This should not - // panic due to the route being released depeding on - // whether Close() sends an active RST or the RST sent - // below is processed by the worker first. - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(5), - AckNum: c.IRS.Add(5), - RcvWnd: 30000, - Flags: header.TCPFlagRst, - }) - }() - - wg.Add(1) - go func() { - defer wg.Done() - c.EP.Close() - }() - - wg.Wait() -} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go deleted file mode 100644 index a641e953d..000000000 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -// createConnectedWithTimestampOption creates and connects c.ep with the -// timestamp option enabled. -func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) -} - -// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on -// an active connect and sets the TS Echo Reply fields correctly when the -// SYN-ACK also indicates support for the TS option and provides a TSVal. -func TestTimeStampEnabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read and validate that we have data to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // The following tests ensure that TS option once enabled behaves - // correctly as described in - // https://tools.ietf.org/html/rfc7323#section-4.3. - // - // We are not testing delayed ACKs here, but we do test out of order - // packet delivery and filling the sequence number hole created due to - // the out of order packet. - // - // The test also verifies that the sequence numbers and timestamps are - // as expected. - data := []byte{1, 2, 3} - - // First we increment tsVal by a small amount. - tsVal := rep.TSVal + 100 - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Next we send an out of order packet. - rep.NextSeqNum += 3 - tsVal += 200 - rep.SendPacketWithTS(data, tsVal) - - // The ACK should contain the original sequenceNumber and an older TS. - rep.NextSeqNum -= 6 - rep.VerifyACKWithTS(tsVal - 200) - - // Next we fill the hole and the returned ACK should contain the - // cumulative sequence number acking all data sent till now and have the - // latest timestamp sent below in its TSEcr field. - tsVal -= 100 - rep.SendPacketWithTS(data, tsVal) - rep.NextSeqNum += 3 - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal by a large value that doesn't result in a wrap around. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal again by a large value which should cause the - // timestamp value to wrap around. The returned ACK should contain the - // wrapped around timestamp in its tsEcr field and not the tsVal from - // the previous packet sent above. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // There should be 5 views to read and each of them should - // contain the same data. - for i := 0; i < 5; i++ { - got, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if want := data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } - } -} - -// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an -// active connect but if the SYN-ACK doesn't specify the TS option then -// timestamp option is not enabled and future packets do not contain a -// timestamp. -func TestTimeStampDisabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithOptions(header.TCPSynOptions{}) -} - -func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - c := context.New(t, defaultMTU) - defer c.Cleanup() - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - tsVal := rand.Uint32() - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) - - // Now send some data and validate that timestamp is echoed correctly in the ACK. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) - } - - // Check that data is received and that the timestamp option TSEcr field - // matches the expected value. - b := c.GetPacket() - checker.IPv4(t, b, - // Add 12 bytes for the timestamp option + 2 NOPs to align at 4 - // byte boundary. - checker.PayloadLen(len(data)+header.TCPMinimumSize+12), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - checker.TCPTimestampChecker(true, 0, tsVal+1), - ), - ) -} - -// TestTimeStampEnabledAccept tests that if the SYN on a passive connect -// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK -// and echoes the tsVal field of the original SYN in the tcEcr field of the -// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify -// that Timestamp option is enabled in both cases if requested in the original -// SYN. -func TestTimeStampEnabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. - } - for _, tc := range testCases { - timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now send some data with the accepted connection endpoint and validate - // that no timestamp option is sent in the TCP segment. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) - } - - // Check that data is received and that the timestamp option is disabled - // when SYN cookies are enabled/disabled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - checker.TCPTimestampChecker(false, 0, 0), - ), - ) -} - -// TestTimeStampDisabledAccept tests that Timestamp option is not used when the -// peer doesn't advertise it and connection is established with Accept(). -func TestTimeStampDisabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. - } - for _, tc := range testCases { - timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func TestSendGreaterThanMTUWithOptions(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - createConnectedWithTimestampOption(c) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - droppedPacketsStat := c.Stack().Stats().DroppedPackets - droppedPackets := droppedPacketsStat.Value() - data := []byte{1, 2, 3} - // Send a packet with no TCP options/timestamp. - rep.SendPacket(data, nil) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Assert that DroppedPackets was not incremented. - if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { - t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) - } - - // Issue a read and we should data. - got, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if want := data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD deleted file mode 100644 index ce6a2c31d..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - testonly = 1, - srcs = ["context.go"], - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go deleted file mode 100644 index 8cea20fb5..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ /dev/null @@ -1,1103 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provides a test context for use in tcp tests. It also -// provides helper methods to assert/check certain behaviours. -package context - -import ( - "bytes" - "context" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // StackAddr is the IPv4 address assigned to the stack. - StackAddr = "\x0a\x00\x00\x01" - - // StackPort is used as the listening port in tests for passive - // connects. - StackPort = 1234 - - // TestAddr is the source address for packets sent to the stack via the - // link layer endpoint. - TestAddr = "\x0a\x00\x00\x02" - - // TestPort is the TCP port used for packets sent to the stack - // via the link layer endpoint. - TestPort = 4096 - - // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - - // TestV6Addr is the source address for packets sent to the stack via - // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - // StackV4MappedAddr is StackAddr as a mapped v6 address. - StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr - - // TestV4MappedAddr is TestAddr as a mapped v6 address. - TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr - - // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - // testInitialSequenceNumber is the initial sequence number sent in packets that - // are sent in response to a SYN or in the initial SYN sent to the stack. - testInitialSequenceNumber = 789 -) - -// Headers is used to represent the TCP header fields when building a -// new packet. -type Headers struct { - // SrcPort holds the src port value to be used in the packet. - SrcPort uint16 - - // DstPort holds the destination port value to be used in the packet. - DstPort uint16 - - // SeqNum is the value of the sequence number field in the TCP header. - SeqNum seqnum.Value - - // AckNum represents the acknowledgement number field in the TCP header. - AckNum seqnum.Value - - // Flags are the TCP flags in the TCP header. - Flags int - - // RcvWnd is the window to be advertised in the ReceiveWindow field of - // the TCP header. - RcvWnd seqnum.Size - - // TCPOpts holds the options to be sent in the option field of the TCP - // header. - TCPOpts []byte -} - -// Context provides an initialized Network stack and a link layer endpoint -// for use in TCP tests. -type Context struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - // IRS holds the initial sequence number in the SYN sent by endpoint in - // case of an active connect or the sequence number sent by the endpoint - // in the SYN-ACK sent in response to a SYN when listening in passive - // mode. - IRS seqnum.Value - - // Port holds the port bound by EP below in case of an active connect or - // the listening port number in case of a passive connect. - Port uint16 - - // EP is the test endpoint in the stack owned by this context. This endpoint - // is used in various tests to either initiate an active connect or is used - // as a passive listening endpoint to accept inbound connections. - EP tcpip.Endpoint - - // Wq is the wait queue associated with EP and is used to block for events - // on EP. - WQ waiter.Queue - - // TimeStampEnabled is true if ep is connected with the timestamp option - // enabled. - TimeStampEnabled bool - - // WindowScale is the expected window scale in SYN packets sent by - // the stack. - WindowScale uint8 -} - -// New allocates and initializes a test context containing a new -// stack and a link-layer endpoint. -func New(t *testing.T, mtu uint32) *Context { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Some of the congestion control tests send up to 640 packets, we so - // set the channel size to 1000. - ep := channel.New(1000, mtu, "") - wep := stack.LinkEndpoint(ep) - if testing.Verbose() { - wep = sniffer.New(ep) - } - opts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, opts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) - if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, mtu, "")) - } - opts2 := stack.NICOptions{Name: "nic2"} - if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &Context{ - t: t, - s: s, - linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), - } -} - -// Cleanup closes the context endpoint if required. -func (c *Context) Cleanup() { - if c.EP != nil { - c.EP.Close() - } - c.Stack().Close() -} - -// Stack returns a reference to the stack in the Context. -func (c *Context) Stack() *stack.Stack { - return c.s -} - -// CheckNoPacketTimeout verifies that no packet is received during the time -// specified by wait. -func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { - c.t.Helper() - - ctx, _ := context.WithTimeout(context.Background(), wait) - if _, ok := c.linkEP.ReadContext(ctx); ok { - c.t.Fatal(errMsg) - } -} - -// CheckNoPacket verifies that no packet is received for 1 second. -func (c *Context) CheckNoPacket(errMsg string) { - c.CheckNoPacketTimeout(errMsg, 1*time.Second) -} - -// GetPacket reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. It will fail with an error if no packet is received for -// 2 seconds. -func (c *Context) GetPacket() []byte { - c.t.Helper() - - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b -} - -// GetPacketNonBlocking reads a packet from the link layer endpoint -// and verifies that it is an IPv4 packet with the expected source -// and destination address. If no packet is available it will return -// nil immediately. -func (c *Context) GetPacketNonBlocking() []byte { - c.t.Helper() - - p, ok := c.linkEP.Read() - if !ok { - return nil - } - - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b -} - -// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { - // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) - if len(buf) > maxTotalSize { - buf = buf[:maxTotalSize] - } - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) - icmp.SetType(typ) - icmp.SetCode(code) - const icmpv4VariableHeaderOffset = 4 - copy(icmp[icmpv4VariableHeaderOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset:], p2) - - // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -// BuildSegment builds a TCP segment based on the given Headers and payload. -func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { - return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) -} - -// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, -// payload and source and destination IPv4 addresses. -func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv4MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: uint8(h.Flags), - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - return buf.ToVectorisedView() -} - -// SendSegment sends a TCP segment that has already been built and written to a -// buffer.VectorisedView. -func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: s, - }) -} - -// SendPacket builds and sends a TCP segment(with the provided payload & TCP -// headers) in an IPv4 packet via the link layer endpoint. -func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: c.BuildSegment(payload, h), - }) -} - -// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload -// & TCPheaders) in an IPv4 packet via the link layer endpoint using the -// provided source and destination IPv4 addresses. -func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: c.BuildSegmentWithAddrs(payload, h, src, dst), - }) -} - -// SendAck sends an ACK packet. -func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { - c.SendAckWithSACK(seq, bytesReceived, nil) -} - -// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified. -func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) { - options := make([]byte, 40) - offset := 0 - if len(sackBlocks) > 0 { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) - } - - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options[:offset], - }) -} - -// ReceiveAndCheckPacket reads a packet from the link layer endpoint and -// verifies that the packet packet payload of packet matches the slice -// of data indicated by offset & size. -func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { - c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) -} - -// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size and skips optlen bytes in addition to the IP -// TCP headers when comparing the data. -func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize+optlen), - checker.TCP( - checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } -} - -// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size. It returns true if a packet was received and -// processed. -func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { - b := c.GetPacketNonBlocking() - if b == nil { - return false - } - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } - return true -} - -// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only -// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 -// only endpoint instead of a default dual stack socket. -func (c *Context) CreateV6Endpoint(v6only bool) { - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } -} - -// GetV6Packet reads a single packet from the link layer endpoint of the context -// and asserts that it is an IPv6 Packet with the expected src/dest addresses. -func (c *Context) GetV6Packet() []byte { - c.t.Helper() - - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b -} - -// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of -// the context. -func (c *Context) SendV6Packet(payload []byte, h *Headers) { - c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) -} - -// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer -// endpoint of the context using the provided source and destination IPv6 -// addresses. -func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - NextHeader: uint8(tcp.ProtocolNumber), - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, - }) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv6MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: header.TCPMinimumSize, - Flags: uint8(h.Flags), - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -// CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { - c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) -} - -// Connect performs the 3-way handshake for c.EP with the provided Initial -// Sequence Number (iss) and receive window(rcvWnd) and any options if -// specified. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -// -// PreCondition: c.EP must already be created. -func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { - c.t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - c.SendPacket(nil, &Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: options, - }) - - // Receive ACK packet. - checker.IPv4(c.t, c.GetPacket(), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.Port = tcpHdr.SourcePort() -} - -// Create creates a TCP endpoint. -func (c *Context) Create(epRcvBuf int) { - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if epRcvBuf != -1 { - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - } -} - -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { - c.Create(epRcvBuf) - c.Connect(iss, rcvWnd, options) -} - -// RawEndpoint is just a small wrapper around a TCP endpoint's state to make -// sending data and ACK packets easy while being able to manipulate the sequence -// numbers and timestamp values as needed. -type RawEndpoint struct { - C *Context - SrcPort uint16 - DstPort uint16 - Flags int - NextSeqNum seqnum.Value - AckNum seqnum.Value - WndSize seqnum.Size - RecentTS uint32 // Stores the latest timestamp to echo back. - TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. - - // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. - SACKPermitted bool -} - -// SendPacketWithTS embeds the provided tsVal in the Timestamp option -// for the packet to be sent out. -func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { - r.TSVal = tsVal - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) - r.SendPacket(payload, tsOpt[:]) -} - -// SendPacket is a small wrapper function to build and send packets. -func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { - packetHeaders := &Headers{ - SrcPort: r.SrcPort, - DstPort: r.DstPort, - Flags: r.Flags, - SeqNum: r.NextSeqNum, - AckNum: r.AckNum, - RcvWnd: r.WndSize, - TCPOpts: opts, - } - r.C.SendPacket(payload, packetHeaders) - r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) -} - -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { - // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.TCPTimestampChecker(true, 0, tsVal), - ), - ) - // Store the parsed TSVal from the ack as recentTS. - tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) - opts := tcpSeg.ParsedOptions() - r.RecentTS = opts.TSVal -} - -// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK -// matches the provided rcvWnd. -func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.Window(rcvWnd), - ), - ) -} - -// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. -func (r *RawEndpoint) VerifyACKNoSACK() { - r.VerifyACKHasSACK(nil) -} - -// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. -func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { - // Read ACK and verify that the TCP options in the segment do - // not contain a SACK block. - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.TCPSACKBlockChecker(sackBlocks), - ), - ) -} - -// CreateConnectedWithOptions creates and connects c.ep with the specified TCP -// options enabled and returns a RawEndpoint which represents the other end of -// the connection. -// -// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK -// does not carry an option that was not requested. -func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} - err = c.EP.Connect(testFullAddr) - if err != tcpip.ErrConnectStarted { - c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) - } - // Receive SYN packet. - b := c.GetPacket() - // Validate that the syn has the timestamp option and a valid - // TS value. - mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) - - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{ - MSS: mss, - TS: true, - WS: int(c.WindowScale), - SACKPermitted: c.SACKEnabled(), - }), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpSeg := header.TCP(header.IPv4(b).Payload()) - synOptions := header.ParseSynOptions(tcpSeg.Options(), false) - - // Build options w/ tsVal to be sent in the SYN-ACK. - synAckOptions := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - if wantOptions.WS != -1 { - offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:]) - } - if wantOptions.TS { - offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) - } - if wantOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) - } - - offset += header.AddTCPOptionPadding(synAckOptions, offset) - - // Build SYN-ACK. - c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) - iss := seqnum.Value(testInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - TCPOpts: synAckOptions[:offset], - }) - - // Read ACK. - ackPacket := c.GetPacket() - - // Verify TCP header fields. - tcpCheckers := []checker.TransportChecker{ - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS) + 1), - checker.AckNum(uint32(iss) + 1), - } - - // Verify that tsEcr of ACK packet is wantOptions.TSVal if the - // timestamp option was enabled, if not then we verify that - // there is no timestamp in the ACK packet. - if wantOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) - - ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) - ackOptions := ackSeg.ParsedOptions() - - // Wait for connection to be established. - select { - case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Store the source port in use by the endpoint. - c.Port = tcpSeg.SourcePort() - - // Mark in context that timestamp option is enabled for this endpoint. - c.TimeStampEnabled = true - - return &RawEndpoint{ - C: c, - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagAck | header.TCPFlagPsh, - NextSeqNum: iss + 1, - AckNum: c.IRS.Add(1), - WndSize: 30000, - RecentTS: ackOptions.TSVal, - TSVal: wantOptions.TSVal, - SACKPermitted: wantOptions.SACKPermitted, - } -} - -// AcceptWithOptions initializes a listening endpoint and connects to it with the -// provided options enabled. It also verifies that the SYN-ACK has the expected -// values for the provided options. -// -// The function returns a RawEndpoint representing the other end of the accepted -// endpoint. -func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if err := ep.Listen(10); err != nil { - c.t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - c.t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - return rep -} - -// PassiveConnect just disables WindowScaling and delegates the call to -// PassiveConnectWithOptions. -func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { - synOptions.WS = -1 - c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) -} - -// PassiveConnectWithOptions initiates a new connection (with the specified TCP -// options enabled) to the port on which the Context.ep is listening for new -// connections. It also validates that the SYN-ACK has the expected values for -// the enabled options. -// -// NOTE: MSS is not a negotiated option and it can be asymmetric -// in each direction. This function uses the maxPayload to set the MSS to be -// sent to the peer on a connect and validates that the MSS in the SYN-ACK -// response is equal to the MTU - (tcphdr len + iphdr len). -// -// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the -// value of the window scaling option to be sent in the SYN. If synOptions.WS > -// 0 then we send the WindowScale option. -func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - opts := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - offset += header.EncodeMSSOption(uint32(maxPayload), opts) - - if synOptions.WS >= 0 { - offset += header.EncodeWSOption(3, opts[offset:]) - } - if synOptions.TS { - offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) - } - - if synOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(opts[offset:]) - } - - paddingToAdd := 4 - offset%4 - // Now add any padding bytes that might be required to quad align the - // options. - for i := offset; i < offset+paddingToAdd; i++ { - opts[i] = header.TCPOptionNOP - } - offset += paddingToAdd - - // Send a SYN request. - iss := seqnum.Value(testInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - TCPOpts: opts[:offset], - }) - - // Receive the SYN-ACK reply. Make sure MSS and other expected options - // are present. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(StackPort), - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(iss) + 1), - checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), - } - - // If TS option was enabled in the original SYN then add a checker to - // validate the Timestamp option in the SYN-ACK. - if synOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) - rcvWnd := seqnum.Size(30000) - ackHeaders := &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: rcvWnd, - } - - // If WS was expected to be in effect then scale the advertised window - // correspondingly. - if synOptions.WS > 0 { - ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) - } - - parsedOpts := tcp.ParsedOptions() - if synOptions.TS { - // Echo the tsVal back to the peer in the tsEcr field of the - // timestamp option. - // Increment TSVal by 1 from the value sent in the SYN and echo - // the TSVal in the SYN-ACK in the TSEcr field. - opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) - ackHeaders.TCPOpts = opts[:] - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.Port = StackPort - - return &RawEndpoint{ - C: c, - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagPsh | header.TCPFlagAck, - NextSeqNum: iss + 1, - AckNum: c.IRS + 1, - WndSize: rcvWnd, - SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), - RecentTS: parsedOpts.TSVal, - TSVal: synOptions.TSVal + 1, - } -} - -// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true -// for the Stack in the context. -func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { - // Stack doesn't support SACK. So just return. - return false - } - return bool(v) -} - -// SetGSOEnabled enables or disables generic segmentation offload. -func (c *Context) SetGSOEnabled(enable bool) { - if enable { - c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO - } else { - c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO - } -} - -// MSSWithoutOptions returns the value for the MSS used by the stack when no -// options are in use. -func (c *Context) MSSWithoutOptions() uint16 { - return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) -} - -// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no -// options are in use for IPv6 packets. -func (c *Context) MSSWithoutOptionsV6() uint16 { - return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD deleted file mode 100644 index 3ad6994a7..000000000 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tcpconntrack", - srcs = ["tcp_conntrack.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcpconntrack_test", - size = "small", - srcs = ["tcp_conntrack_test.go"], - deps = [ - ":tcpconntrack", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 93712cd45..93712cd45 100644..100755 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go deleted file mode 100644 index 5e271b7ca..000000000 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpconntrack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" -) - -// connected creates a connection tracker TCB and sets it to a connected state -// by performing a 3-way handshake. -func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: iss, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: irw, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: irs, - AckNum: iss + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: isw, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: iss + 1, - AckNum: irs + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: irw, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - return &tcb -} - -func TestConnectionRefused(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive RST. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionRefusedInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionResetInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestRetransmitOnSynSent(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Retransmit the same SYN. - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) - } -} - -func TestRetransmitOnSynRcvd(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. This will cause the state to go to SYN-RCVD. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Retransmit the original SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Transmit a SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} - -func TestClosedBySelf(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Send FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1236, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestClosedByPeer(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Receive FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 791, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) - } -} - -func TestSendAndReceiveDataClosedBySelf(t *testing.T) { - sseq := uint32(1234) - rseq := uint32(789) - tcb := connected(t, sseq, rseq, 30000, 50000) - sseq++ - rseq++ - - // Send some data. - tcp := make(header.TCP, header.TCPMinimumSize+1024) - - for i := uint32(0); i < 10; i++ { - // Send some data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - sseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - for i := uint32(0); i < 10; i++ { - // Receive some data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - rseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - // Send FIN. - tcp = tcp[:header.TCPMinimumSize] - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - sseq++ - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - rseq++ - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestIgnoreBadResetOnSynSent(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive a RST with a bad ACK, it should not cause the connection to - // be reset. - acks := []uint32{1234, 1236, 1000, 5000} - flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} - for _, a := range acks { - for _, f := range flags { - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: a, - DataOffset: header.TCPMinimumSize, - Flags: f, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - } - - // Complete the handshake. - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} diff --git a/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go new file mode 100755 index 000000000..ff53204da --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package tcpconntrack diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD deleted file mode 100644 index adc908e24..000000000 --- a/pkg/tcpip/transport/udp/BUILD +++ /dev/null @@ -1,61 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "udp_packet_list", - out = "udp_packet_list.go", - package = "udp", - prefix = "udpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*udpPacket", - "Linker": "*udpPacket", - }, -) - -go_library( - name = "udp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "udp_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - ], -) - -go_test( - name = "udp_x_test", - size = "small", - srcs = ["udp_test.go"], - deps = [ - ":udp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go new file mode 100755 index 000000000..2ae846eaa --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,186 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *udpPacketList) PushFront(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *udpPacketList) PushBack(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + bLinker := udpPacketElementMapper{}.linkerFor(b) + eLinker := udpPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + aLinker := udpPacketElementMapper{}.linkerFor(a) + eLinker := udpPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *udpPacketList) Remove(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go new file mode 100755 index 000000000..5cf9848cc --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -0,0 +1,144 @@ +// automatically generated by stateify. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *udpPacket) beforeSave() {} +func (x *udpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("udpPacketEntry", &x.udpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("packetInfo", &x.packetInfo) + m.Save("timestamp", &x.timestamp) + m.Save("tos", &x.tos) +} + +func (x *udpPacket) afterLoad() {} +func (x *udpPacket) load(m state.Map) { + m.Load("udpPacketEntry", &x.udpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("packetInfo", &x.packetInfo) + m.Load("timestamp", &x.timestamp) + m.Load("tos", &x.tos) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("uniqueID", &x.uniqueID) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("state", &x.state) + m.Save("dstPort", &x.dstPort) + m.Save("v6only", &x.v6only) + m.Save("ttl", &x.ttl) + m.Save("multicastTTL", &x.multicastTTL) + m.Save("multicastAddr", &x.multicastAddr) + m.Save("multicastNICID", &x.multicastNICID) + m.Save("multicastLoop", &x.multicastLoop) + m.Save("reusePort", &x.reusePort) + m.Save("bindToDevice", &x.bindToDevice) + m.Save("broadcast", &x.broadcast) + m.Save("boundBindToDevice", &x.boundBindToDevice) + m.Save("boundPortFlags", &x.boundPortFlags) + m.Save("sendTOS", &x.sendTOS) + m.Save("receiveTOS", &x.receiveTOS) + m.Save("receiveTClass", &x.receiveTClass) + m.Save("receiveIPPacketInfo", &x.receiveIPPacketInfo) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("multicastMemberships", &x.multicastMemberships) + m.Save("effectiveNetProtos", &x.effectiveNetProtos) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("uniqueID", &x.uniqueID) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("state", &x.state) + m.Load("dstPort", &x.dstPort) + m.Load("v6only", &x.v6only) + m.Load("ttl", &x.ttl) + m.Load("multicastTTL", &x.multicastTTL) + m.Load("multicastAddr", &x.multicastAddr) + m.Load("multicastNICID", &x.multicastNICID) + m.Load("multicastLoop", &x.multicastLoop) + m.Load("reusePort", &x.reusePort) + m.Load("bindToDevice", &x.bindToDevice) + m.Load("broadcast", &x.broadcast) + m.Load("boundBindToDevice", &x.boundBindToDevice) + m.Load("boundPortFlags", &x.boundPortFlags) + m.Load("sendTOS", &x.sendTOS) + m.Load("receiveTOS", &x.receiveTOS) + m.Load("receiveTClass", &x.receiveTClass) + m.Load("receiveIPPacketInfo", &x.receiveIPPacketInfo) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("multicastMemberships", &x.multicastMemberships) + m.Load("effectiveNetProtos", &x.effectiveNetProtos) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *multicastMembership) beforeSave() {} +func (x *multicastMembership) save(m state.Map) { + x.beforeSave() + m.Save("nicID", &x.nicID) + m.Save("multicastAddr", &x.multicastAddr) +} + +func (x *multicastMembership) afterLoad() {} +func (x *multicastMembership) load(m state.Map) { + m.Load("nicID", &x.nicID) + m.Load("multicastAddr", &x.multicastAddr) +} + +func (x *udpPacketList) beforeSave() {} +func (x *udpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *udpPacketList) afterLoad() {} +func (x *udpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *udpPacketEntry) beforeSave() {} +func (x *udpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *udpPacketEntry) afterLoad() {} +func (x *udpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/tcpip/transport/udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load}) + state.Register("pkg/tcpip/transport/udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("pkg/tcpip/transport/udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load}) + state.Register("pkg/tcpip/transport/udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load}) + state.Register("pkg/tcpip/transport/udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go deleted file mode 100644 index 34b7c2360..000000000 --- a/pkg/tcpip/transport/udp/udp_test.go +++ /dev/null @@ -1,1802 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_test - -import ( - "bytes" - "context" - "fmt" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// Addresses and ports used for testing. It is recommended that tests stick to -// using these addresses as it allows using the testFlow helper. -// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' -// represents the remote endpoint. -const ( - v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = v4MappedAddrPrefix + stackAddr - testV4MappedAddr = v4MappedAddrPrefix + testAddr - multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr - broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr - v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 - multicastAddr = "\xe8\x2b\xd3\xea" - multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - broadcastAddr = header.IPv4Broadcast - testTOS = 0x80 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in -// a packet header. These values are used to populate a header or verify one. -// Note that because they are used in packet headers, the addresses are never in -// a V4-mapped format. -type header4Tuple struct { - srcAddr tcpip.FullAddress - dstAddr tcpip.FullAddress -} - -// testFlow implements a helper type used for sending and receiving test -// packets. A given test flow value defines 1) the socket endpoint used for the -// test and 2) the type of packet send or received on the endpoint. E.g., a -// multicastV6Only flow is a V6 multicast packet passing through a V6-only -// endpoint. The type provides helper methods to characterize the flow (e.g., -// isV4) as well as return a proper header4Tuple for it. -type testFlow int - -const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket -) - -func (flow testFlow) String() string { - switch flow { - case unicastV4: - return "unicastV4" - case unicastV6: - return "unicastV6" - case unicastV6Only: - return "unicastV6Only" - case unicastV4in6: - return "unicastV4in6" - case multicastV4: - return "multicastV4" - case multicastV6: - return "multicastV6" - case multicastV6Only: - return "multicastV6Only" - case multicastV4in6: - return "multicastV4in6" - case broadcast: - return "broadcast" - case broadcastIn6: - return "broadcastIn6" - default: - return "unknown" - } -} - -// packetDirection explains if a flow is incoming (read) or outgoing (write). -type packetDirection int - -const ( - incoming packetDirection = iota - outgoing -) - -// header4Tuple returns the header4Tuple for the given flow and direction. Note -// that the tuple contains no mapped addresses as those only exist at the socket -// level but not at the packet header level. -func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { - var h header4Tuple - if flow.isV4() { - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastAddr - } else if flow.isBroadcast() { - h.dstAddr.Addr = broadcastAddr - } - } else { // IPv6 - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastV6Addr - } - } - return h -} - -func (flow testFlow) getMcastAddr() tcpip.Address { - if flow.isV4() { - return multicastAddr - } - return multicastV6Addr -} - -// mapAddrIfApplicable converts the given V4 address into its V4-mapped version -// if it is applicable to the flow. -func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { - if flow.isMapped() { - return v4MappedAddrPrefix + v4Addr - } - return v4Addr -} - -// netProto returns the protocol number used for the network packet. -func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { - if flow.isV4() { - return ipv4.ProtocolNumber - } - return ipv6.ProtocolNumber -} - -// sockProto returns the protocol number used when creating the socket -// endpoint for this flow. -func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { - switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: - return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: - return ipv4.ProtocolNumber - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { - if flow.isV4() { - return checker.IPv4 - } - return checker.IPv6 -} - -func (flow testFlow) isV6() bool { return !flow.isV4() } -func (flow testFlow) isV4() bool { - return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() -} - -func (flow testFlow) isV6Only() bool { - switch flow { - case unicastV6Only, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMulticast() bool { - switch flow { - case multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isBroadcast() bool { - switch flow { - case broadcast, broadcastIn6: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMapped() bool { - switch flow { - case unicastV4in6, multicastV4in6, broadcastIn6: - return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func newDualTestContext(t *testing.T, mtu uint32) *testContext { - t.Helper() - return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - }) -} - -func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { - t.Helper() - - s := stack.New(options) - ep := channel.New(256, mtu, "") - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEP: ep, - } -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { - c.t.Helper() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) - if err != nil { - c.t.Fatal("NewEndpoint failed: ", err) - } -} - -func (c *testContext) createEndpointForFlow(flow testFlow) { - c.t.Helper() - - c.createEndpoint(flow.sockProto()) - if flow.isV6Only() { - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - } else if flow.isBroadcast() { - if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { - c.t.Fatal("SetSockOpt failed:", err) - } - } -} - -// getPacketAndVerify reads a packet from the link endpoint and verifies the -// header against expected values from the given test flow. In addition, it -// calls any extra checker functions provided. -func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { - c.t.Helper() - - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - h := flow.header4Tuple(outgoing) - checkers = append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b -} - -// injectPacket creates a packet of the given flow and with the given payload, -// and injects it into the link endpoint. -func (c *testContext) injectPacket(flow testFlow, payload []byte) { - c.t.Helper() - - h := flow.header4Tuple(incoming) - if flow.isV4() { - c.injectV4Packet(payload, &h, true /* valid */) - } else { - c.injectV6Packet(payload, &h, true /* valid */) - } -} - -// injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - l := uint16(header.UDPMinimumSize + len(payload)) - if !valid { - // Change the UDP payload length to corrupt the header - // as requested by the caller. - l++ - } - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: l, - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), - }) -} - -// injectV4Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TOS: testTOS, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), - }) -} - -func newPayload() []byte { - return newMinPayload(30) -} - -func newMinPayload(minSize int) []byte { - b := make([]byte, minSize+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer ep.Close() - - opts := stack.NICOptions{Name: "my_device"} - if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) - } - } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %v, want %v", err, nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %d, want %d", got, want) - } - }) - } -} - -// testReadInternal sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then attempts to read it from the -// UDP endpoint and depending on if this was expected to succeed verifies its -// correctness including any additional checker functions provided. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - - payload := newPayload() - c.injectPacket(flow, payload) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - var addr tcpip.FullAddress - v, cm, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, cm, err = c.ep.Read(&addr) - - case <-time.After(300 * time.Millisecond): - if packetShouldBeDropped { - return // expected to time out - } - c.t.Fatal("timed out waiting for data") - } - } - - if expectReadError && err != nil { - c.checkEndpointReadStats(1, epstats, err) - return - } - - if err != nil { - c.t.Fatal("Read failed:", err) - } - - if packetShouldBeDropped { - c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) - } - - // Check the peer address. - h := flow.header4Tuple(incoming) - if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("bad payload: got %x, want %x", v, payload) - } - - // Run any checkers against the ControlMessages. - for _, f := range checkers { - f(c.t, cm) - } - - c.checkEndpointReadStats(1, epstats, err) -} - -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness including any additional checker functions provided. -func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) -} - -// testFailingRead sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then tries to read it from the UDP -// endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { - c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) -} - -func TestBindEphemeralPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } -} - -func TestBindReservedPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - addr, err := c.ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) - } - - // We can't bind the address reserved by the connected endpoint above. - { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) - } - } - - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - // We can't bind ipv4-any on the port reserved by the connected endpoint - // above, since the endpoint is dual-stack. - if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) - } - // We can bind an ipv4 address on this port, though. - if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - }() - - // Once the connected endpoint releases its port reservation, we are able to - // bind ipv4-any once again. - c.ep.Close() - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - }() -} - -func TestV4ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to local address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV6ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV6) -} - -// TestV4ReadSelfSource checks that packets coming from a local IP address are -// correctly dropped when handleLocal is true and not otherwise. -func TestV4ReadSelfSource(t *testing.T) { - for _, tt := range []struct { - name string - handleLocal bool - wantErr *tcpip.Error - wantInvalidSource uint64 - }{ - {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, - } { - t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - HandleLocal: tt.handleLocal, - }) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - h.srcAddr = h.dstAddr - - c.injectV4Packet(payload, &h, true /* valid */) - - if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { - t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) - } - - if _, _, err := c.ep.Read(nil); err != tt.wantErr { - t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr) - } - }) - } -} - -func TestV4ReadOnV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4) -} - -// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast -// address and receive data sent to that address. -func TestReadOnBoundToMulticast(t *testing.T) { - // FIXME(b/128189410): multicastV4in6 currently doesn't work as - // AddMembershipOption doesn't handle V4in6 addresses. - for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to multicast address. - mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) - if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - // Join multicast group. - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatal("SetSockOpt failed:", err) - } - - // Check that we receive multicast packets but not unicast or broadcast - // ones. - testRead(c, flow) - testFailingRead(c, broadcast, false /* expectReadError */) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and can receive only broadcast data. -func TestV4ReadOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to broadcast address. - bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) - if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Check that we receive broadcast packets but not unicast ones. - testRead(c, flow) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY -// and receive broadcast and unicast data. -func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s (", err) - } - - // Check that we receive both broadcast and unicast packets. - testRead(c, flow) - testRead(c, unicastV4) - }) - } -} - -// testFailingWrite sends a packet of the given test flow into the UDP endpoint -// and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - - payload := buffer.View(newPayload()) - _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - }) - c.checkEndpointWriteStats(1, epstats, gotErr) - if gotErr != wantErr { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) - } -} - -// testWrite sends a packet of the given test flow from the UDP endpoint to the -// flow's destination address:port. It then receives it from the link endpoint -// and verifies its correctness including any additional checker functions -// provided. -func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteInternal(c, flow, true, checkers...) -} - -// testWriteWithoutDestination sends a packet of the given test flow from the -// UDP endpoint without giving a destination address:port. It then receives it -// from the link endpoint and verifies its correctness including any additional -// checker functions provided. -func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteInternal(c, flow, false, checkers...) -} - -func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - writeOpts := tcpip.WriteOptions{} - if setDest { - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - writeOpts = tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - } - } - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - c.checkEndpointWriteStats(1, epstats, err) - // Received the packet and check the payload. - b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP - if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) - } else { - udp = header.UDP(header.IPv6(b).Payload()) - } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } - - return udp.SourcePort() -} - -func testDualWrite(c *testContext) uint16 { - c.t.Helper() - - v4Port := testWrite(c, unicastV4in6) - v6Port := testWrite(c, unicastV6) - if v4Port != v6Port { - c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) - } - - return v4Port -} - -func TestDualWriteUnbound(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) -} - -func TestDualWriteBoundToWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - p := testDualWrite(c) - if p != stackPort { - c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) - } -} - -func TestDualWriteConnectedToV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testWrite(c, unicastV6) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) - const want = 1 - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { - c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) - } -} - -func TestDualWriteConnectedToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testWrite(c, unicastV4in6) - - // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) -} - -func TestV4WriteOnV6Only(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6Only) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) -} - -func TestV6WriteOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to v4 mapped address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) -} - -func TestV6WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - testWriteWithoutDestination(c, unicastV6) -} - -func TestV4WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - testWriteWithoutDestination(c, unicastV4) -} - -// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket -// that is bound to a V4 multicast address. -func TestWriteOnBoundToV4Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a -// socket that is bound to a V4-mapped multicast address. -func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV6, multicastV6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// V6-only socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToBroadcast checks that we can send packets out of a -// socket that is bound to the broadcast address. -func TestWriteOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 broadcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a -// socket that is bound to the V4-mapped broadcast address. -func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -func TestReadIncrementsPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - // Create IPv4 UDP endpoint - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testRead(c, unicastV4) - - var want uint64 = 1 - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { - c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) - } -} - -func TestWriteIncrementsPacketsSent(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) - - var want uint64 = 2 - if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { - c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) - } -} - -func TestTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - var wantTTL uint8 - if flow.isMulticast() { - wantTTL = multicastTTL - } else { - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } -} - -func TestSetTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) - if err != nil { - t.Fatal(err) - } - ep.Close() - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } - }) - } -} - -func TestSetTOS(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tos = testTOS - var v tcpip.IPv4TOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) - } - - if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err) - } - - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) - } - - if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) - } - - testWrite(c, flow, checker.TOS(tos, 0)) - }) - } -} - -func TestSetTClass(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tClass = testTOS - var v tcpip.IPv6TrafficClassOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) - } - - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err) - } - - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) - } - - if want := tcpip.IPv6TrafficClassOption(tClass); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) - } - - // The header getter for TClass is called TOS, so use that checker. - testWrite(c, flow, checker.TOS(tClass, 0)) - }) - } -} - -func TestReceiveTosTClass(t *testing.T) { - testCases := []struct { - name string - getReceiveOption tcpip.SockOptBool - tests []testFlow - }{ - {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, - {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, - } - for _, testCase := range testCases { - for _, flow := range testCase.tests { - t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - option := testCase.getReceiveOption - name := testCase.name - - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) - } - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) - } - - want := true - if err := c.ep.SetSockOptBool(option, want); err != nil { - c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) - } - - got, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) - } - - if got != want { - c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) - } - - // Verify that the correct received TOS or TClass is handed through as - // ancillary data to the ControlMessages struct. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - switch option { - case tcpip.ReceiveTClassOption: - testRead(c, flow, checker.ReceiveTClass(testTOS)) - case tcpip.ReceiveTOSOption: - testRead(c, flow, checker.ReceiveTOS(testTOS)) - default: - t.Fatalf("unknown test variant: %s", name) - } - }) - } - } -} - -func TestMulticastInterfaceOption(t *testing.T) { - for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - h := flow.header4Tuple(outgoing) - mcastAddr := h.dstAddr.Addr - localIfAddr := h.srcAddr.Addr - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: flow.mapAddrIfApplicable(mcastAddr), - Port: stackPort, - } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - } - - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) - } -} - -// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV4UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true, will result in a payload large enough - // so that the final generated IPv4 packet is larger than - // header.IPv4MinimumProcessableDatagramSize. - largePayload bool - }{ - {unicastV4, true, false}, - {unicastV4, true, true}, - {multicastV4, false, false}, - {multicastV4, false, true}, - {broadcast, false, false}, - {broadcast, false, true}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(576) - } - c.injectPacket(tc.flow, payload) - if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) - if p, ok := c.linkEP.ReadContext(ctx); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize - } - - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - }) - } -} - -// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV6UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true will result in a payload large enough to - // create an IPv6 packet > header.IPv6MinimumMTU bytes. - largePayload bool - }{ - {unicastV6, true, false}, - {unicastV6, true, true}, - {multicastV6, false, false}, - {multicastV6, false, true}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(1280) - } - c.injectPacket(tc.flow, payload) - if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) - if p, ok := c.linkEP.ReadContext(ctx); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - }) - } -} - -// TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats get incremented. -func TestIncrementMalformedPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - payload := newPayload() - c.t.Helper() - h := unicastV6.header4Tuple(incoming) - c.injectV6Packet(payload, &h, false /* !valid */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) - } -} - -// TestShutdownRead verifies endpoint read shutdown and error -// stats increment on packet receive. -func TestShutdownRead(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - testFailingRead(c, unicastV6, true /* expectReadError */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { - t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) - } -} - -// TestShutdownWrite verifies endpoint write shutdown and error -// stats increment on packet write. -func TestShutdownWrite(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) -} - -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil: - want.PacketsSent.IncrementBy(incr) - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: - want.WriteErrors.InvalidArgs.IncrementBy(incr) - case tcpip.ErrClosedForSend: - want.WriteErrors.WriteClosed.IncrementBy(incr) - case tcpip.ErrInvalidEndpointState: - want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoLinkAddress: - want.SendErrors.NoLinkAddr.IncrementBy(incr) - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: - want.SendErrors.NoRoute.IncrementBy(incr) - default: - want.SendErrors.SendToNetworkFailed.IncrementBy(incr) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil, tcpip.ErrWouldBlock: - case tcpip.ErrClosedForReceive: - want.ReadErrors.ReadClosed.IncrementBy(incr) - default: - c.t.Errorf("Endpoint error missing stats update err %v", err) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD deleted file mode 100644 index 2dcba84ae..000000000 --- a/pkg/tmutex/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tmutex", - srcs = ["tmutex.go"], - visibility = ["//:sandbox"], -) - -go_test( - name = "tmutex_test", - size = "medium", - srcs = ["tmutex_test.go"], - library = ":tmutex", - deps = ["//pkg/sync"], -) diff --git a/pkg/tmutex/tmutex_state_autogen.go b/pkg/tmutex/tmutex_state_autogen.go new file mode 100755 index 000000000..2336683e3 --- /dev/null +++ b/pkg/tmutex/tmutex_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package tmutex diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go deleted file mode 100644 index 05540696a..000000000 --- a/pkg/tmutex/tmutex_test.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmutex - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestBasicLock(t *testing.T) { - var m Mutex - m.Init() - - m.Lock() - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - ch <- struct{}{} - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } - - // Make sure we can lock and unlock again. - m.Lock() - m.Unlock() -} - -func TestTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Try to lock. It should succeed. - if !m.TryLock() { - t.Fatalf("TryLock failed on unlocked mutex") - } - - // Try to lock again, it should now fail. - if m.TryLock() { - t.Fatalf("TryLock succeeded on locked mutex") - } - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } -} - -func TestMutualExclusion(t *testing.T) { - var m Mutex - m.Init() - - // Test mutual exclusion by running "gr" goroutines concurrently, and - // have each one increment a counter "iters" times within the critical - // section established by the mutex. - // - // If at the end the counter is not gr * iters, then we know that - // goroutines ran concurrently within the critical section. - // - // If one of the goroutines doesn't complete, it's likely a bug that - // causes to it to wait forever. - const gr = 1000 - const iters = 100000 - v := 0 - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(1) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - } - - wg.Wait() - - if v != gr*iters { - t.Fatalf("Bad count: got %v, want %v", v, gr*iters) - } -} - -func TestMutualExclusionWithTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Similar to the previous, with the addition of some goroutines that - // only increment the count if TryLock succeeds. - const gr = 1000 - const iters = 100000 - total := int64(gr * iters) - var tryTotal int64 - v := int64(0) - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(2) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - go func() { - local := int64(0) - for j := 0; j < iters; j++ { - if m.TryLock() { - v++ - m.Unlock() - local++ - } - } - atomic.AddInt64(&tryTotal, local) - wg.Done() - }() - } - - wg.Wait() - - t.Logf("tryTotal = %d", tryTotal) - total += tryTotal - - if v != total { - t.Fatalf("Bad count: got %v, want %v", v, total) - } -} - -// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following -// differences: -// -// - The number of goroutines is variable, with the maximum value depending on -// GOMAXPROCS. -// -// - The number of iterations per benchmark is controlled by the benchmarking -// framework. -// -// - Care is taken to ensure that all goroutines participating in the benchmark -// have been created before the benchmark begins. -func BenchmarkTmutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m Mutex - m.Init() - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} - -// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as -// a comparison point. -func BenchmarkSyncMutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m sync.Mutex - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD deleted file mode 100644 index a86501fa2..000000000 --- a/pkg/unet/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "unet", - srcs = [ - "unet.go", - "unet_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/gate", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "unet_test", - size = "small", - srcs = [ - "unet_test.go", - ], - library = ":unet", - deps = ["//pkg/sync"], -) diff --git a/pkg/unet/unet_state_autogen.go b/pkg/unet/unet_state_autogen.go new file mode 100755 index 000000000..9bbf31d35 --- /dev/null +++ b/pkg/unet/unet_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package unet diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go deleted file mode 100644 index 5c4b9e8e9..000000000 --- a/pkg/unet/unet_test.go +++ /dev/null @@ -1,736 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package unet - -import ( - "io/ioutil" - "os" - "path/filepath" - "reflect" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func randomFilename() (string, error) { - // Return a randomly generated file in the test dir. - f, err := ioutil.TempFile("", "unet-test") - if err != nil { - return "", err - } - file := f.Name() - os.Remove(file) - f.Close() - - cwd, err := os.Getwd() - if err != nil { - return "", err - } - - // NOTE(b/26918832): We try to use relative path if possible. This is - // to help conforming to the unix path length limit. - if rel, err := filepath.Rel(cwd, file); err == nil { - return rel, nil - } - - return file, nil -} - -func TestConnectFailure(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - if _, err := Connect(name, false); err == nil { - t.Fatalf("connect was successful, expected err") - } -} - -func TestBindFailure(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) - } - defer ss.Close() - - if _, err = BindAndListen(name, false); err == nil { - t.Fatalf("second bind succeeded, expected non-nil err") - } -} - -func TestMultipleAccept(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) - } - defer ss.Close() - - // Connect backlog times asynchronously. - var wg sync.WaitGroup - defer wg.Wait() - for i := 0; i < backlog; i++ { - wg.Add(1) - go func() { - defer wg.Done() - s, err := Connect(name, false) - if err != nil { - t.Fatalf("connect failed, got err %v expected nil", err) - } - s.Close() - }() - } - - // Accept backlog times. - for i := 0; i < backlog; i++ { - s, err := ss.Accept() - if err != nil { - t.Errorf("accept failed, got err %v expected nil", err) - continue - } - s.Close() - } -} - -func TestServerClose(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) - } - - // Make sure the first close succeeds. - if err := ss.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) - } - - // The second one should fail. - if err := ss.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") - } -} - -func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - // Bind a server. - ss, err := BindAndListen(name, packet) - if err != nil { - t.Fatalf("error binding, got %v expected nil", err) - } - defer ss.Close() - - // Accept a client. - acceptSocket := make(chan *Socket) - acceptErr := make(chan error) - go func() { - server, err := ss.Accept() - if err != nil { - acceptErr <- err - } - acceptSocket <- server - }() - - // Connect the client. - client, err := Connect(name, packet) - if err != nil { - t.Fatalf("error connecting, got %v expected nil", err) - } - - // Grab the server handle. - select { - case server := <-acceptSocket: - return server, client - case err := <-acceptErr: - t.Fatalf("accept error: %v", err) - } - panic("unreachable") -} - -func TestSendRecv(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'b'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } -} - -// TestSymmetric exists to assert that the two sockets received from socketPair -// are interchangeable. They should be, this just provides a basic sanity check -// by running TestSendRecv "backwards". -func TestSymmetric(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the server. - w := server.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the client. - b := [][]byte{{'b'}} - r := client.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } -} - -func TestPacket(t *testing.T) { - server, client := socketPair(t, true) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Write on the client again. - w = client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the server. - // - // This should only get back a single byte, despite the buffer - // being size two. This is because it's a _packet_ buffer. - b := [][]byte{{'b', 'b'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } - - // Do it again. - r = server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } -} - -func TestClose(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - - // Make sure the first close succeeds. - if err := client.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) - } - - // The second one should fail. - if err := client.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") - } -} - -func TestNonBlockingSend(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Try up to 1000 writes, of 1000 bytes. - blockCount := 0 - for i := 0; i < 1000; i++ { - w := client.Writer(false) - if n, err := w.WriteVec([][]byte{make([]byte, 1000)}); n != 1000 || err != nil { - if err == syscall.EWOULDBLOCK || err == syscall.EAGAIN { - // We're good. That's what we wanted. - blockCount++ - } else { - t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err) - } - } - } - - if blockCount == 1000 { - // Shouldn't have _always_ blocked. - t.Fatalf("socket always blocked!") - } else if blockCount == 0 { - // Should have started blocking eventually. - t.Fatalf("socket never blocked!") - } -} - -func TestNonBlockingRecv(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - b := [][]byte{{'b'}} - r := client.Reader(false) - - // Expected to block immediately. - _, err := r.ReadVec(b) - if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) - } - - // Put some data in the pipe. - w := server.Writer(false) - if n, err := w.WriteVec(b); n != 1 || err != nil { - t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Expect it not to block. - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Expect it to return a block error again. - r = client.Reader(false) - _, err = r.ReadVec(b) - if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) - } -} - -func TestRecvVectors(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'c'}, {'c'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) - } - if b[0][0] != 'a' || b[1][0] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) - } -} - -func TestSendVectors(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'c', 'c'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) - } - if b[0][0] != 'a' || b[0][1] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) - } -} - -func TestSendFDsNotEnabled(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the server. - w := server.Writer(true) - w.PackFDs(0, 1, 2) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the client, without enabling FDs. - b := [][]byte{{'b'}} - r := client.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } - - // Make sure the FDs are not received. - fds, err := r.ExtractFDs() - if len(fds) != 0 || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) - } -} - -func sendFDs(t *testing.T, s *Socket, fds []int) { - w := s.Writer(true) - w.PackFDs(fds...) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err) - } -} - -func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { - expected := len(origFDs) - - // Count the number of FDs. - preEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) - } - - // Read on the client. - b := [][]byte{{'b'}} - r := s.Reader(true) - if enableSize >= 0 { - r.EnableFDs(enableSize) - } - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } - - // Count the new number of FDs. - postEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) - } - if len(preEntries)+expected != len(postEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) - } - - // Make sure the FDs are there. - fds, err := r.ExtractFDs() - if len(fds) != expected || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) - } - - // Make sure they are different from the originals. - for i := 0; i < len(fds); i++ { - if fds[i] == origFDs[i] { - t.Errorf("got original fd for index %d, expected different", i) - } - } - - // Make sure they can be accessed as expected. - for i := 0; i < len(fds); i++ { - var st syscall.Stat_t - if err := syscall.Fstat(fds[i], &st); err != nil { - t.Errorf("fds[%d] can't be stated, got err %v expected nil", i, err) - } - } - - // Close them off. - r.CloseFDs() - - // Make sure the count is back to normal. - finalEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) - } - if len(finalEntries) != len(preEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) - } -} - -func TestFDsSingle(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 1, []int{0}) -} - -func TestFDsMultiple(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Basic case, multiple FDs. - sendFDs(t, server, []int{0, 1, 2}) - recvFDs(t, client, 3, []int{0, 1, 2}) -} - -// See TestSymmetric above. -func TestFDsSymmetric(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0, 1, 2}) - recvFDs(t, client, 3, []int{0, 1, 2}) -} - -func TestFDsReceiveLargeBuffer(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 3, []int{0}) -} - -func TestFDsReceiveSmallBuffer(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0, 1, 2}) - - // Per the spec, we may still receive more than the buffer. In fact, - // it'll be rounded up and we can receive two with a size one buffer. - recvFDs(t, client, 1, []int{0, 1}) -} - -func TestFDsReceiveNotEnabled(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, -1, []int{}) -} - -func TestFDsReceiveSizeZero(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 0, []int{}) -} - -func TestGetPeerCred(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - want := &syscall.Ucred{ - Pid: int32(os.Getpid()), - Uid: uint32(os.Getuid()), - Gid: uint32(os.Getgid()), - } - - if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) - } -} - -func newClosedSocket() (*Socket, error) { - fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - return nil, err - } - - s, err := NewSocket(fd) - if err != nil { - syscall.Close(fd) - return nil, err - } - - return s, s.Close() -} - -func TestGetPeerCredFailure(t *testing.T) { - s, err := newClosedSocket() - if err != nil { - t.Fatalf("newClosedSocket got error %v want nil", err) - } - - want := "bad file descriptor" - if _, err := s.GetPeerCred(); err == nil || err.Error() != want { - t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want) - } -} - -func TestAcceptClosed(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) - } - - if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) - } - - if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) - } -} - -func TestCloseAfterAcceptStart(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - time.Sleep(50 * time.Millisecond) - if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) - } - wg.Done() - }() - - if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) - } - - wg.Wait() -} - -func TestReleaseAfterAcceptStart(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - time.Sleep(50 * time.Millisecond) - fd, err := ss.Release() - if err != nil { - t.Fatalf("Release failed, got err %v expected nil", err) - } - syscall.Close(fd) - wg.Done() - }() - - if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) - } - - wg.Wait() -} - -func TestControlMessage(t *testing.T) { - for i := 0; i <= 10; i++ { - var want []int - for j := 0; j < i; j++ { - want = append(want, i+j+1) - } - - var cm ControlMessage - cm.EnableFDs(i) - cm.PackFDs(want...) - got, err := cm.ExtractFDs() - if err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) - } - } -} - -func benchmarkSendRecv(b *testing.B, packet bool) { - server, client, err := SocketPair(packet) - if err != nil { - b.Fatalf("SocketPair: got %v, wanted nil", err) - } - defer server.Close() - defer client.Close() - go func() { - buf := make([]byte, 1) - for i := 0; i < b.N; i++ { - n, err := server.Read(buf) - if n != 1 || err != nil { - b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err) - } - n, err = server.Write(buf) - if n != 1 || err != nil { - b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err) - } - } - }() - buf := make([]byte, 1) - b.ResetTimer() - for i := 0; i < b.N; i++ { - n, err := client.Write(buf) - if n != 1 || err != nil { - b.Fatalf("client.Write: got (%d, %v), wanted (1, nil)", n, err) - } - n, err = client.Read(buf) - if n != 1 || err != nil { - b.Fatalf("client.Read: got (%d, %v), wanted (1, nil)", n, err) - } - } -} - -func BenchmarkSendRecvStream(b *testing.B) { - benchmarkSendRecv(b, false) -} - -func BenchmarkSendRecvPacket(b *testing.B) { - benchmarkSendRecv(b, true) -} diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD deleted file mode 100644 index 850c34ed0..000000000 --- a/pkg/urpc/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "urpc", - srcs = ["urpc.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/sync", - "//pkg/unet", - ], -) - -go_test( - name = "urpc_test", - size = "small", - srcs = ["urpc_test.go"], - library = ":urpc", - deps = ["//pkg/unet"], -) diff --git a/pkg/urpc/urpc_state_autogen.go b/pkg/urpc/urpc_state_autogen.go new file mode 100755 index 000000000..5fdca6717 --- /dev/null +++ b/pkg/urpc/urpc_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package urpc diff --git a/pkg/urpc/urpc_test.go b/pkg/urpc/urpc_test.go deleted file mode 100644 index c6c7ce9d4..000000000 --- a/pkg/urpc/urpc_test.go +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package urpc - -import ( - "errors" - "os" - "testing" - - "gvisor.dev/gvisor/pkg/unet" -) - -type test struct { -} - -type testArg struct { - StringArg string - IntArg int - FilePayload -} - -type testResult struct { - StringResult string - IntResult int - FilePayload -} - -func (t test) Func(a *testArg, r *testResult) error { - r.StringResult = a.StringArg - r.IntResult = a.IntArg - return nil -} - -func (t test) Err(a *testArg, r *testResult) error { - return errors.New("test error") -} - -func (t test) FailNoFile(a *testArg, r *testResult) error { - if a.Files == nil { - return errors.New("no file found") - } - - return nil -} - -func (t test) SendFile(a *testArg, r *testResult) error { - r.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr} - return nil -} - -func (t test) TooManyFiles(a *testArg, r *testResult) error { - for i := 0; i <= maxFiles; i++ { - r.Files = append(r.Files, os.Stdin) - } - return nil -} - -func startServer(socket *unet.Socket) { - s := NewServer() - s.Register(test{}) - s.StartHandling(socket) -} - -func testClient() (*Client, error) { - serverSock, clientSock, err := unet.SocketPair(false) - if err != nil { - return nil, err - } - startServer(serverSock) - - return NewClient(clientSock), nil -} - -func TestCall(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Func", &testArg{}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.StringResult != "" || r.IntResult != 0 { - t.Errorf("unexpected result, got %v expected zero value", r) - } - if err := c.Call("test.Func", &testArg{StringArg: "hello"}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.StringResult != "hello" { - t.Errorf("unexpected result, got %v expected hello", r.StringResult) - } - if err := c.Call("test.Func", &testArg{IntArg: 1}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.IntResult != 1 { - t.Errorf("unexpected result, got %v expected 1", r.IntResult) - } -} - -func TestUnknownMethod(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Unknown", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != ErrUnknownMethod.Error() { - t.Errorf("expected test error, got %v", err) - } -} - -func TestErr(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Err", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != "test error" { - t.Errorf("expected test error, got %v", err) - } -} - -func TestSendFile(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.FailNoFile", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } - if err := c.Call("test.FailNoFile", &testArg{FilePayload: FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stdin}}}, &r); err != nil { - t.Errorf("expected nil err, got %v", err) - } -} - -func TestRecvFile(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.SendFile", &testArg{}, &r); err != nil { - t.Errorf("expected nil err, got %v", err) - } - if r.Files == nil { - t.Errorf("expected file, got nil") - } -} - -func TestShutdown(t *testing.T) { - serverSock, clientSock, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - clientSock.Close() - - s := NewServer() - if err := s.Handle(serverSock); err == nil { - t.Errorf("expected non-nil err, got nil") - } -} - -func TestTooManyFiles(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - var a testArg - for i := 0; i <= maxFiles; i++ { - a.Files = append(a.Files, os.Stdin) - } - - // Client-side error. - if err := c.Call("test.Func", &a, &r); err != ErrTooManyFiles { - t.Errorf("expected ErrTooManyFiles, got %v", err) - } - - // Server-side error. - if err := c.Call("test.TooManyFiles", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != "too many files" { - t.Errorf("expected too many files, got %v", err.Error()) - } -} diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD deleted file mode 100644 index 6c9ada9c7..000000000 --- a/pkg/usermem/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "addr_range", - out = "addr_range.go", - package = "usermem", - prefix = "Addr", - template = "//pkg/segment:generic_range", - types = { - "T": "Addr", - }, -) - -go_library( - name = "usermem", - srcs = [ - "access_type.go", - "addr.go", - "addr_range.go", - "addr_range_seq_unsafe.go", - "bytes_io.go", - "bytes_io_unsafe.go", - "usermem.go", - "usermem_arm64.go", - "usermem_x86.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/atomicbitops", - "//pkg/binary", - "//pkg/context", - "//pkg/gohacks", - "//pkg/log", - "//pkg/safemem", - "//pkg/syserror", - ], -) - -go_test( - name = "usermem_test", - size = "small", - srcs = [ - "addr_range_seq_test.go", - "usermem_test.go", - ], - library = ":usermem", - deps = [ - "//pkg/context", - "//pkg/safemem", - "//pkg/syserror", - ], -) diff --git a/pkg/usermem/README.md b/pkg/usermem/README.md deleted file mode 100644 index f6d2137eb..000000000 --- a/pkg/usermem/README.md +++ /dev/null @@ -1,31 +0,0 @@ -This package defines primitives for sentry access to application memory. - -Major types: - -- The `IO` interface represents a virtual address space and provides I/O - methods on that address space. `IO` is the lowest-level primitive. The - primary implementation of the `IO` interface is `mm.MemoryManager`. - -- `IOSequence` represents a collection of individually-contiguous address - ranges in a `IO` that is operated on sequentially, analogous to Linux's - `struct iov_iter`. - -Major usage patterns: - -- Access to a task's virtual memory, subject to the application's memory - protections and while running on that task's goroutine, from a context that - is at or above the level of the `kernel` package (e.g. most syscall - implementations in `syscalls/linux`); use the `kernel.Task.Copy*` wrappers - defined in `kernel/task_usermem.go`. - -- Access to a task's virtual memory, from a context that is at or above the - level of the `kernel` package, but where any of the above constraints does - not hold (e.g. `PTRACE_POKEDATA`, which ignores application memory - protections); obtain the task's `mm.MemoryManager` by calling - `kernel.Task.MemoryManager`, and call its `IO` methods directly. - -- Access to a task's virtual memory, from a context that is below the level of - the `kernel` package (e.g. filesystem I/O); clients must pass I/O arguments - from higher layers, usually in the form of an `IOSequence`. The - `kernel.Task.SingleIOSequence` and `kernel.Task.IovecsIOSequence` functions - in `kernel/task_usermem.go` are convenience functions for doing so. diff --git a/pkg/usermem/access_type.go b/pkg/usermem/access_type.go index 9c1742a59..9c1742a59 100644..100755 --- a/pkg/usermem/access_type.go +++ b/pkg/usermem/access_type.go diff --git a/pkg/usermem/addr.go b/pkg/usermem/addr.go index e79210804..e79210804 100644..100755 --- a/pkg/usermem/addr.go +++ b/pkg/usermem/addr.go diff --git a/pkg/usermem/addr_range.go b/pkg/usermem/addr_range.go new file mode 100755 index 000000000..152ed1434 --- /dev/null +++ b/pkg/usermem/addr_range.go @@ -0,0 +1,62 @@ +package usermem + +// A Range represents a contiguous range of T. +// +// +stateify savable +type AddrRange struct { + // Start is the inclusive start of the range. + Start Addr + + // End is the exclusive end of the range. + End Addr +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r AddrRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r AddrRange) Length() Addr { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r AddrRange) Contains(x Addr) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r AddrRange) Overlaps(r2 AddrRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r AddrRange) IsSupersetOf(r2 AddrRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r AddrRange) Intersect(r2 AddrRange) AddrRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r AddrRange) CanSplitAt(x Addr) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/usermem/addr_range_seq_test.go b/pkg/usermem/addr_range_seq_test.go deleted file mode 100644 index 82f735026..000000000 --- a/pkg/usermem/addr_range_seq_test.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "testing" -) - -var addrRangeSeqTests = []struct { - desc string - ranges []AddrRange -}{ - { - desc: "Empty sequence", - }, - { - desc: "Single empty AddrRange", - ranges: []AddrRange{ - {0x10, 0x10}, - }, - }, - { - desc: "Single non-empty AddrRange of length 1", - ranges: []AddrRange{ - {0x10, 0x11}, - }, - }, - { - desc: "Single non-empty AddrRange of length 2", - ranges: []AddrRange{ - {0x10, 0x12}, - }, - }, - { - desc: "Multiple non-empty AddrRanges", - ranges: []AddrRange{ - {0x10, 0x11}, - {0x20, 0x22}, - }, - }, - { - desc: "Multiple AddrRanges including empty AddrRanges", - ranges: []AddrRange{ - {0x10, 0x10}, - {0x20, 0x20}, - {0x30, 0x33}, - {0x40, 0x44}, - {0x50, 0x50}, - {0x60, 0x60}, - {0x70, 0x77}, - {0x80, 0x88}, - {0x90, 0x90}, - {0xa0, 0xa0}, - }, - }, -} - -func testAddrRangeSeqEqualityWithTailIteration(t *testing.T, ars AddrRangeSeq, wantRanges []AddrRange) { - var wantLen int64 - for _, ar := range wantRanges { - wantLen += int64(ar.Length()) - } - - var i int - for !ars.IsEmpty() { - if gotLen := ars.NumBytes(); gotLen != wantLen { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) - } - if gotN, wantN := ars.NumRanges(), len(wantRanges)-i; gotN != wantN { - t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted %d", i, ars, gotN, wantN) - } - got := ars.Head() - if i >= len(wantRanges) { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got) - } else if want := wantRanges[i]; got != want { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) - } - ars = ars.Tail() - wantLen -= int64(got.Length()) - i++ - } - if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) - } - if gotN := ars.NumRanges(); gotN != 0 { - t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted 0", i, ars, gotN) - } -} - -func TestAddrRangeSeqTailIteration(t *testing.T) { - for _, test := range addrRangeSeqTests { - t.Run(test.desc, func(t *testing.T) { - testAddrRangeSeqEqualityWithTailIteration(t, AddrRangeSeqFromSlice(test.ranges), test.ranges) - }) - } -} - -func TestAddrRangeSeqDropFirstEmpty(t *testing.T) { - var ars AddrRangeSeq - if got, want := ars.DropFirst(1), ars; got != want { - t.Errorf("%v.DropFirst(1): got %v, wanted %v", ars, got, want) - } -} - -func TestAddrRangeSeqDropSingleByteIteration(t *testing.T) { - // Tests AddrRangeSeq iteration using Head/DropFirst, simulating - // I/O-per-AddrRange. - for _, test := range addrRangeSeqTests { - t.Run(test.desc, func(t *testing.T) { - // Figure out what AddrRanges we expect to see. - var wantLen int64 - var wantRanges []AddrRange - for _, ar := range test.ranges { - wantLen += int64(ar.Length()) - wantRanges = append(wantRanges, ar) - if ar.Length() == 0 { - // We "do" 0 bytes of I/O and then call DropFirst(0), - // advancing to the next AddrRange. - continue - } - // Otherwise we "do" 1 byte of I/O and then call DropFirst(1), - // advancing the AddrRange by 1 byte, or to the next AddrRange - // if this one is exhausted. - for ar.Start++; ar.Length() != 0; ar.Start++ { - wantRanges = append(wantRanges, ar) - } - } - t.Logf("Expected AddrRanges: %s (%d bytes)", wantRanges, wantLen) - - ars := AddrRangeSeqFromSlice(test.ranges) - var i int - for !ars.IsEmpty() { - if gotLen := ars.NumBytes(); gotLen != wantLen { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) - } - got := ars.Head() - if i >= len(wantRanges) { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got) - } else if want := wantRanges[i]; got != want { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) - } - if got.Length() == 0 { - ars = ars.DropFirst(0) - } else { - ars = ars.DropFirst(1) - wantLen-- - } - i++ - } - if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) - } - }) - } -} - -func TestAddrRangeSeqTakeFirstEmpty(t *testing.T) { - var ars AddrRangeSeq - if got, want := ars.TakeFirst(1), ars; got != want { - t.Errorf("%v.TakeFirst(1): got %v, wanted %v", ars, got, want) - } -} - -func TestAddrRangeSeqTakeFirst(t *testing.T) { - ranges := []AddrRange{ - {0x10, 0x11}, - {0x20, 0x22}, - {0x30, 0x30}, - {0x40, 0x44}, - {0x50, 0x55}, - {0x60, 0x60}, - {0x70, 0x77}, - } - ars := AddrRangeSeqFromSlice(ranges).TakeFirst(5) - want := []AddrRange{ - {0x10, 0x11}, // +1 byte (total 1 byte), not truncated - {0x20, 0x22}, // +2 bytes (total 3 bytes), not truncated - {0x30, 0x30}, // +0 bytes (total 3 bytes), no change - {0x40, 0x42}, // +2 bytes (total 5 bytes), partially truncated - {0x50, 0x50}, // +0 bytes (total 5 bytes), fully truncated - {0x60, 0x60}, // +0 bytes (total 5 bytes), "fully truncated" (no change) - {0x70, 0x70}, // +0 bytes (total 5 bytes), fully truncated - } - testAddrRangeSeqEqualityWithTailIteration(t, ars, want) -} diff --git a/pkg/usermem/addr_range_seq_unsafe.go b/pkg/usermem/addr_range_seq_unsafe.go index c09337c15..c09337c15 100644..100755 --- a/pkg/usermem/addr_range_seq_unsafe.go +++ b/pkg/usermem/addr_range_seq_unsafe.go diff --git a/pkg/usermem/bytes_io.go b/pkg/usermem/bytes_io.go index e177d30eb..e177d30eb 100644..100755 --- a/pkg/usermem/bytes_io.go +++ b/pkg/usermem/bytes_io.go diff --git a/pkg/usermem/bytes_io_unsafe.go b/pkg/usermem/bytes_io_unsafe.go index 20de5037d..20de5037d 100644..100755 --- a/pkg/usermem/bytes_io_unsafe.go +++ b/pkg/usermem/bytes_io_unsafe.go diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index d2f4403b0..d2f4403b0 100644..100755 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go diff --git a/pkg/usermem/usermem_arm64.go b/pkg/usermem/usermem_arm64.go index fdfc30a66..fdfc30a66 100644..100755 --- a/pkg/usermem/usermem_arm64.go +++ b/pkg/usermem/usermem_arm64.go diff --git a/pkg/usermem/usermem_arm64_state_autogen.go b/pkg/usermem/usermem_arm64_state_autogen.go new file mode 100755 index 000000000..d7c365e5d --- /dev/null +++ b/pkg/usermem/usermem_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package usermem diff --git a/pkg/usermem/usermem_state_autogen.go b/pkg/usermem/usermem_state_autogen.go new file mode 100755 index 000000000..a93af5eaa --- /dev/null +++ b/pkg/usermem/usermem_state_autogen.go @@ -0,0 +1,51 @@ +// automatically generated by stateify. + +// +build amd64 i386 + +package usermem + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *AccessType) beforeSave() {} +func (x *AccessType) save(m state.Map) { + x.beforeSave() + m.Save("Read", &x.Read) + m.Save("Write", &x.Write) + m.Save("Execute", &x.Execute) +} + +func (x *AccessType) afterLoad() {} +func (x *AccessType) load(m state.Map) { + m.Load("Read", &x.Read) + m.Load("Write", &x.Write) + m.Load("Execute", &x.Execute) +} + +func (x *Addr) save(m state.Map) { + m.SaveValue("", (uintptr)(*x)) +} + +func (x *Addr) load(m state.Map) { + m.LoadValue("", new(uintptr), func(y interface{}) { *x = (Addr)(y.(uintptr)) }) +} + +func (x *AddrRange) beforeSave() {} +func (x *AddrRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *AddrRange) afterLoad() {} +func (x *AddrRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func init() { + state.Register("pkg/usermem.AccessType", (*AccessType)(nil), state.Fns{Save: (*AccessType).save, Load: (*AccessType).load}) + state.Register("pkg/usermem.Addr", (*Addr)(nil), state.Fns{Save: (*Addr).save, Load: (*Addr).load}) + state.Register("pkg/usermem.AddrRange", (*AddrRange)(nil), state.Fns{Save: (*AddrRange).save, Load: (*AddrRange).load}) +} diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go deleted file mode 100644 index bf3c5df2b..000000000 --- a/pkg/usermem/usermem_test.go +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "bytes" - "encoding/binary" - "fmt" - "reflect" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/syserror" -) - -// newContext returns a context.Context that we can use in these tests (we -// can't use contexttest because it depends on usermem). -func newContext() context.Context { - return context.Background() -} - -func newBytesIOString(s string) *BytesIO { - return &BytesIO{[]byte(s)} -} - -func TestBytesIOCopyOutSuccess(t *testing.T) { - b := newBytesIOString("ABCDE") - n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFailure(t *testing.T) { - b := newBytesIOString("ABC") - n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) - if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInSuccess(t *testing.T) { - b := newBytesIOString("AfooE") - var dst [3]byte - n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInFailure(t *testing.T) { - b := newBytesIOString("Afo") - var dst [3]byte - n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) - if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestBytesIOZeroOutSuccess(t *testing.T) { - b := newBytesIOString("ABCD") - n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{}) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOZeroOutFailure(t *testing.T) { - b := newBytesIOString("ABC") - n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{}) - if wantN, wantErr := int64(2), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFromSuccess(t *testing.T) { - b := newBytesIOString("ABCDEFGH") - n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 4, End: 7}, - {Start: 1, End: 4}, - }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{}) - if wantN := int64(6); n != wantN || err != nil { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFromFailure(t *testing.T) { - b := newBytesIOString("ABCDE") - n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 1, End: 4}, - {Start: 4, End: 7}, - }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{}) - if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInToSuccess(t *testing.T) { - b := newBytesIOString("AfoobarH") - var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 4, End: 7}, - {Start: 1, End: 4}, - }), safemem.FromIOWriter{&dst}, IOOpts{}) - if wantN := int64(6); n != wantN || err != nil { - t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) { - t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInToFailure(t *testing.T) { - b := newBytesIOString("Afoob") - var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 1, End: 4}, - {Start: 4, End: 7}, - }), safemem.FromIOWriter{&dst}, IOOpts{}) - if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { - t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) - } -} - -type testStruct struct { - Int8 int8 - Uint8 uint8 - Int16 int16 - Uint16 uint16 - Int32 int32 - Uint32 uint32 - Int64 int64 - Uint64 uint64 -} - -func TestCopyObject(t *testing.T) { - wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8} - wantN := binary.Size(wantObj) - b := &BytesIO{make([]byte, wantN)} - ctx := newContext() - if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil { - t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - var gotObj testStruct - if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil { - t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if gotObj != wantObj { - t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj) - } -} - -func TestCopyStringInShort(t *testing.T) { - // Tests for string length <= copyStringIncrement. - want := strings.Repeat("A", copyStringIncrement-2) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInLong(t *testing.T) { - // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen - // (requiring multiple calls to IO.CopyIn()). - want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInVeryLong(t *testing.T) { - // Tests for string length > copyStringMaxInitBufLen (requiring buffer - // reallocation). - want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInNoTerminatingZeroByte(t *testing.T) { - want := strings.Repeat("A", copyStringIncrement-1) - got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{}) - if wantErr := syserror.EFAULT; got != want || err != wantErr { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) - } -} - -func TestCopyStringInTruncatedByMaxlen(t *testing.T) { - got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{}) - if want, wantErr := strings.Repeat("A", 5), syserror.ENAMETOOLONG; got != want || err != wantErr { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) - } -} - -func TestCopyInt32StringsInVec(t *testing.T) { - for _, test := range []struct { - str string - n int - initial []int32 - final []int32 - }{ - { - str: "100 200", - n: len("100 200"), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - { - // Fewer values ok - str: "100", - n: len("100"), - initial: []int32{1, 2}, - final: []int32{100, 2}, - }, - { - // Extra values ok - str: "100 200 300", - n: len("100 200 "), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - { - // Leading and trailing whitespace ok - str: " 100\t200\n", - n: len(" 100\t200\n"), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - } { - t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { - src := BytesIOSequence([]byte(test.str)) - dsts := append([]int32(nil), test.initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n) - } - if !reflect.DeepEqual(dsts, test.final) { - t.Errorf("dsts: got %v, wanted %v", dsts, test.final) - } - }) - } -} - -func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { - for _, s := range []string{"", "\n", "a123"} { - t.Run(fmt.Sprintf("%q", s), func(t *testing.T) { - src := BytesIOSequence([]byte(s)) - initial := []int32{1, 2} - dsts := append([]int32(nil), initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL) - } - if !reflect.DeepEqual(dsts, initial) { - t.Errorf("dsts: got %v, wanted %v", dsts, initial) - } - }) - } -} - -func TestIOSequenceCopyOut(t *testing.T) { - buf := []byte("ABCD") - s := BytesIOSequence(buf) - - // CopyOut limited by len(src). - n, err := s.CopyOut(newContext(), []byte("fo")) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foCD"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // CopyOut limited by s.NumBytes(). - n, err = s.CopyOut(newContext(), []byte("obar")) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foob"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceCopyIn(t *testing.T) { - s := BytesIOSequence([]byte("foob")) - dst := []byte("ABCDEF") - - // CopyIn limited by len(dst). - n, err := s.CopyIn(newContext(), dst[:2]) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foCDEF"); !bytes.Equal(dst, want) { - t.Errorf("dst: got %q, wanted %q", dst, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // CopyIn limited by s.Remaining(). - n, err = s.CopyIn(newContext(), dst[2:]) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foobEF"); !bytes.Equal(dst, want) { - t.Errorf("dst: got %q, wanted %q", dst, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceZeroOut(t *testing.T) { - buf := []byte("ABCD") - s := BytesIOSequence(buf) - - // ZeroOut limited by toZero. - n, err := s.ZeroOut(newContext(), 2) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // ZeroOut limited by s.NumBytes(). - n, err = s.ZeroOut(newContext(), 4) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceTakeFirst(t *testing.T) { - s := BytesIOSequence([]byte("foobar")) - if got, want := s.NumBytes(), int64(6); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - s = s.TakeFirst(3) - if got, want := s.NumBytes(), int64(3); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // TakeFirst(n) where n > s.NumBytes() is a no-op. - s = s.TakeFirst(9) - if got, want := s.NumBytes(), int64(3); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - var dst [3]byte - n, err := s.CopyIn(newContext(), dst[:]) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } - s = s.DropFirst(3) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} diff --git a/pkg/usermem/usermem_x86.go b/pkg/usermem/usermem_x86.go index 8059b72d2..8059b72d2 100644..100755 --- a/pkg/usermem/usermem_x86.go +++ b/pkg/usermem/usermem_x86.go diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD deleted file mode 100644 index 852480a09..000000000 --- a/pkg/waiter/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "waiter", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Entry", - "Linker": "*Entry", - }, -) - -go_library( - name = "waiter", - srcs = [ - "waiter.go", - "waiter_list.go", - ], - visibility = ["//visibility:public"], - deps = ["//pkg/sync"], -) - -go_test( - name = "waiter_test", - size = "small", - srcs = [ - "waiter_test.go", - ], - library = ":waiter", -) diff --git a/pkg/waiter/waiter_list.go b/pkg/waiter/waiter_list.go new file mode 100755 index 000000000..07950faa4 --- /dev/null +++ b/pkg/waiter/waiter_list.go @@ -0,0 +1,186 @@ +package waiter + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *Entry) *Entry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *Entry + tail *Entry +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *waiterList) Front() *Entry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *waiterList) Back() *Entry { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *waiterList) PushFront(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *waiterList) PushBack(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *waiterList) InsertAfter(b, e *Entry) { + bLinker := waiterElementMapper{}.linkerFor(b) + eLinker := waiterElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *waiterList) InsertBefore(a, e *Entry) { + aLinker := waiterElementMapper{}.linkerFor(a) + eLinker := waiterElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *waiterList) Remove(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *Entry + prev *Entry +} + +// Next returns the entry that follows e in the list. +func (e *waiterEntry) Next() *Entry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *waiterEntry) Prev() *Entry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *waiterEntry) SetNext(elem *Entry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *waiterEntry) SetPrev(elem *Entry) { + e.prev = elem +} diff --git a/pkg/waiter/waiter_state_autogen.go b/pkg/waiter/waiter_state_autogen.go new file mode 100755 index 000000000..93acec042 --- /dev/null +++ b/pkg/waiter/waiter_state_autogen.go @@ -0,0 +1,69 @@ +// automatically generated by stateify. + +package waiter + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Entry) beforeSave() {} +func (x *Entry) save(m state.Map) { + x.beforeSave() + m.Save("Context", &x.Context) + m.Save("Callback", &x.Callback) + m.Save("mask", &x.mask) + m.Save("waiterEntry", &x.waiterEntry) +} + +func (x *Entry) afterLoad() {} +func (x *Entry) load(m state.Map) { + m.Load("Context", &x.Context) + m.Load("Callback", &x.Callback) + m.Load("mask", &x.mask) + m.Load("waiterEntry", &x.waiterEntry) +} + +func (x *Queue) beforeSave() {} +func (x *Queue) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.list) { + m.Failf("list is %v, expected zero", x.list) + } +} + +func (x *Queue) afterLoad() {} +func (x *Queue) load(m state.Map) { +} + +func (x *waiterList) beforeSave() {} +func (x *waiterList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *waiterList) afterLoad() {} +func (x *waiterList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *waiterEntry) beforeSave() {} +func (x *waiterEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *waiterEntry) afterLoad() {} +func (x *waiterEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/waiter.Entry", (*Entry)(nil), state.Fns{Save: (*Entry).save, Load: (*Entry).load}) + state.Register("pkg/waiter.Queue", (*Queue)(nil), state.Fns{Save: (*Queue).save, Load: (*Queue).load}) + state.Register("pkg/waiter.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load}) + state.Register("pkg/waiter.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load}) +} diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go deleted file mode 100644 index c1b94a4f3..000000000 --- a/pkg/waiter/waiter_test.go +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waiter - -import ( - "sync/atomic" - "testing" -) - -type callbackStub struct { - f func(e *Entry) -} - -// Callback implements EntryCallback.Callback. -func (c *callbackStub) Callback(e *Entry) { - c.f(e) -} - -func TestEmptyQueue(t *testing.T) { - var q Queue - - // Notify the zero-value of a queue. - q.Notify(EventIn) - - // Register then unregister a waiter, then notify the queue. - cnt := 0 - e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} - q.EventRegister(&e, EventIn) - q.EventUnregister(&e) - q.Notify(EventIn) - if cnt != 0 { - t.Errorf("Callback was called when it shouldn't have been") - } -} - -func TestMask(t *testing.T) { - // Register a waiter. - var q Queue - var cnt int - e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} - q.EventRegister(&e, EventIn|EventErr) - - // Notify with an overlapping mask. - cnt = 0 - q.Notify(EventIn | EventOut) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a subset mask. - cnt = 0 - q.Notify(EventIn) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a superset mask. - cnt = 0 - q.Notify(EventIn | EventErr | EventOut) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with the exact same mask. - cnt = 0 - q.Notify(EventIn | EventErr) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a disjoint mask. - cnt = 0 - q.Notify(EventOut | EventHUp) - if cnt != 0 { - t.Errorf("Callback was called when it shouldn't have been") - } -} - -func TestConcurrentRegistration(t *testing.T) { - var q Queue - var cnt int - const concurrency = 1000 - - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - ch3 := make(chan struct{}) - - // Create goroutines that will all register/unregister concurrently. - for i := 0; i < concurrency; i++ { - go func() { - var e Entry - e.Callback = &callbackStub{func(entry *Entry) { - cnt++ - if entry != &e { - t.Errorf("entry = %p, want %p", entry, &e) - } - }} - - // Wait for notification, then register. - <-ch1 - q.EventRegister(&e, EventIn|EventErr) - - // Tell main goroutine that we're done registering. - ch2 <- struct{}{} - - // Wait for notification, then unregister. - <-ch3 - q.EventUnregister(&e) - - // Tell main goroutine that we're done unregistering. - ch2 <- struct{}{} - }() - } - - // Let the goroutines register. - close(ch1) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Issue a notification. - q.Notify(EventIn) - if cnt != concurrency { - t.Errorf("cnt = %d, want %d", cnt, concurrency) - } - - // Let the goroutine unregister. - close(ch3) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Issue a notification. - q.Notify(EventIn) - if cnt != concurrency { - t.Errorf("cnt = %d, want %d", cnt, concurrency) - } -} - -func TestConcurrentNotification(t *testing.T) { - var q Queue - var cnt int32 - const concurrency = 1000 - const waiterCount = 1000 - - // Register waiters. - for i := 0; i < waiterCount; i++ { - var e Entry - e.Callback = &callbackStub{func(entry *Entry) { - atomic.AddInt32(&cnt, 1) - if entry != &e { - t.Errorf("entry = %p, want %p", entry, &e) - } - }} - - q.EventRegister(&e, EventIn|EventErr) - } - - // Launch notifiers. - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - for i := 0; i < concurrency; i++ { - go func() { - <-ch1 - q.Notify(EventIn) - ch2 <- struct{}{} - }() - } - - // Let notifiers go. - close(ch1) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Check the count. - if cnt != concurrency*waiterCount { - t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount) - } -} diff --git a/runsc/BUILD b/runsc/BUILD deleted file mode 100644 index 757f6d44c..000000000 --- a/runsc/BUILD +++ /dev/null @@ -1,123 +0,0 @@ -load("//tools:defs.bzl", "go_binary", "pkg_deb", "pkg_tar") - -package(licenses = ["notice"]) - -go_binary( - name = "runsc", - srcs = [ - "main.go", - "version.go", - ], - pure = True, - visibility = [ - "//visibility:public", - ], - x_defs = {"main.version": "{STABLE_VERSION}"}, - deps = [ - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/platform", - "//runsc/boot", - "//runsc/cmd", - "//runsc/flag", - "//runsc/specutils", - "@com_github_google_subcommands//:go_default_library", - ], -) - -# The runsc-race target is a race-compatible BUILD target. This must be built -# via: bazel build --features=race :runsc-race -# -# This is neccessary because the race feature must apply to all dependencies -# due a bug in gazelle file selection. The pure attribute must be off because -# the race detector requires linking with non-Go components, although we still -# require a static binary. -# -# Note that in the future this might be convertible to a compatible target by -# using the pure and static attributes within a select function, but select is -# not currently compatible with string attributes [1]. -# -# [1] https://github.com/bazelbuild/bazel/issues/1698 -go_binary( - name = "runsc-race", - srcs = [ - "main.go", - "version.go", - ], - static = True, - visibility = [ - "//visibility:public", - ], - x_defs = {"main.version": "{STABLE_VERSION}"}, - deps = [ - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/platform", - "//runsc/boot", - "//runsc/cmd", - "//runsc/flag", - "//runsc/specutils", - "@com_github_google_subcommands//:go_default_library", - ], -) - -pkg_tar( - name = "runsc-bin", - srcs = [":runsc"], - mode = "0755", - package_dir = "/usr/bin", - strip_prefix = "/runsc/linux_amd64_pure_stripped", -) - -pkg_tar( - name = "debian-data", - extension = "tar.gz", - deps = [ - ":runsc-bin", - ], -) - -genrule( - name = "deb-version", - # Note that runsc must appear in the srcs parameter and not the tools - # parameter, otherwise it will not be stamped. This is reasonable, as tools - # may be encoded differently in the build graph (cached more aggressively - # because they are assumes to be hermetic). - srcs = [":runsc"], - outs = ["version.txt"], - # Note that the little dance here is necessary because files in the $(SRCS) - # attribute are not executable by default, and we can't touch in place. - cmd = "cp $(location :runsc) $(@D)/runsc && \ - chmod a+x $(@D)/runsc && \ - $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \ - rm -f $(@D)/runsc", - stamp = 1, -) - -pkg_deb( - name = "runsc-debian", - architecture = "amd64", - data = ":debian-data", - # Note that the description_file will be flatten (all newlines removed), - # and therefore it is kept to a simple one-line description. The expected - # format for debian packages is "short summary\nLonger explanation of - # tool." and this is impossible with the flattening. - description_file = "debian/description", - homepage = "https://gvisor.dev/", - maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>", - package = "runsc", - postinst = "debian/postinst.sh", - version_file = ":version.txt", - visibility = [ - "//visibility:public", - ], -) - -sh_test( - name = "version_test", - size = "small", - srcs = ["version_test.sh"], - args = ["$(location :runsc)"], - data = [":runsc"], - tags = ["noguitar"], -) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD deleted file mode 100644 index 26f68fe3d..000000000 --- a/runsc/boot/BUILD +++ /dev/null @@ -1,123 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "boot", - srcs = [ - "compat.go", - "compat_amd64.go", - "compat_arm64.go", - "config.go", - "controller.go", - "debug.go", - "events.go", - "fds.go", - "fs.go", - "limits.go", - "loader.go", - "loader_amd64.go", - "loader_arm64.go", - "network.go", - "strace.go", - "user.go", - ], - visibility = [ - "//runsc:__subpackages__", - "//test:__subpackages__", - ], - deps = [ - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/context", - "//pkg/control/server", - "//pkg/cpuid", - "//pkg/eventchannel", - "//pkg/log", - "//pkg/memutil", - "//pkg/rand", - "//pkg/refs", - "//pkg/sentry/arch", - "//pkg/sentry/arch:registers_go_proto", - "//pkg/sentry/control", - "//pkg/sentry/fs", - "//pkg/sentry/fs/dev", - "//pkg/sentry/fs/gofer", - "//pkg/sentry/fs/host", - "//pkg/sentry/fs/proc", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fs/sys", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/fs/tty", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel:uncaught_signal_go_proto", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/sighandling", - "//pkg/sentry/socket/hostinet", - "//pkg/sentry/socket/netlink", - "//pkg/sentry/socket/netlink/route", - "//pkg/sentry/socket/netlink/uevent", - "//pkg/sentry/socket/netstack", - "//pkg/sentry/socket/unix", - "//pkg/sentry/state", - "//pkg/sentry/strace", - "//pkg/sentry/syscalls/linux", - "//pkg/sentry/syscalls/linux/vfs2", - "//pkg/sentry/time", - "//pkg/sentry/unimpl:unimplemented_syscall_go_proto", - "//pkg/sentry/usage", - "//pkg/sentry/watchdog", - "//pkg/sync", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/urpc", - "//pkg/usermem", - "//runsc/boot/filter", - "//runsc/boot/platforms", - "//runsc/boot/pprof", - "//runsc/specutils", - "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "boot_test", - size = "small", - srcs = [ - "compat_test.go", - "fs_test.go", - "loader_test.go", - "user_test.go", - ], - library = ":boot", - deps = [ - "//pkg/control/server", - "//pkg/log", - "//pkg/p9", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sync", - "//pkg/unet", - "//runsc/fsgofer", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - ], -) diff --git a/runsc/boot/boot_amd64_state_autogen.go b/runsc/boot/boot_amd64_state_autogen.go new file mode 100755 index 000000000..4b7a38bb8 --- /dev/null +++ b/runsc/boot/boot_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package boot diff --git a/runsc/boot/boot_arm64_state_autogen.go b/runsc/boot/boot_arm64_state_autogen.go new file mode 100755 index 000000000..b94cf6df2 --- /dev/null +++ b/runsc/boot/boot_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package boot diff --git a/runsc/boot/boot_state_autogen.go b/runsc/boot/boot_state_autogen.go new file mode 100755 index 000000000..167d1cf02 --- /dev/null +++ b/runsc/boot/boot_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package boot + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *sandboxNetstackCreator) beforeSave() {} +func (x *sandboxNetstackCreator) save(m state.Map) { + x.beforeSave() + m.Save("clock", &x.clock) + m.Save("uniqueID", &x.uniqueID) +} + +func (x *sandboxNetstackCreator) afterLoad() {} +func (x *sandboxNetstackCreator) load(m state.Map) { + m.Load("clock", &x.clock) + m.Load("uniqueID", &x.uniqueID) +} + +func init() { + state.Register("runsc/boot.sandboxNetstackCreator", (*sandboxNetstackCreator)(nil), state.Fns{Save: (*sandboxNetstackCreator).save, Load: (*sandboxNetstackCreator).load}) +} diff --git a/runsc/boot/compat_arm64.go b/runsc/boot/compat_arm64.go index f784cd237..f784cd237 100644..100755 --- a/runsc/boot/compat_arm64.go +++ b/runsc/boot/compat_arm64.go diff --git a/runsc/boot/compat_test.go b/runsc/boot/compat_test.go deleted file mode 100644 index 839c5303b..000000000 --- a/runsc/boot/compat_test.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package boot - -import ( - "testing" -) - -func TestOnceTracker(t *testing.T) { - o := onceTracker{} - if !o.shouldReport(nil) { - t.Error("first call to checkAndMark, got: false, want: true") - } - o.onReported(nil) - for i := 0; i < 2; i++ { - if o.shouldReport(nil) { - t.Error("after first call to checkAndMark, got: true, want: false") - } - } -} - -func TestArgsTracker(t *testing.T) { - for _, tc := range []struct { - name string - idx []int - arg1_1 uint64 - arg1_2 uint64 - arg2_1 uint64 - arg2_2 uint64 - want bool - }{ - {name: "same arg1", idx: []int{0}, arg1_1: 123, arg1_2: 123, want: false}, - {name: "same arg2", idx: []int{1}, arg2_1: 123, arg2_2: 123, want: false}, - {name: "diff arg1", idx: []int{0}, arg1_1: 123, arg1_2: 321, want: true}, - {name: "diff arg2", idx: []int{1}, arg2_1: 123, arg2_2: 321, want: true}, - {name: "cmd is uint32", idx: []int{0}, arg2_1: 0xdead00000123, arg2_2: 0xbeef00000123, want: false}, - {name: "same 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 123, arg1_2: 321, want: false}, - {name: "diff 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 789, arg1_2: 987, want: true}, - } { - t.Run(tc.name, func(t *testing.T) { - c := newArgsTracker(tc.idx...) - regs := newRegs() - setArgVal(0, tc.arg1_1, regs) - setArgVal(1, tc.arg2_1, regs) - if !c.shouldReport(regs) { - t.Error("first call to shouldReport, got: false, want: true") - } - c.onReported(regs) - - setArgVal(0, tc.arg1_2, regs) - setArgVal(1, tc.arg2_2, regs) - if got := c.shouldReport(regs); tc.want != got { - t.Errorf("second call to shouldReport, got: %t, want: %t", got, tc.want) - } - }) - } -} - -func TestArgsTrackerLimit(t *testing.T) { - c := newArgsTracker(0, 1) - for i := 0; i < reportLimit; i++ { - regs := newRegs() - setArgVal(0, 123, regs) - setArgVal(1, uint64(i), regs) - if !c.shouldReport(regs) { - t.Error("shouldReport before limit was reached, got: false, want: true") - } - c.onReported(regs) - } - - // Should hit the count limit now. - regs := newRegs() - setArgVal(0, 123, regs) - setArgVal(1, 123456, regs) - if c.shouldReport(regs) { - t.Error("shouldReport after limit was reached, got: true, want: false") - } -} diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD deleted file mode 100644 index ed18f0047..000000000 --- a/runsc/boot/filter/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "filter", - srcs = [ - "config.go", - "config_amd64.go", - "config_arm64.go", - "config_profile.go", - "extra_filters.go", - "extra_filters_msan.go", - "extra_filters_race.go", - "filter.go", - ], - visibility = [ - "//runsc/boot:__subpackages__", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/seccomp", - "//pkg/sentry/platform", - "//pkg/tcpip/link/fdbased", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go index 5335ff82c..5335ff82c 100644..100755 --- a/runsc/boot/filter/config_amd64.go +++ b/runsc/boot/filter/config_amd64.go diff --git a/runsc/boot/filter/config_arm64.go b/runsc/boot/filter/config_arm64.go index 7fa9bbda3..7fa9bbda3 100644..100755 --- a/runsc/boot/filter/config_arm64.go +++ b/runsc/boot/filter/config_arm64.go diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go index 194952a7b..194952a7b 100644..100755 --- a/runsc/boot/filter/config_profile.go +++ b/runsc/boot/filter/config_profile.go diff --git a/runsc/boot/filter/filter_amd64_state_autogen.go b/runsc/boot/filter/filter_amd64_state_autogen.go new file mode 100755 index 000000000..0f27e5568 --- /dev/null +++ b/runsc/boot/filter/filter_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package filter diff --git a/runsc/boot/filter/filter_arm64_state_autogen.go b/runsc/boot/filter/filter_arm64_state_autogen.go new file mode 100755 index 000000000..e87cf5af7 --- /dev/null +++ b/runsc/boot/filter/filter_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package filter diff --git a/runsc/boot/filter/filter_state_autogen.go b/runsc/boot/filter/filter_state_autogen.go new file mode 100755 index 000000000..545d526ae --- /dev/null +++ b/runsc/boot/filter/filter_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +// +build !msan,!race +// +build msan +// +build race + +package filter diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go deleted file mode 100644 index 912037075..000000000 --- a/runsc/boot/fs_test.go +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package boot - -import ( - "reflect" - "strings" - "testing" - - specs "github.com/opencontainers/runtime-spec/specs-go" -) - -func TestPodMountHintsHappy(t *testing.T) { - spec := &specs.Spec{ - Annotations: map[string]string{ - MountPrefix + "mount1.source": "foo", - MountPrefix + "mount1.type": "tmpfs", - MountPrefix + "mount1.share": "pod", - - MountPrefix + "mount2.source": "bar", - MountPrefix + "mount2.type": "bind", - MountPrefix + "mount2.share": "container", - MountPrefix + "mount2.options": "rw,private", - }, - } - podHints, err := newPodMountHints(spec) - if err != nil { - t.Fatalf("newPodMountHints failed: %v", err) - } - - // Check that fields were set correctly. - mount1 := podHints.mounts["mount1"] - if want := "mount1"; want != mount1.name { - t.Errorf("mount1 name, want: %q, got: %q", want, mount1.name) - } - if want := "foo"; want != mount1.mount.Source { - t.Errorf("mount1 source, want: %q, got: %q", want, mount1.mount.Source) - } - if want := "tmpfs"; want != mount1.mount.Type { - t.Errorf("mount1 type, want: %q, got: %q", want, mount1.mount.Type) - } - if want := pod; want != mount1.share { - t.Errorf("mount1 type, want: %q, got: %q", want, mount1.share) - } - if want := []string(nil); !reflect.DeepEqual(want, mount1.mount.Options) { - t.Errorf("mount1 type, want: %q, got: %q", want, mount1.mount.Options) - } - - mount2 := podHints.mounts["mount2"] - if want := "mount2"; want != mount2.name { - t.Errorf("mount2 name, want: %q, got: %q", want, mount2.name) - } - if want := "bar"; want != mount2.mount.Source { - t.Errorf("mount2 source, want: %q, got: %q", want, mount2.mount.Source) - } - if want := "bind"; want != mount2.mount.Type { - t.Errorf("mount2 type, want: %q, got: %q", want, mount2.mount.Type) - } - if want := container; want != mount2.share { - t.Errorf("mount2 type, want: %q, got: %q", want, mount2.share) - } - if want := []string{"private", "rw"}; !reflect.DeepEqual(want, mount2.mount.Options) { - t.Errorf("mount2 type, want: %q, got: %q", want, mount2.mount.Options) - } -} - -func TestPodMountHintsErrors(t *testing.T) { - for _, tst := range []struct { - name string - annotations map[string]string - error string - }{ - { - name: "too short", - annotations: map[string]string{ - MountPrefix + "mount1": "foo", - }, - error: "invalid mount annotation", - }, - { - name: "no name", - annotations: map[string]string{ - MountPrefix + ".source": "foo", - }, - error: "invalid mount name", - }, - { - name: "missing source", - annotations: map[string]string{ - MountPrefix + "mount1.type": "tmpfs", - MountPrefix + "mount1.share": "pod", - }, - error: "source field", - }, - { - name: "missing type", - annotations: map[string]string{ - MountPrefix + "mount1.source": "foo", - MountPrefix + "mount1.share": "pod", - }, - error: "type field", - }, - { - name: "missing share", - annotations: map[string]string{ - MountPrefix + "mount1.source": "foo", - MountPrefix + "mount1.type": "tmpfs", - }, - error: "share field", - }, - { - name: "invalid field name", - annotations: map[string]string{ - MountPrefix + "mount1.invalid": "foo", - }, - error: "invalid mount annotation", - }, - { - name: "invalid source", - annotations: map[string]string{ - MountPrefix + "mount1.source": "", - MountPrefix + "mount1.type": "tmpfs", - MountPrefix + "mount1.share": "pod", - }, - error: "source cannot be empty", - }, - { - name: "invalid type", - annotations: map[string]string{ - MountPrefix + "mount1.source": "foo", - MountPrefix + "mount1.type": "invalid-type", - MountPrefix + "mount1.share": "pod", - }, - error: "invalid type", - }, - { - name: "invalid share", - annotations: map[string]string{ - MountPrefix + "mount1.source": "foo", - MountPrefix + "mount1.type": "tmpfs", - MountPrefix + "mount1.share": "invalid-share", - }, - error: "invalid share", - }, - { - name: "invalid options", - annotations: map[string]string{ - MountPrefix + "mount1.source": "foo", - MountPrefix + "mount1.type": "tmpfs", - MountPrefix + "mount1.share": "pod", - MountPrefix + "mount1.options": "invalid-option", - }, - error: "unknown mount option", - }, - { - name: "duplicate source", - annotations: map[string]string{ - MountPrefix + "mount1.source": "foo", - MountPrefix + "mount1.type": "tmpfs", - MountPrefix + "mount1.share": "pod", - - MountPrefix + "mount2.source": "foo", - MountPrefix + "mount2.type": "bind", - MountPrefix + "mount2.share": "container", - }, - error: "have the same mount source", - }, - } { - t.Run(tst.name, func(t *testing.T) { - spec := &specs.Spec{Annotations: tst.annotations} - podHints, err := newPodMountHints(spec) - if err == nil || !strings.Contains(err.Error(), tst.error) { - t.Errorf("newPodMountHints invalid error, want: .*%s.*, got: %v", tst.error, err) - } - if podHints != nil { - t.Errorf("newPodMountHints must return nil on failure: %+v", podHints) - } - }) - } -} - -func TestGetMountAccessType(t *testing.T) { - const source = "foo" - for _, tst := range []struct { - name string - annotations map[string]string - want FileAccessType - }{ - { - name: "container=exclusive", - annotations: map[string]string{ - MountPrefix + "mount1.source": source, - MountPrefix + "mount1.type": "bind", - MountPrefix + "mount1.share": "container", - }, - want: FileAccessExclusive, - }, - { - name: "pod=shared", - annotations: map[string]string{ - MountPrefix + "mount1.source": source, - MountPrefix + "mount1.type": "bind", - MountPrefix + "mount1.share": "pod", - }, - want: FileAccessShared, - }, - { - name: "shared=shared", - annotations: map[string]string{ - MountPrefix + "mount1.source": source, - MountPrefix + "mount1.type": "bind", - MountPrefix + "mount1.share": "shared", - }, - want: FileAccessShared, - }, - { - name: "default=shared", - annotations: map[string]string{ - MountPrefix + "mount1.source": source + "mismatch", - MountPrefix + "mount1.type": "bind", - MountPrefix + "mount1.share": "container", - }, - want: FileAccessShared, - }, - } { - t.Run(tst.name, func(t *testing.T) { - spec := &specs.Spec{Annotations: tst.annotations} - podHints, err := newPodMountHints(spec) - if err != nil { - t.Fatalf("newPodMountHints failed: %v", err) - } - mounter := containerMounter{hints: podHints} - if got := mounter.getMountAccessType(specs.Mount{Source: source}); got != tst.want { - t.Errorf("getMountAccessType(), want: %v, got: %v", tst.want, got) - } - }) - } -} diff --git a/runsc/boot/loader_amd64.go b/runsc/boot/loader_amd64.go index b9669f2ac..b9669f2ac 100644..100755 --- a/runsc/boot/loader_amd64.go +++ b/runsc/boot/loader_amd64.go diff --git a/runsc/boot/loader_arm64.go b/runsc/boot/loader_arm64.go index cf64d28c8..cf64d28c8 100644..100755 --- a/runsc/boot/loader_arm64.go +++ b/runsc/boot/loader_arm64.go diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go deleted file mode 100644 index 44aa63196..000000000 --- a/runsc/boot/loader_test.go +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package boot - -import ( - "fmt" - "math/rand" - "os" - "reflect" - "syscall" - "testing" - "time" - - specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/pkg/control/server" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/unet" - "gvisor.dev/gvisor/runsc/fsgofer" -) - -func init() { - log.SetLevel(log.Debug) - rand.Seed(time.Now().UnixNano()) - if err := fsgofer.OpenProcSelfFD(); err != nil { - panic(err) - } -} - -func testConfig() *Config { - return &Config{ - RootDir: "unused_root_dir", - Network: NetworkNone, - DisableSeccomp: true, - Platform: "ptrace", - } -} - -// testSpec returns a simple spec that can be used in tests. -func testSpec() *specs.Spec { - return &specs.Spec{ - // The host filesystem root is the sandbox root. - Root: &specs.Root{ - Path: "/", - Readonly: true, - }, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - } -} - -// startGofer starts a new gofer routine serving 'root' path. It returns the -// sandbox side of the connection, and a function that when called will stop the -// gofer. -func startGofer(root string) (int, func(), error) { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) - if err != nil { - return 0, nil, err - } - sandboxEnd, goferEnd := fds[0], fds[1] - - socket, err := unet.NewSocket(goferEnd) - if err != nil { - syscall.Close(sandboxEnd) - syscall.Close(goferEnd) - return 0, nil, fmt.Errorf("error creating server on FD %d: %v", goferEnd, err) - } - at, err := fsgofer.NewAttachPoint(root, fsgofer.Config{ROMount: true}) - if err != nil { - return 0, nil, err - } - go func() { - s := p9.NewServer(at) - if err := s.Handle(socket); err != nil { - log.Infof("Gofer is stopping. FD: %d, err: %v\n", goferEnd, err) - } - }() - // Closing the gofer socket will stop the gofer and exit goroutine above. - cleanup := func() { - if err := socket.Close(); err != nil { - log.Warningf("Error closing gofer socket: %v", err) - } - } - return sandboxEnd, cleanup, nil -} - -func createLoader() (*Loader, func(), error) { - fd, err := server.CreateSocket(ControlSocketAddr(fmt.Sprintf("%010d", rand.Int())[:10])) - if err != nil { - return nil, nil, err - } - conf := testConfig() - spec := testSpec() - - sandEnd, cleanup, err := startGofer(spec.Root.Path) - if err != nil { - return nil, nil, err - } - - stdio := []int{int(os.Stdin.Fd()), int(os.Stdout.Fd()), int(os.Stderr.Fd())} - args := Args{ - ID: "foo", - Spec: spec, - Conf: conf, - ControllerFD: fd, - GoferFDs: []int{sandEnd}, - StdioFDs: stdio, - } - l, err := New(args) - if err != nil { - cleanup() - return nil, nil, err - } - return l, cleanup, nil -} - -// TestRun runs a simple application in a sandbox and checks that it succeeds. -func TestRun(t *testing.T) { - l, cleanup, err := createLoader() - if err != nil { - t.Fatalf("error creating loader: %v", err) - } - defer l.Destroy() - defer cleanup() - - // Start a goroutine to read the start chan result, otherwise Run will - // block forever. - var resultChanErr error - var wg sync.WaitGroup - wg.Add(1) - go func() { - resultChanErr = <-l.ctrl.manager.startResultChan - wg.Done() - }() - - // Run the container. - if err := l.Run(); err != nil { - t.Errorf("error running container: %v", err) - } - - // We should have not gotten an error on the startResultChan. - wg.Wait() - if resultChanErr != nil { - t.Errorf("error on startResultChan: %v", resultChanErr) - } - - // Wait for the application to exit. It should succeed. - if status := l.WaitExit(); status.Code != 0 || status.Signo != 0 { - t.Errorf("application exited with status %+v, want 0", status) - } -} - -// TestStartSignal tests that the controller Start message will cause -// WaitForStartSignal to return. -func TestStartSignal(t *testing.T) { - l, cleanup, err := createLoader() - if err != nil { - t.Fatalf("error creating loader: %v", err) - } - defer l.Destroy() - defer cleanup() - - // We aren't going to wait on this application, so the control server - // needs to be shut down manually. - defer l.ctrl.srv.Stop() - - // Start a goroutine that calls WaitForStartSignal and writes to a - // channel when it returns. - waitFinished := make(chan struct{}) - go func() { - l.WaitForStartSignal() - // Pretend that Run() executed and returned no error. - l.ctrl.manager.startResultChan <- nil - waitFinished <- struct{}{} - }() - - // Nothing has been written to the channel, so waitFinished should not - // return. Give it a little bit of time to make sure the goroutine has - // started. - select { - case <-waitFinished: - t.Errorf("WaitForStartSignal completed but it should not have") - case <-time.After(50 * time.Millisecond): - // OK. - } - - // Trigger the control server StartRoot method. - cid := "foo" - if err := l.ctrl.manager.StartRoot(&cid, nil); err != nil { - t.Errorf("error calling StartRoot: %v", err) - } - - // Now WaitForStartSignal should return (within a short amount of - // time). - select { - case <-waitFinished: - // OK. - case <-time.After(50 * time.Millisecond): - t.Errorf("WaitForStartSignal did not complete but it should have") - } - -} - -// Test that MountNamespace can be created with various specs. -func TestCreateMountNamespace(t *testing.T) { - testCases := []struct { - name string - // Spec that will be used to create the mount manager. Note - // that we can't mount procfs without a kernel, so each spec - // MUST contain something other than procfs mounted at /proc. - spec specs.Spec - // Paths that are expected to exist in the resulting fs. - expectedPaths []string - }{ - { - // Only proc. - name: "only proc mount", - spec: specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/proc", - Type: "tmpfs", - }, - }, - }, - // /proc, /dev, and /sys should always be mounted. - expectedPaths: []string{"/proc", "/dev", "/sys"}, - }, - { - // Mount at a deep path, with many components that do - // not exist in the root. - name: "deep mount path", - spec: specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/some/very/very/deep/path", - Type: "tmpfs", - }, - { - Destination: "/proc", - Type: "tmpfs", - }, - }, - }, - // /some/deep/path should be mounted, along with /proc, - // /dev, and /sys. - expectedPaths: []string{"/some/very/very/deep/path", "/proc", "/dev", "/sys"}, - }, - { - // Mounts are nested inside each other. - name: "nested mounts", - spec: specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/proc", - Type: "tmpfs", - }, - { - Destination: "/foo", - Type: "tmpfs", - }, - { - Destination: "/foo/qux", - Type: "tmpfs", - }, - { - // File mounts with the same prefix. - Destination: "/foo/qux-quz", - Type: "tmpfs", - }, - { - Destination: "/foo/bar", - Type: "tmpfs", - }, - { - Destination: "/foo/bar/baz", - Type: "tmpfs", - }, - { - // A deep path that is in foo but not the other mounts. - Destination: "/foo/some/very/very/deep/path", - Type: "tmpfs", - }, - }, - }, - expectedPaths: []string{"/foo", "/foo/bar", "/foo/bar/baz", "/foo/qux", - "/foo/qux-quz", "/foo/some/very/very/deep/path", "/proc", "/dev", "/sys"}, - }, - { - name: "mount inside /dev", - spec: specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/proc", - Type: "tmpfs", - }, - { - Destination: "/dev", - Type: "tmpfs", - }, - { - // Mounted by runsc by default. - Destination: "/dev/fd", - Type: "tmpfs", - }, - { - // Mount with the same prefix. - Destination: "/dev/fd-foo", - Type: "tmpfs", - }, - { - // Unsupported fs type. - Destination: "/dev/mqueue", - Type: "mqueue", - }, - { - Destination: "/dev/foo", - Type: "tmpfs", - }, - { - Destination: "/dev/bar", - Type: "tmpfs", - }, - }, - }, - expectedPaths: []string{"/proc", "/dev", "/dev/fd-foo", "/dev/foo", "/dev/bar", "/sys"}, - }, - { - name: "mounts inside mandatory mounts", - spec: specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/proc", - Type: "tmpfs", - }, - // We don't include /sys, and /tmp in - // the spec, since they will be added - // automatically. - // - // Instead, add submounts inside these - // directories and make sure they are - // visible under the mandatory mounts. - { - Destination: "/sys/bar", - Type: "tmpfs", - }, - { - Destination: "/tmp/baz", - Type: "tmpfs", - }, - }, - }, - expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - conf := testConfig() - ctx := contexttest.Context(t) - - sandEnd, cleanup, err := startGofer(tc.spec.Root.Path) - if err != nil { - t.Fatalf("failed to create gofer: %v", err) - } - defer cleanup() - - mntr := newContainerMounter(&tc.spec, []int{sandEnd}, nil, &podMountHints{}) - mns, err := mntr.createMountNamespace(ctx, conf) - if err != nil { - t.Fatalf("failed to create mount namespace: %v", err) - } - ctx = fs.WithRoot(ctx, mns.Root()) - if err := mntr.mountSubmounts(ctx, conf, mns); err != nil { - t.Fatalf("failed to create mount namespace: %v", err) - } - - root := mns.Root() - defer root.DecRef() - for _, p := range tc.expectedPaths { - maxTraversals := uint(0) - if d, err := mns.FindInode(ctx, root, root, p, &maxTraversals); err != nil { - t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err) - } else { - d.DecRef() - } - } - }) - } -} - -// TestRestoreEnvironment tests that the correct mounts are collected from the spec and config -// in order to build the environment for restoring. -func TestRestoreEnvironment(t *testing.T) { - testCases := []struct { - name string - spec *specs.Spec - ioFDs []int - errorExpected bool - expectedRenv fs.RestoreEnvironment - }{ - { - name: "basic spec test", - spec: &specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/some/very/very/deep/path", - Type: "tmpfs", - }, - { - Destination: "/proc", - Type: "tmpfs", - }, - }, - }, - ioFDs: []int{0}, - errorExpected: false, - expectedRenv: fs.RestoreEnvironment{ - MountSources: map[string][]fs.MountArgs{ - "9p": { - { - Dev: "9pfs-/", - Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", - }, - }, - "tmpfs": { - { - Dev: "none", - }, - { - Dev: "none", - }, - { - Dev: "none", - }, - }, - "devtmpfs": { - { - Dev: "none", - }, - }, - "devpts": { - { - Dev: "none", - }, - }, - "sysfs": { - { - Dev: "none", - }, - }, - }, - }, - }, - { - name: "bind type test", - spec: &specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/dev/fd-foo", - Type: "bind", - }, - }, - }, - ioFDs: []int{0, 1}, - errorExpected: false, - expectedRenv: fs.RestoreEnvironment{ - MountSources: map[string][]fs.MountArgs{ - "9p": { - { - Dev: "9pfs-/", - Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", - }, - { - Dev: "9pfs-/dev/fd-foo", - DataString: "trans=fd,rfdno=1,wfdno=1,privateunixsocket=true,cache=remote_revalidating", - }, - }, - "tmpfs": { - { - Dev: "none", - }, - }, - "devtmpfs": { - { - Dev: "none", - }, - }, - "devpts": { - { - Dev: "none", - }, - }, - "proc": { - { - Dev: "none", - }, - }, - "sysfs": { - { - Dev: "none", - }, - }, - }, - }, - }, - { - name: "options test", - spec: &specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/dev/fd-foo", - Type: "tmpfs", - Options: []string{"uid=1022", "noatime"}, - }, - }, - }, - ioFDs: []int{0}, - errorExpected: false, - expectedRenv: fs.RestoreEnvironment{ - MountSources: map[string][]fs.MountArgs{ - "9p": { - { - Dev: "9pfs-/", - Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", - }, - }, - "tmpfs": { - { - Dev: "none", - Flags: fs.MountSourceFlags{NoAtime: true}, - DataString: "uid=1022", - }, - { - Dev: "none", - }, - }, - "devtmpfs": { - { - Dev: "none", - }, - }, - "devpts": { - { - Dev: "none", - }, - }, - "proc": { - { - Dev: "none", - }, - }, - "sysfs": { - { - Dev: "none", - }, - }, - }, - }, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - conf := testConfig() - mntr := newContainerMounter(tc.spec, tc.ioFDs, nil, &podMountHints{}) - actualRenv, err := mntr.createRestoreEnvironment(conf) - if !tc.errorExpected && err != nil { - t.Fatalf("could not create restore environment for test:%s", tc.name) - } else if tc.errorExpected { - if err == nil { - t.Errorf("expected an error, but no error occurred.") - } - } else { - if !reflect.DeepEqual(*actualRenv, tc.expectedRenv) { - t.Errorf("restore environments did not match for test:%s\ngot:%+v\nwant:%+v\n", tc.name, *actualRenv, tc.expectedRenv) - } - } - }) - } -} diff --git a/runsc/boot/platforms/BUILD b/runsc/boot/platforms/BUILD deleted file mode 100644 index 77774f43c..000000000 --- a/runsc/boot/platforms/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "platforms", - srcs = ["platforms.go"], - visibility = [ - "//runsc:__subpackages__", - ], - deps = [ - "//pkg/sentry/platform/kvm", - "//pkg/sentry/platform/ptrace", - ], -) diff --git a/runsc/boot/platforms/platforms_state_autogen.go b/runsc/boot/platforms/platforms_state_autogen.go new file mode 100755 index 000000000..8676d25c1 --- /dev/null +++ b/runsc/boot/platforms/platforms_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package platforms diff --git a/runsc/boot/pprof/BUILD b/runsc/boot/pprof/BUILD deleted file mode 100644 index 29cb42b2f..000000000 --- a/runsc/boot/pprof/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "pprof", - srcs = ["pprof.go"], - visibility = [ - "//runsc:__subpackages__", - ], -) diff --git a/runsc/boot/pprof/pprof.go b/runsc/boot/pprof/pprof.go index 1ded20dee..1ded20dee 100644..100755 --- a/runsc/boot/pprof/pprof.go +++ b/runsc/boot/pprof/pprof.go diff --git a/runsc/boot/pprof/pprof_state_autogen.go b/runsc/boot/pprof/pprof_state_autogen.go new file mode 100755 index 000000000..cabd43173 --- /dev/null +++ b/runsc/boot/pprof/pprof_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pprof diff --git a/runsc/boot/user_test.go b/runsc/boot/user_test.go deleted file mode 100644 index fb4e13dfb..000000000 --- a/runsc/boot/user_test.go +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package boot - -import ( - "io/ioutil" - "os" - "path/filepath" - "strings" - "syscall" - "testing" - - specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" -) - -func setupTempDir() (string, error) { - tmpDir, err := ioutil.TempDir(os.TempDir(), "exec-user-test") - if err != nil { - return "", err - } - return tmpDir, nil -} - -func setupPasswd(contents string, perms os.FileMode) func() (string, error) { - return func() (string, error) { - tmpDir, err := setupTempDir() - if err != nil { - return "", err - } - - if err := os.Mkdir(filepath.Join(tmpDir, "etc"), 0777); err != nil { - return "", err - } - - f, err := os.Create(filepath.Join(tmpDir, "etc", "passwd")) - if err != nil { - return "", err - } - defer f.Close() - - _, err = f.WriteString(contents) - if err != nil { - return "", err - } - - err = f.Chmod(perms) - if err != nil { - return "", err - } - return tmpDir, nil - } -} - -// TestGetExecUserHome tests the getExecUserHome function. -func TestGetExecUserHome(t *testing.T) { - tests := map[string]struct { - uid auth.KUID - createRoot func() (string, error) - expected string - }{ - "success": { - uid: 1000, - createRoot: setupPasswd("adin::1000:1111::/home/adin:/bin/sh", 0666), - expected: "/home/adin", - }, - "no_passwd": { - uid: 1000, - createRoot: setupTempDir, - expected: "/", - }, - "no_perms": { - uid: 1000, - createRoot: setupPasswd("adin::1000:1111::/home/adin:/bin/sh", 0000), - expected: "/", - }, - "directory": { - uid: 1000, - createRoot: func() (string, error) { - tmpDir, err := setupTempDir() - if err != nil { - return "", err - } - - if err := os.Mkdir(filepath.Join(tmpDir, "etc"), 0777); err != nil { - return "", err - } - - if err := syscall.Mkdir(filepath.Join(tmpDir, "etc", "passwd"), 0666); err != nil { - return "", err - } - - return tmpDir, nil - }, - expected: "/", - }, - // Currently we don't allow named pipes. - "named_pipe": { - uid: 1000, - createRoot: func() (string, error) { - tmpDir, err := setupTempDir() - if err != nil { - return "", err - } - - if err := os.Mkdir(filepath.Join(tmpDir, "etc"), 0777); err != nil { - return "", err - } - - if err := syscall.Mkfifo(filepath.Join(tmpDir, "etc", "passwd"), 0666); err != nil { - return "", err - } - - return tmpDir, nil - }, - expected: "/", - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - tmpDir, err := tc.createRoot() - if err != nil { - t.Fatalf("failed to create root dir: %v", err) - } - - sandEnd, cleanup, err := startGofer(tmpDir) - if err != nil { - t.Fatalf("failed to create gofer: %v", err) - } - defer cleanup() - - ctx := contexttest.Context(t) - conf := &Config{ - RootDir: "unused_root_dir", - Network: NetworkNone, - DisableSeccomp: true, - } - - spec := &specs.Spec{ - Root: &specs.Root{ - Path: tmpDir, - Readonly: true, - }, - // Add /proc mount as tmpfs to avoid needing a kernel. - Mounts: []specs.Mount{ - { - Destination: "/proc", - Type: "tmpfs", - }, - }, - } - - mntr := newContainerMounter(spec, []int{sandEnd}, nil, &podMountHints{}) - mns, err := mntr.createMountNamespace(ctx, conf) - if err != nil { - t.Fatalf("failed to create mount namespace: %v", err) - } - ctx = fs.WithRoot(ctx, mns.Root()) - if err := mntr.mountSubmounts(ctx, conf, mns); err != nil { - t.Fatalf("failed to create mount namespace: %v", err) - } - - got, err := getExecUserHome(ctx, mns, tc.uid) - if err != nil { - t.Fatalf("failed to get user home: %v", err) - } - - if got != tc.expected { - t.Fatalf("expected %v, got: %v", tc.expected, got) - } - }) - } -} - -// TestFindHomeInPasswd tests the findHomeInPasswd function's passwd file parsing. -func TestFindHomeInPasswd(t *testing.T) { - tests := map[string]struct { - uid uint32 - passwd string - expected string - def string - }{ - "empty": { - uid: 1000, - passwd: "", - expected: "/", - def: "/", - }, - "whitespace": { - uid: 1000, - passwd: " ", - expected: "/", - def: "/", - }, - "full": { - uid: 1000, - passwd: "adin::1000:1111::/home/adin:/bin/sh", - expected: "/home/adin", - def: "/", - }, - // For better or worse, this is how runc works. - "partial": { - uid: 1000, - passwd: "adin::1000:1111:", - expected: "", - def: "/", - }, - "multiple": { - uid: 1001, - passwd: "adin::1000:1111::/home/adin:/bin/sh\nian::1001:1111::/home/ian:/bin/sh", - expected: "/home/ian", - def: "/", - }, - "duplicate": { - uid: 1000, - passwd: "adin::1000:1111::/home/adin:/bin/sh\nian::1000:1111::/home/ian:/bin/sh", - expected: "/home/adin", - def: "/", - }, - "empty_lines": { - uid: 1001, - passwd: "adin::1000:1111::/home/adin:/bin/sh\n\n\nian::1001:1111::/home/ian:/bin/sh", - expected: "/home/ian", - def: "/", - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - got, err := findHomeInPasswd(tc.uid, strings.NewReader(tc.passwd), tc.def) - if err != nil { - t.Fatalf("error parsing passwd: %v", err) - } - if tc.expected != got { - t.Fatalf("expected %v, got: %v", tc.expected, got) - } - }) - } -} diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD deleted file mode 100644 index d4c7bdfbb..000000000 --- a/runsc/cgroup/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "cgroup", - srcs = ["cgroup.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//runsc/specutils", - "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - ], -) - -go_test( - name = "cgroup_test", - size = "small", - srcs = ["cgroup_test.go"], - library = ":cgroup", - tags = ["local"], -) diff --git a/runsc/cgroup/cgroup_state_autogen.go b/runsc/cgroup/cgroup_state_autogen.go new file mode 100755 index 000000000..934ed169b --- /dev/null +++ b/runsc/cgroup/cgroup_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package cgroup diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go deleted file mode 100644 index 548c80e9a..000000000 --- a/runsc/cgroup/cgroup_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cgroup - -import ( - "testing" -) - -func TestUninstallEnoent(t *testing.T) { - c := Cgroup{ - // set a non-existent name - Name: "runsc-test-uninstall-656e6f656e740a", - Own: true, - } - if err := c.Uninstall(); err != nil { - t.Errorf("Uninstall() failed: %v", err) - } -} - -func TestCountCpuset(t *testing.T) { - for _, tc := range []struct { - str string - want int - error bool - }{ - {str: "0", want: 1}, - {str: "0,1,2,8,9,10", want: 6}, - {str: "0-1", want: 2}, - {str: "0-7", want: 8}, - {str: "0-7,16,32-39,64,65", want: 19}, - {str: "a", error: true}, - {str: "5-a", error: true}, - {str: "a-5", error: true}, - {str: "-10", error: true}, - {str: "15-", error: true}, - {str: "-", error: true}, - {str: "--", error: true}, - } { - t.Run(tc.str, func(t *testing.T) { - got, err := countCpuset(tc.str) - if tc.error { - if err == nil { - t.Errorf("countCpuset(%q) should have failed", tc.str) - } - } else { - if err != nil { - t.Errorf("countCpuset(%q) failed: %v", tc.str, err) - } - if tc.want != got { - t.Errorf("countCpuset(%q) want: %d, got: %d", tc.str, tc.want, got) - } - } - }) - } -} diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD deleted file mode 100644 index d0bb4613a..000000000 --- a/runsc/cmd/BUILD +++ /dev/null @@ -1,95 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "cmd", - srcs = [ - "boot.go", - "capability.go", - "checkpoint.go", - "chroot.go", - "cmd.go", - "create.go", - "debug.go", - "delete.go", - "do.go", - "error.go", - "events.go", - "exec.go", - "gofer.go", - "help.go", - "install.go", - "kill.go", - "list.go", - "path.go", - "pause.go", - "ps.go", - "restore.go", - "resume.go", - "run.go", - "spec.go", - "start.go", - "state.go", - "statefile.go", - "syscalls.go", - "wait.go", - ], - visibility = [ - "//runsc:__subpackages__", - ], - deps = [ - "//pkg/log", - "//pkg/p9", - "//pkg/sentry/control", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/state", - "//pkg/state/statefile", - "//pkg/sync", - "//pkg/unet", - "//pkg/urpc", - "//runsc/boot", - "//runsc/boot/platforms", - "//runsc/console", - "//runsc/container", - "//runsc/flag", - "//runsc/fsgofer", - "//runsc/fsgofer/filter", - "//runsc/specutils", - "@com_github_google_subcommands//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@com_github_syndtr_gocapability//capability:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "cmd_test", - size = "small", - srcs = [ - "capability_test.go", - "delete_test.go", - "exec_test.go", - "gofer_test.go", - ], - data = [ - "//runsc", - ], - library = ":cmd", - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/sentry/control", - "//pkg/sentry/kernel/auth", - "//pkg/urpc", - "//runsc/boot", - "//runsc/container", - "//runsc/specutils", - "//runsc/testutil", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@com_github_syndtr_gocapability//capability:go_default_library", - ], -) diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go deleted file mode 100644 index 0c27f7313..000000000 --- a/runsc/cmd/capability_test.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "flag" - "fmt" - "os" - "testing" - - specs "github.com/opencontainers/runtime-spec/specs-go" - "github.com/syndtr/gocapability/capability" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" - "gvisor.dev/gvisor/runsc/container" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" -) - -func init() { - log.SetLevel(log.Debug) - if err := testutil.ConfigureExePath(); err != nil { - panic(err.Error()) - } -} - -func checkProcessCaps(pid int, wantCaps *specs.LinuxCapabilities) error { - curCaps, err := capability.NewPid2(pid) - if err != nil { - return fmt.Errorf("capability.NewPid2(%d) failed: %v", pid, err) - } - if err := curCaps.Load(); err != nil { - return fmt.Errorf("unable to load capabilities: %v", err) - } - fmt.Printf("Capabilities (PID: %d): %v\n", pid, curCaps) - - for _, c := range allCapTypes { - if err := checkCaps(c, curCaps, wantCaps); err != nil { - return err - } - } - return nil -} - -func checkCaps(which capability.CapType, curCaps capability.Capabilities, wantCaps *specs.LinuxCapabilities) error { - wantNames := getCaps(which, wantCaps) - for name, c := range capFromName { - want := specutils.ContainsStr(wantNames, name) - got := curCaps.Get(which, c) - if want != got { - if want { - return fmt.Errorf("capability %v:%s should be set", which, name) - } - return fmt.Errorf("capability %v:%s should NOT be set", which, name) - } - } - return nil -} - -func TestCapabilities(t *testing.T) { - stop := testutil.StartReaper() - defer stop() - - spec := testutil.NewSpecWithArgs("/bin/sleep", "10000") - caps := []string{ - "CAP_CHOWN", - "CAP_SYS_PTRACE", // ptrace is added due to the platform choice. - } - spec.Process.Capabilities = &specs.LinuxCapabilities{ - Permitted: caps, - Bounding: caps, - Effective: caps, - Inheritable: caps, - } - - conf := testutil.TestConfig() - - // Use --network=host to make sandbox use spec's capabilities. - conf.Network = boot.NetworkHost - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := container.Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := container.New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Check that sandbox and gofer have the proper capabilities. - if err := checkProcessCaps(c.Sandbox.Pid, spec.Process.Capabilities); err != nil { - t.Error(err) - } - if err := checkProcessCaps(c.GoferPid, goferCaps); err != nil { - t.Error(err) - } -} - -func TestMain(m *testing.M) { - flag.Parse() - specutils.MaybeRunAsRoot() - os.Exit(m.Run()) -} diff --git a/runsc/cmd/cmd_state_autogen.go b/runsc/cmd/cmd_state_autogen.go new file mode 100755 index 000000000..de8aa267b --- /dev/null +++ b/runsc/cmd/cmd_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package cmd diff --git a/runsc/cmd/delete_test.go b/runsc/cmd/delete_test.go deleted file mode 100644 index cb59516a3..000000000 --- a/runsc/cmd/delete_test.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "io/ioutil" - "testing" - - "gvisor.dev/gvisor/runsc/boot" -) - -func TestNotFound(t *testing.T) { - ids := []string{"123"} - dir, err := ioutil.TempDir("", "metadata") - if err != nil { - t.Fatalf("error creating dir: %v", err) - } - conf := &boot.Config{RootDir: dir} - - d := Delete{} - if err := d.execute(ids, conf); err == nil { - t.Error("Deleting non-existent container should have failed") - } - - d = Delete{force: true} - if err := d.execute(ids, conf); err != nil { - t.Errorf("Deleting non-existent container with --force should NOT have failed: %v", err) - } -} diff --git a/runsc/cmd/exec_test.go b/runsc/cmd/exec_test.go deleted file mode 100644 index a1e980d08..000000000 --- a/runsc/cmd/exec_test.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "os" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/urpc" -) - -func TestUser(t *testing.T) { - testCases := []struct { - input string - want user - wantErr bool - }{ - {input: "0", want: user{kuid: 0, kgid: 0}}, - {input: "7", want: user{kuid: 7, kgid: 0}}, - {input: "49:343", want: user{kuid: 49, kgid: 343}}, - {input: "0:2401", want: user{kuid: 0, kgid: 2401}}, - {input: "", wantErr: true}, - {input: "foo", wantErr: true}, - {input: ":123", wantErr: true}, - {input: "1:2:3", wantErr: true}, - } - - for _, tc := range testCases { - var u user - if err := u.Set(tc.input); err != nil && tc.wantErr { - // We got an error and wanted one. - continue - } else if err == nil && tc.wantErr { - t.Errorf("user.Set(%s): got no error, but wanted one", tc.input) - } else if err != nil && !tc.wantErr { - t.Errorf("user.Set(%s): got error %v, but wanted none", tc.input, err) - } else if u != tc.want { - t.Errorf("user.Set(%s): got %+v, but wanted %+v", tc.input, u, tc.want) - } - } -} - -func TestCLIArgs(t *testing.T) { - testCases := []struct { - ex Exec - argv []string - expected control.ExecArgs - }{ - { - ex: Exec{ - cwd: "/foo/bar", - user: user{kuid: 0, kgid: 0}, - extraKGIDs: []string{"1", "2", "3"}, - caps: []string{"CAP_DAC_OVERRIDE"}, - processPath: "", - }, - argv: []string{"ls", "/"}, - expected: control.ExecArgs{ - Argv: []string{"ls", "/"}, - WorkingDirectory: "/foo/bar", - FilePayload: urpc.FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}}, - KUID: 0, - KGID: 0, - ExtraKGIDs: []auth.KGID{1, 2, 3}, - Capabilities: &auth.TaskCapabilities{ - BoundingCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - EffectiveCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - InheritableCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - PermittedCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - }, - }, - }, - } - - for _, tc := range testCases { - e, err := tc.ex.argsFromCLI(tc.argv, true) - if err != nil { - t.Errorf("argsFromCLI(%+v): got error: %+v", tc.ex, err) - } else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) { - t.Errorf("argsFromCLI(%+v): got %+v, but expected %+v", tc.ex, *e, tc.expected) - } - } -} - -func TestJSONArgs(t *testing.T) { - testCases := []struct { - // ex is provided to make sure it is overridden by p. - ex Exec - p specs.Process - expected control.ExecArgs - }{ - { - ex: Exec{ - cwd: "/baz/quux", - user: user{kuid: 1, kgid: 1}, - extraKGIDs: []string{"4", "5", "6"}, - caps: []string{"CAP_SETGID"}, - processPath: "/bin/foo", - }, - p: specs.Process{ - User: specs.User{UID: 0, GID: 0, AdditionalGids: []uint32{1, 2, 3}}, - Args: []string{"ls", "/"}, - Cwd: "/foo/bar", - Capabilities: &specs.LinuxCapabilities{ - Bounding: []string{"CAP_DAC_OVERRIDE"}, - Effective: []string{"CAP_DAC_OVERRIDE"}, - Inheritable: []string{"CAP_DAC_OVERRIDE"}, - Permitted: []string{"CAP_DAC_OVERRIDE"}, - }, - }, - expected: control.ExecArgs{ - Argv: []string{"ls", "/"}, - WorkingDirectory: "/foo/bar", - FilePayload: urpc.FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}}, - KUID: 0, - KGID: 0, - ExtraKGIDs: []auth.KGID{1, 2, 3}, - Capabilities: &auth.TaskCapabilities{ - BoundingCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - EffectiveCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - InheritableCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - PermittedCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - }, - }, - }, - } - - for _, tc := range testCases { - e, err := argsFromProcess(&tc.p, true) - if err != nil { - t.Errorf("argsFromProcess(%+v): got error: %+v", tc.p, err) - } else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) { - t.Errorf("argsFromProcess(%+v): got %+v, but expected %+v", tc.p, *e, tc.expected) - } - } -} diff --git a/runsc/cmd/gofer_test.go b/runsc/cmd/gofer_test.go deleted file mode 100644 index cbea7f127..000000000 --- a/runsc/cmd/gofer_test.go +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "fmt" - "io/ioutil" - "os" - "path" - "path/filepath" - "testing" -) - -func tmpDir() string { - dir := os.Getenv("TEST_TMPDIR") - if dir == "" { - dir = "/tmp" - } - return dir -} - -type dir struct { - rel string - link string -} - -func construct(root string, dirs []dir) error { - for _, d := range dirs { - p := path.Join(root, d.rel) - if d.link == "" { - if err := os.MkdirAll(p, 0755); err != nil { - return fmt.Errorf("error creating dir: %v", err) - } - } else { - if err := os.MkdirAll(path.Dir(p), 0755); err != nil { - return fmt.Errorf("error creating dir: %v", err) - } - if err := os.Symlink(d.link, p); err != nil { - return fmt.Errorf("error creating symlink: %v", err) - } - } - } - return nil -} - -func TestResolveSymlinks(t *testing.T) { - root, err := ioutil.TempDir(tmpDir(), "root") - if err != nil { - t.Fatal("ioutil.TempDir() failed:", err) - } - dirs := []dir{ - {"dir1/dir11/dir111/dir1111", ""}, // Just a boring dir - {"dir1/lnk12", "dir11"}, // Link to sibling - {"dir1/lnk13", "./dir11"}, // Link to sibling through self - {"dir1/lnk14", "../dir1/dir11"}, // Link to sibling through parent - {"dir1/dir15/lnk151", ".."}, // Link to parent - {"dir1/lnk16", "dir11/dir111"}, // Link to child - {"dir1/lnk17", "."}, // Link to self - {"dir1/lnk18", "lnk13"}, // Link to link - {"lnk2", "dir1/lnk13"}, // Link to link to link - {"dir3/dir21/lnk211", "../.."}, // Link to root relative - {"dir3/lnk22", "/"}, // Link to root absolute - {"dir3/lnk23", "/dir1"}, // Link to dir absolute - {"dir3/lnk24", "/dir1/lnk12"}, // Link to link absolute - {"lnk5", "../../.."}, // Link outside root - } - if err := construct(root, dirs); err != nil { - t.Fatal("construct failed:", err) - } - - tests := []struct { - name string - rel string - want string - compareHost bool - }{ - {name: "root", rel: "/", want: "/", compareHost: true}, - {name: "basic dir", rel: "/dir1/dir11/dir111", want: "/dir1/dir11/dir111", compareHost: true}, - {name: "dot 1", rel: "/dir1/dir11/./dir111", want: "/dir1/dir11/dir111", compareHost: true}, - {name: "dot 2", rel: "/dir1/././dir11/./././././dir111/.", want: "/dir1/dir11/dir111", compareHost: true}, - {name: "dotdot 1", rel: "/dir1/dir11/../dir15", want: "/dir1/dir15", compareHost: true}, - {name: "dotdot 2", rel: "/dir1/dir11/dir1111/../..", want: "/dir1", compareHost: true}, - - {name: "link sibling", rel: "/dir1/lnk12", want: "/dir1/dir11", compareHost: true}, - {name: "link sibling + dir", rel: "/dir1/lnk12/dir111", want: "/dir1/dir11/dir111", compareHost: true}, - {name: "link sibling through self", rel: "/dir1/lnk13", want: "/dir1/dir11", compareHost: true}, - {name: "link sibling through parent", rel: "/dir1/lnk14", want: "/dir1/dir11", compareHost: true}, - - {name: "link parent", rel: "/dir1/dir15/lnk151", want: "/dir1", compareHost: true}, - {name: "link parent + dir", rel: "/dir1/dir15/lnk151/dir11", want: "/dir1/dir11", compareHost: true}, - {name: "link child", rel: "/dir1/lnk16", want: "/dir1/dir11/dir111", compareHost: true}, - {name: "link child + dir", rel: "/dir1/lnk16/dir1111", want: "/dir1/dir11/dir111/dir1111", compareHost: true}, - {name: "link self", rel: "/dir1/lnk17", want: "/dir1", compareHost: true}, - {name: "link self + dir", rel: "/dir1/lnk17/dir11", want: "/dir1/dir11", compareHost: true}, - - {name: "link^2", rel: "/dir1/lnk18", want: "/dir1/dir11", compareHost: true}, - {name: "link^2 + dir", rel: "/dir1/lnk18/dir111", want: "/dir1/dir11/dir111", compareHost: true}, - {name: "link^3", rel: "/lnk2", want: "/dir1/dir11", compareHost: true}, - {name: "link^3 + dir", rel: "/lnk2/dir111", want: "/dir1/dir11/dir111", compareHost: true}, - - {name: "link abs", rel: "/dir3/lnk23", want: "/dir1"}, - {name: "link abs + dir", rel: "/dir3/lnk23/dir11", want: "/dir1/dir11"}, - {name: "link^2 abs", rel: "/dir3/lnk24", want: "/dir1/dir11"}, - {name: "link^2 abs + dir", rel: "/dir3/lnk24/dir111", want: "/dir1/dir11/dir111"}, - - {name: "root link rel", rel: "/dir3/dir21/lnk211", want: "/", compareHost: true}, - {name: "root link abs", rel: "/dir3/lnk22", want: "/"}, - {name: "root contain link", rel: "/lnk5/dir1", want: "/dir1"}, - {name: "root contain dotdot", rel: "/dir1/dir11/../../../../../../../..", want: "/"}, - - {name: "crazy", rel: "/dir3/dir21/lnk211/dir3/lnk22/dir1/dir11/../../lnk5/dir3/../dir3/lnk24/dir111/dir1111/..", want: "/dir1/dir11/dir111"}, - } - for _, tst := range tests { - t.Run(tst.name, func(t *testing.T) { - got, err := resolveSymlinks(root, tst.rel) - if err != nil { - t.Errorf("resolveSymlinks(root, %q) failed: %v", tst.rel, err) - } - want := path.Join(root, tst.want) - if got != want { - t.Errorf("resolveSymlinks(root, %q) got: %q, want: %q", tst.rel, got, want) - } - if tst.compareHost { - // Check that host got to the same end result. - host, err := filepath.EvalSymlinks(path.Join(root, tst.rel)) - if err != nil { - t.Errorf("path.EvalSymlinks(root, %q) failed: %v", tst.rel, err) - } - if host != got { - t.Errorf("resolveSymlinks(root, %q) got: %q, want: %q", tst.rel, host, got) - } - } - }) - } -} - -func TestResolveSymlinksLoop(t *testing.T) { - root, err := ioutil.TempDir(tmpDir(), "root") - if err != nil { - t.Fatal("ioutil.TempDir() failed:", err) - } - dirs := []dir{ - {"loop1", "loop2"}, - {"loop2", "loop1"}, - } - if err := construct(root, dirs); err != nil { - t.Fatal("construct failed:", err) - } - if _, err := resolveSymlinks(root, "loop1"); err == nil { - t.Errorf("resolveSymlinks() should have failed") - } -} diff --git a/runsc/cmd/statefile.go b/runsc/cmd/statefile.go index e6f1907da..e6f1907da 100644..100755 --- a/runsc/cmd/statefile.go +++ b/runsc/cmd/statefile.go diff --git a/runsc/console/BUILD b/runsc/console/BUILD deleted file mode 100644 index 06924bccd..000000000 --- a/runsc/console/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "console", - srcs = [ - "console.go", - ], - visibility = [ - "//runsc:__subpackages__", - ], - deps = [ - "@com_github_kr_pty//:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/runsc/console/console_state_autogen.go b/runsc/console/console_state_autogen.go new file mode 100755 index 000000000..80521cdb7 --- /dev/null +++ b/runsc/console/console_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package console diff --git a/runsc/container/BUILD b/runsc/container/BUILD deleted file mode 100644 index 0aaeea3a8..000000000 --- a/runsc/container/BUILD +++ /dev/null @@ -1,68 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "container", - srcs = [ - "container.go", - "hook.go", - "state_file.go", - "status.go", - ], - visibility = [ - "//runsc:__subpackages__", - "//test:__subpackages__", - ], - deps = [ - "//pkg/log", - "//pkg/sentry/control", - "//pkg/sync", - "//runsc/boot", - "//runsc/cgroup", - "//runsc/sandbox", - "//runsc/specutils", - "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_gofrs_flock//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - ], -) - -go_test( - name = "container_test", - size = "large", - srcs = [ - "console_test.go", - "container_test.go", - "multi_container_test.go", - "shared_volume_test.go", - ], - data = [ - "//runsc", - "//runsc/container/test_app", - ], - library = ":container", - shard_count = 5, - tags = [ - "requires-kvm", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/log", - "//pkg/sentry/control", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sync", - "//pkg/unet", - "//pkg/urpc", - "//runsc/boot", - "//runsc/boot/platforms", - "//runsc/specutils", - "//runsc/testutil", - "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_kr_pty//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go deleted file mode 100644 index 651615d4c..000000000 --- a/runsc/container/console_test.go +++ /dev/null @@ -1,487 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package container - -import ( - "bytes" - "fmt" - "io" - "os" - "path/filepath" - "syscall" - "testing" - "time" - - "github.com/kr/pty" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/unet" - "gvisor.dev/gvisor/pkg/urpc" - "gvisor.dev/gvisor/runsc/testutil" -) - -// socketPath creates a path inside bundleDir and ensures that the returned -// path is under 108 charactors (the unix socket path length limit), -// relativizing the path if necessary. -func socketPath(bundleDir string) (string, error) { - path := filepath.Join(bundleDir, "socket") - cwd, err := os.Getwd() - if err != nil { - return "", fmt.Errorf("error getting cwd: %v", err) - } - relPath, err := filepath.Rel(cwd, path) - if err != nil { - return "", fmt.Errorf("error getting relative path for %q from cwd %q: %v", path, cwd, err) - } - if len(path) > len(relPath) { - path = relPath - } - const maxPathLen = 108 - if len(path) > maxPathLen { - return "", fmt.Errorf("could not get socket path under length limit %d: %s", maxPathLen, path) - } - return path, nil -} - -// createConsoleSocket creates a socket at the given path that will receive a -// console fd from the sandbox. If no error occurs, it returns the server -// socket and a cleanup function. -func createConsoleSocket(path string) (*unet.ServerSocket, func() error, error) { - srv, err := unet.BindAndListen(path, false) - if err != nil { - return nil, nil, fmt.Errorf("error binding and listening to socket %q: %v", path, err) - } - - cleanup := func() error { - if err := srv.Close(); err != nil { - return fmt.Errorf("error closing socket %q: %v", path, err) - } - if err := os.Remove(path); err != nil { - return fmt.Errorf("error removing socket %q: %v", path, err) - } - return nil - } - - return srv, cleanup, nil -} - -// receiveConsolePTY accepts a connection on the server socket and reads fds. -// It fails if more than one FD is received, or if the FD is not a PTY. It -// returns the PTY master file. -func receiveConsolePTY(srv *unet.ServerSocket) (*os.File, error) { - sock, err := srv.Accept() - if err != nil { - return nil, fmt.Errorf("error accepting socket connection: %v", err) - } - - // Allow 3 fds to be received. We only expect 1. - r := sock.Reader(true /* blocking */) - r.EnableFDs(1) - - // The socket is closed right after sending the FD, so EOF is - // an allowed error. - b := [][]byte{{}} - if _, err := r.ReadVec(b); err != nil && err != io.EOF { - return nil, fmt.Errorf("error reading from socket connection: %v", err) - } - - // We should have gotten a control message. - fds, err := r.ExtractFDs() - if err != nil { - return nil, fmt.Errorf("error extracting fds from socket connection: %v", err) - } - if len(fds) != 1 { - return nil, fmt.Errorf("got %d fds from socket, wanted 1", len(fds)) - } - - // Verify that the fd is a terminal. - if _, err := unix.IoctlGetTermios(fds[0], unix.TCGETS); err != nil { - return nil, fmt.Errorf("fd is not a terminal (ioctl TGGETS got %v)", err) - } - - return os.NewFile(uintptr(fds[0]), "pty_master"), nil -} - -// Test that an pty FD is sent over the console socket if one is provided. -func TestConsoleSocket(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - spec := testutil.NewSpecWithArgs("true") - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - sock, err := socketPath(bundleDir) - if err != nil { - t.Fatalf("error getting socket path: %v", err) - } - srv, cleanup, err := createConsoleSocket(sock) - if err != nil { - t.Fatalf("error creating socket at %q: %v", sock, err) - } - defer cleanup() - - // Create the container and pass the socket name. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - ConsoleSocket: sock, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - - // Make sure we get a console PTY. - ptyMaster, err := receiveConsolePTY(srv) - if err != nil { - t.Fatalf("error receiving console FD: %v", err) - } - ptyMaster.Close() - } -} - -// Test that job control signals work on a console created with "exec -ti". -func TestJobControlSignalExec(t *testing.T) { - spec := testutil.NewSpecWithArgs("/bin/sleep", "10000") - conf := testutil.TestConfig() - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Create a pty master/slave. The slave will be passed to the exec - // process. - ptyMaster, ptySlave, err := pty.Open() - if err != nil { - t.Fatalf("error opening pty: %v", err) - } - defer ptyMaster.Close() - defer ptySlave.Close() - - // Exec bash and attach a terminal. Note that occasionally /bin/sh - // may be a different shell or have a different configuration (such - // as disabling interactive mode and job control). Since we want to - // explicitly test interactive mode, use /bin/bash. See b/116981926. - execArgs := &control.ExecArgs{ - Filename: "/bin/bash", - // Don't let bash execute from profile or rc files, otherwise - // our PID counts get messed up. - Argv: []string{"/bin/bash", "--noprofile", "--norc"}, - // Pass the pty slave as FD 0, 1, and 2. - FilePayload: urpc.FilePayload{ - Files: []*os.File{ptySlave, ptySlave, ptySlave}, - }, - StdioIsPty: true, - } - - pid, err := c.Execute(execArgs) - if err != nil { - t.Fatalf("error executing: %v", err) - } - if pid != 2 { - t.Fatalf("exec got pid %d, wanted %d", pid, 2) - } - - // Make sure all the processes are running. - expectedPL := []*control.Process{ - // Root container process. - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - // Bash from exec process. - {PID: 2, Cmd: "bash", Threads: []kernel.ThreadID{2}}, - } - if err := waitForProcessList(c, expectedPL); err != nil { - t.Error(err) - } - - // Execute sleep. - ptyMaster.Write([]byte("sleep 100\n")) - - // Wait for it to start. Sleep's PPID is bash's PID. - expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}}) - if err := waitForProcessList(c, expectedPL); err != nil { - t.Error(err) - } - - // Send a SIGTERM to the foreground process for the exec PID. Note that - // although we pass in the PID of "bash", it should actually terminate - // "sleep", since that is the foreground process. - if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.SIGTERM, true /* fgProcess */); err != nil { - t.Fatalf("error signaling container: %v", err) - } - - // Sleep process should be gone. - expectedPL = expectedPL[:len(expectedPL)-1] - if err := waitForProcessList(c, expectedPL); err != nil { - t.Error(err) - } - - // Sleep is dead, but it may take more time for bash to notice and - // change the foreground process back to itself. We know it is done - // when bash writes "Terminated" to the pty. - if err := testutil.WaitUntilRead(ptyMaster, "Terminated", nil, 5*time.Second); err != nil { - t.Fatalf("bash did not take over pty: %v", err) - } - - // Send a SIGKILL to the foreground process again. This time "bash" - // should be killed. We use SIGKILL instead of SIGTERM or SIGINT - // because bash ignores those. - if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.SIGKILL, true /* fgProcess */); err != nil { - t.Fatalf("error signaling container: %v", err) - } - expectedPL = expectedPL[:1] - if err := waitForProcessList(c, expectedPL); err != nil { - t.Error(err) - } - - // Make sure the process indicates it was killed by a SIGKILL. - ws, err := c.WaitPID(pid) - if err != nil { - t.Errorf("waiting on container failed: %v", err) - } - if !ws.Signaled() { - t.Error("ws.Signaled() got false, want true") - } - if got, want := ws.Signal(), syscall.SIGKILL; got != want { - t.Errorf("ws.Signal() got %v, want %v", got, want) - } -} - -// Test that job control signals work on a console created with "run -ti". -func TestJobControlSignalRootContainer(t *testing.T) { - conf := testutil.TestConfig() - // Don't let bash execute from profile or rc files, otherwise our PID - // counts get messed up. - spec := testutil.NewSpecWithArgs("/bin/bash", "--noprofile", "--norc") - spec.Process.Terminal = true - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - sock, err := socketPath(bundleDir) - if err != nil { - t.Fatalf("error getting socket path: %v", err) - } - srv, cleanup, err := createConsoleSocket(sock) - if err != nil { - t.Fatalf("error creating socket at %q: %v", sock, err) - } - defer cleanup() - - // Create the container and pass the socket name. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - ConsoleSocket: sock, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - - // Get the PTY master. - ptyMaster, err := receiveConsolePTY(srv) - if err != nil { - t.Fatalf("error receiving console FD: %v", err) - } - defer ptyMaster.Close() - - // Bash output as well as sandbox output will be written to the PTY - // file. Writes after a certain point will block unless we drain the - // PTY, so we must continually copy from it. - // - // We log the output to stderr for debugabilitly, and also to a buffer, - // since we wait on particular output from bash below. We use a custom - // blockingBuffer which is thread-safe and also blocks on Read calls, - // which makes this a suitable Reader for WaitUntilRead. - ptyBuf := newBlockingBuffer() - tee := io.TeeReader(ptyMaster, ptyBuf) - go io.Copy(os.Stderr, tee) - - // Start the container. - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Start waiting for the container to exit in a goroutine. We do this - // very early, otherwise it might exit before we have a chance to call - // Wait. - var ( - ws syscall.WaitStatus - wg sync.WaitGroup - ) - wg.Add(1) - go func() { - var err error - ws, err = c.Wait() - if err != nil { - t.Errorf("error waiting on container: %v", err) - } - wg.Done() - }() - - // Wait for bash to start. - expectedPL := []*control.Process{ - {PID: 1, Cmd: "bash", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(c, expectedPL); err != nil { - t.Fatal(err) - } - - // Execute sleep via the terminal. - ptyMaster.Write([]byte("sleep 100\n")) - - // Wait for sleep to start. - expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{2}}) - if err := waitForProcessList(c, expectedPL); err != nil { - t.Fatal(err) - } - - // Reset the pty buffer, so there is less output for us to scan later. - ptyBuf.Reset() - - // Send a SIGTERM to the foreground process. We pass PID=0, indicating - // that the root process should be killed. However, by setting - // fgProcess=true, the signal should actually be sent to sleep. - if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, syscall.SIGTERM, true /* fgProcess */); err != nil { - t.Fatalf("error signaling container: %v", err) - } - - // Sleep process should be gone. - expectedPL = expectedPL[:len(expectedPL)-1] - if err := waitForProcessList(c, expectedPL); err != nil { - t.Error(err) - } - - // Sleep is dead, but it may take more time for bash to notice and - // change the foreground process back to itself. We know it is done - // when bash writes "Terminated" to the pty. - if err := testutil.WaitUntilRead(ptyBuf, "Terminated", nil, 5*time.Second); err != nil { - t.Fatalf("bash did not take over pty: %v", err) - } - - // Send a SIGKILL to the foreground process again. This time "bash" - // should be killed. We use SIGKILL instead of SIGTERM or SIGINT - // because bash ignores those. - if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, syscall.SIGKILL, true /* fgProcess */); err != nil { - t.Fatalf("error signaling container: %v", err) - } - - // Wait for the sandbox to exit. It should exit with a SIGKILL status. - wg.Wait() - if !ws.Signaled() { - t.Error("ws.Signaled() got false, want true") - } - if got, want := ws.Signal(), syscall.SIGKILL; got != want { - t.Errorf("ws.Signal() got %v, want %v", got, want) - } -} - -// blockingBuffer is a thread-safe buffer that blocks when reading if the -// buffer is empty. It implements io.ReadWriter. -type blockingBuffer struct { - // A send to readCh indicates that a previously empty buffer now has - // data for reading. - readCh chan struct{} - - // mu protects buf. - mu sync.Mutex - buf bytes.Buffer -} - -func newBlockingBuffer() *blockingBuffer { - return &blockingBuffer{ - readCh: make(chan struct{}, 1), - } -} - -// Write implements Writer.Write. -func (bb *blockingBuffer) Write(p []byte) (int, error) { - bb.mu.Lock() - defer bb.mu.Unlock() - l := bb.buf.Len() - n, err := bb.buf.Write(p) - if l == 0 && n > 0 { - // New data! - bb.readCh <- struct{}{} - } - return n, err -} - -// Read implements Reader.Read. It will block until data is available. -func (bb *blockingBuffer) Read(p []byte) (int, error) { - for { - bb.mu.Lock() - n, err := bb.buf.Read(p) - if n > 0 || err != io.EOF { - if bb.buf.Len() == 0 { - // Reset the readCh. - select { - case <-bb.readCh: - default: - } - } - bb.mu.Unlock() - return n, err - } - bb.mu.Unlock() - - // Wait for new data. - <-bb.readCh - } -} - -// Reset resets the buffer. -func (bb *blockingBuffer) Reset() { - bb.mu.Lock() - defer bb.mu.Unlock() - bb.buf.Reset() - // Reset the readCh. - select { - case <-bb.readCh: - default: - } -} diff --git a/runsc/container/container_state_autogen.go b/runsc/container/container_state_autogen.go new file mode 100755 index 000000000..5bc1c1aff --- /dev/null +++ b/runsc/container/container_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package container diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go deleted file mode 100644 index 442e80ac0..000000000 --- a/runsc/container/container_test.go +++ /dev/null @@ -1,2204 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package container - -import ( - "bytes" - "flag" - "fmt" - "io" - "io/ioutil" - "os" - "path" - "path/filepath" - "reflect" - "strconv" - "strings" - "syscall" - "testing" - "time" - - "github.com/cenkalti/backoff" - specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/bits" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/runsc/boot" - "gvisor.dev/gvisor/runsc/boot/platforms" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" -) - -// waitForProcessList waits for the given process list to show up in the container. -func waitForProcessList(cont *Container, want []*control.Process) error { - cb := func() error { - got, err := cont.Processes() - if err != nil { - err = fmt.Errorf("error getting process data from container: %v", err) - return &backoff.PermanentError{Err: err} - } - if r, err := procListsEqual(got, want); !r { - return fmt.Errorf("container got process list: %s, want: %s: error: %v", - procListToString(got), procListToString(want), err) - } - return nil - } - // Gives plenty of time as tests can run slow under --race. - return testutil.Poll(cb, 30*time.Second) -} - -func waitForProcessCount(cont *Container, want int) error { - cb := func() error { - pss, err := cont.Processes() - if err != nil { - err = fmt.Errorf("error getting process data from container: %v", err) - return &backoff.PermanentError{Err: err} - } - if got := len(pss); got != want { - log.Infof("Waiting for process count to reach %d. Current: %d", want, got) - return fmt.Errorf("wrong process count, got: %d, want: %d", got, want) - } - return nil - } - // Gives plenty of time as tests can run slow under --race. - return testutil.Poll(cb, 30*time.Second) -} - -func blockUntilWaitable(pid int) error { - _, _, err := specutils.RetryEintr(func() (uintptr, uintptr, error) { - var err error - _, _, err1 := syscall.Syscall6(syscall.SYS_WAITID, 1, uintptr(pid), 0, syscall.WEXITED|syscall.WNOWAIT, 0, 0) - if err1 != 0 { - err = err1 - } - return 0, 0, err - }) - return err -} - -// procListsEqual is used to check whether 2 Process lists are equal for all -// implemented fields. -func procListsEqual(got, want []*control.Process) (bool, error) { - if len(got) != len(want) { - return false, nil - } - for i := range got { - pd1 := got[i] - pd2 := want[i] - // Zero out timing dependant fields. - pd1.Time = "" - pd1.STime = "" - pd1.C = 0 - // Ignore TTY field too, since it's not relevant in the cases - // where we use this method. Tests that care about the TTY - // field should check for it themselves. - pd1.TTY = "" - pd1Json, err := control.ProcessListToJSON([]*control.Process{pd1}) - if err != nil { - return false, err - } - pd2Json, err := control.ProcessListToJSON([]*control.Process{pd2}) - if err != nil { - return false, err - } - if pd1Json != pd2Json { - return false, nil - } - } - return true, nil -} - -func procListToString(pl []*control.Process) string { - strs := make([]string, 0, len(pl)) - for _, p := range pl { - strs = append(strs, fmt.Sprintf("%+v", p)) - } - return fmt.Sprintf("[%s]", strings.Join(strs, ",")) -} - -// createWriteableOutputFile creates an output file that can be read and -// written to in the sandbox. -func createWriteableOutputFile(path string) (*os.File, error) { - outputFile, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666) - if err != nil { - return nil, fmt.Errorf("error creating file: %q, %v", path, err) - } - - // Chmod to allow writing after umask. - if err := outputFile.Chmod(0666); err != nil { - return nil, fmt.Errorf("error chmoding file: %q, %v", path, err) - } - return outputFile, nil -} - -func waitForFileNotEmpty(f *os.File) error { - op := func() error { - fi, err := f.Stat() - if err != nil { - return err - } - if fi.Size() == 0 { - return fmt.Errorf("file %q is empty", f.Name()) - } - return nil - } - - return testutil.Poll(op, 30*time.Second) -} - -func waitForFileExist(path string) error { - op := func() error { - if _, err := os.Stat(path); os.IsNotExist(err) { - return err - } - return nil - } - - return testutil.Poll(op, 30*time.Second) -} - -// readOutputNum reads a file at given filepath and returns the int at the -// requested position. -func readOutputNum(file string, position int) (int, error) { - f, err := os.Open(file) - if err != nil { - return 0, fmt.Errorf("error opening file: %q, %v", file, err) - } - - // Ensure that there is content in output file. - if err := waitForFileNotEmpty(f); err != nil { - return 0, fmt.Errorf("error waiting for output file: %v", err) - } - - b, err := ioutil.ReadAll(f) - if err != nil { - return 0, fmt.Errorf("error reading file: %v", err) - } - if len(b) == 0 { - return 0, fmt.Errorf("error no content was read") - } - - // Strip leading null bytes caused by file offset not being 0 upon restore. - b = bytes.Trim(b, "\x00") - nums := strings.Split(string(b), "\n") - - if position >= len(nums) { - return 0, fmt.Errorf("position %v is not within the length of content %v", position, nums) - } - if position == -1 { - // Expectation of newline at the end of last position. - position = len(nums) - 2 - } - num, err := strconv.Atoi(nums[position]) - if err != nil { - return 0, fmt.Errorf("error getting number from file: %v", err) - } - return num, nil -} - -// run starts the sandbox and waits for it to exit, checking that the -// application succeeded. -func run(spec *specs.Spec, conf *boot.Config) error { - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - return fmt.Errorf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create, start and wait for the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - Attached: true, - } - ws, err := Run(conf, args) - if err != nil { - return fmt.Errorf("running container: %v", err) - } - if !ws.Exited() || ws.ExitStatus() != 0 { - return fmt.Errorf("container failed, waitStatus: %v", ws) - } - return nil -} - -type configOption int - -const ( - overlay configOption = iota - kvm - nonExclusiveFS -) - -var noOverlay = []configOption{kvm, nonExclusiveFS} -var all = append(noOverlay, overlay) - -// configs generates different configurations to run tests. -func configs(opts ...configOption) []*boot.Config { - // Always load the default config. - cs := []*boot.Config{testutil.TestConfig()} - - for _, o := range opts { - c := testutil.TestConfig() - switch o { - case overlay: - c.Overlay = true - case kvm: - // TODO(b/112165693): KVM tests are flaky. Disable until fixed. - continue - - c.Platform = platforms.KVM - case nonExclusiveFS: - c.FileAccess = boot.FileAccessShared - default: - panic(fmt.Sprintf("unknown config option %v", o)) - - } - cs = append(cs, c) - } - return cs -} - -// TestLifecycle tests the basic Create/Start/Signal/Destroy container lifecycle. -// It verifies after each step that the container can be loaded from disk, and -// has the correct status. -func TestLifecycle(t *testing.T) { - // Start the child reaper. - childReaper := &testutil.Reaper{} - childReaper.Start() - defer childReaper.Stop() - - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - // The container will just sleep for a long time. We will kill it before - // it finishes sleeping. - spec := testutil.NewSpecWithArgs("sleep", "100") - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // expectedPL lists the expected process state of the container. - expectedPL := []*control.Process{ - { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", - Threads: []kernel.ThreadID{1}, - }, - } - // Create the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - - // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) - if err != nil { - t.Fatalf("error loading container: %v", err) - } - if got, want := c.Status, Created; got != want { - t.Errorf("container status got %v, want %v", got, want) - } - - // List should return the container id. - ids, err := List(rootDir) - if err != nil { - t.Fatalf("error listing containers: %v", err) - } - if got, want := ids, []string{args.ID}; !reflect.DeepEqual(got, want) { - t.Errorf("container list got %v, want %v", got, want) - } - - // Start the container. - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) - if err != nil { - t.Fatalf("error loading container: %v", err) - } - if got, want := c.Status, Running; got != want { - t.Errorf("container status got %v, want %v", got, want) - } - - // Verify that "sleep 100" is running. - if err := waitForProcessList(c, expectedPL); err != nil { - t.Error(err) - } - - // Wait on the container. - var wg sync.WaitGroup - wg.Add(1) - ch := make(chan struct{}) - go func() { - ch <- struct{}{} - ws, err := c.Wait() - if err != nil { - t.Fatalf("error waiting on container: %v", err) - } - if got, want := ws.Signal(), syscall.SIGTERM; got != want { - t.Fatalf("got signal %v, want %v", got, want) - } - wg.Done() - }() - - // Wait a bit to ensure that we've started waiting on the - // container before we signal. - <-ch - time.Sleep(100 * time.Millisecond) - // Send the container a SIGTERM which will cause it to stop. - if err := c.SignalContainer(syscall.SIGTERM, false); err != nil { - t.Fatalf("error sending signal %v to container: %v", syscall.SIGTERM, err) - } - // Wait for it to die. - wg.Wait() - - // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) - if err != nil { - t.Fatalf("error loading container: %v", err) - } - if got, want := c.Status, Stopped; got != want { - t.Errorf("container status got %v, want %v", got, want) - } - - // Destroy the container. - if err := c.Destroy(); err != nil { - t.Fatalf("error destroying container: %v", err) - } - - // List should not return the container id. - ids, err = List(rootDir) - if err != nil { - t.Fatalf("error listing containers: %v", err) - } - if len(ids) != 0 { - t.Errorf("expected container list to be empty, but got %v", ids) - } - - // Loading the container by id should fail. - if _, err = Load(rootDir, args.ID); err == nil { - t.Errorf("expected loading destroyed container to fail, but it did not") - } - } -} - -// Test the we can execute the application with different path formats. -func TestExePath(t *testing.T) { - // Create two directories that will be prepended to PATH. - firstPath, err := ioutil.TempDir(testutil.TmpDir(), "first") - if err != nil { - t.Fatal(err) - } - secondPath, err := ioutil.TempDir(testutil.TmpDir(), "second") - if err != nil { - t.Fatal(err) - } - - // Create two minimal executables in the second path, two of which - // will be masked by files in first path. - for _, p := range []string{"unmasked", "masked1", "masked2"} { - path := filepath.Join(secondPath, p) - f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0777) - if err != nil { - t.Fatal(err) - } - defer f.Close() - if _, err := io.WriteString(f, "#!/bin/true\n"); err != nil { - t.Fatal(err) - } - } - - // Create a non-executable file in the first path which masks a healthy - // executable in the second. - nonExecutable := filepath.Join(firstPath, "masked1") - f2, err := os.OpenFile(nonExecutable, os.O_CREATE|os.O_EXCL, 0666) - if err != nil { - t.Fatal(err) - } - f2.Close() - - // Create a non-regular file in the first path which masks a healthy - // executable in the second. - nonRegular := filepath.Join(firstPath, "masked2") - if err := os.Mkdir(nonRegular, 0777); err != nil { - t.Fatal(err) - } - - for _, conf := range configs(overlay) { - t.Logf("Running test with conf: %+v", conf) - for _, test := range []struct { - path string - success bool - }{ - {path: "true", success: true}, - {path: "bin/true", success: true}, - {path: "/bin/true", success: true}, - {path: "thisfiledoesntexit", success: false}, - {path: "bin/thisfiledoesntexit", success: false}, - {path: "/bin/thisfiledoesntexit", success: false}, - - {path: "unmasked", success: true}, - {path: filepath.Join(firstPath, "unmasked"), success: false}, - {path: filepath.Join(secondPath, "unmasked"), success: true}, - - {path: "masked1", success: true}, - {path: filepath.Join(firstPath, "masked1"), success: false}, - {path: filepath.Join(secondPath, "masked1"), success: true}, - - {path: "masked2", success: true}, - {path: filepath.Join(firstPath, "masked2"), success: false}, - {path: filepath.Join(secondPath, "masked2"), success: true}, - } { - spec := testutil.NewSpecWithArgs(test.path) - spec.Process.Env = []string{ - fmt.Sprintf("PATH=%s:%s:%s", firstPath, secondPath, os.Getenv("PATH")), - } - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("exec: %s, error setting up container: %v", test.path, err) - } - - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - Attached: true, - } - ws, err := Run(conf, args) - - os.RemoveAll(rootDir) - os.RemoveAll(bundleDir) - - if test.success { - if err != nil { - t.Errorf("exec: %s, error running container: %v", test.path, err) - } - if ws.ExitStatus() != 0 { - t.Errorf("exec: %s, got exit status %v want %v", test.path, ws.ExitStatus(), 0) - } - } else { - if err == nil { - t.Errorf("exec: %s, got: no error, want: error", test.path) - } - } - } - } -} - -// Test the we can retrieve the application exit status from the container. -func TestAppExitStatus(t *testing.T) { - // First container will succeed. - succSpec := testutil.NewSpecWithArgs("true") - conf := testutil.TestConfig() - rootDir, bundleDir, err := testutil.SetupContainer(succSpec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: succSpec, - BundleDir: bundleDir, - Attached: true, - } - ws, err := Run(conf, args) - if err != nil { - t.Fatalf("error running container: %v", err) - } - if ws.ExitStatus() != 0 { - t.Errorf("got exit status %v want %v", ws.ExitStatus(), 0) - } - - // Second container exits with non-zero status. - wantStatus := 123 - errSpec := testutil.NewSpecWithArgs("bash", "-c", fmt.Sprintf("exit %d", wantStatus)) - - rootDir2, bundleDir2, err := testutil.SetupContainer(errSpec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir2) - defer os.RemoveAll(bundleDir2) - - args2 := Args{ - ID: testutil.UniqueContainerID(), - Spec: errSpec, - BundleDir: bundleDir2, - Attached: true, - } - ws, err = Run(conf, args2) - if err != nil { - t.Fatalf("error running container: %v", err) - } - if ws.ExitStatus() != wantStatus { - t.Errorf("got exit status %v want %v", ws.ExitStatus(), wantStatus) - } -} - -// TestExec verifies that a container can exec a new program. -func TestExec(t *testing.T) { - for _, conf := range configs(overlay) { - t.Logf("Running test with conf: %+v", conf) - - const uid = 343 - spec := testutil.NewSpecWithArgs("sleep", "100") - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont.Destroy() - if err := cont.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // expectedPL lists the expected process state of the container. - expectedPL := []*control.Process{ - { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", - Threads: []kernel.ThreadID{1}, - }, - { - UID: uid, - PID: 2, - PPID: 0, - C: 0, - Cmd: "sleep", - Threads: []kernel.ThreadID{2}, - }, - } - - // Verify that "sleep 100" is running. - if err := waitForProcessList(cont, expectedPL[:1]); err != nil { - t.Error(err) - } - - execArgs := &control.ExecArgs{ - Filename: "/bin/sleep", - Argv: []string{"/bin/sleep", "5"}, - WorkingDirectory: "/", - KUID: uid, - } - - // Verify that "sleep 100" and "sleep 5" are running after exec. - // First, start running exec (whick blocks). - status := make(chan error, 1) - go func() { - exitStatus, err := cont.executeSync(execArgs) - if err != nil { - log.Debugf("error executing: %v", err) - status <- err - } else if exitStatus != 0 { - log.Debugf("bad status: %d", exitStatus) - status <- fmt.Errorf("failed with exit status: %v", exitStatus) - } else { - status <- nil - } - }() - - if err := waitForProcessList(cont, expectedPL); err != nil { - t.Fatal(err) - } - - // Ensure that exec finished without error. - select { - case <-time.After(10 * time.Second): - t.Fatalf("container timed out waiting for exec to finish.") - case st := <-status: - if st != nil { - t.Errorf("container failed to exec %v: %v", args, err) - } - } - } -} - -// TestKillPid verifies that we can signal individual exec'd processes. -func TestKillPid(t *testing.T) { - for _, conf := range configs(overlay) { - t.Logf("Running test with conf: %+v", conf) - - app, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - const nProcs = 4 - spec := testutil.NewSpecWithArgs(app, "task-tree", "--depth", strconv.Itoa(nProcs-1), "--width=1", "--pause=true") - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont.Destroy() - if err := cont.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Verify that all processes are running. - if err := waitForProcessCount(cont, nProcs); err != nil { - t.Fatalf("timed out waiting for processes to start: %v", err) - } - - // Kill the child process with the largest PID. - procs, err := cont.Processes() - if err != nil { - t.Fatalf("failed to get process list: %v", err) - } - var pid int32 - for _, p := range procs { - if pid < int32(p.PID) { - pid = int32(p.PID) - } - } - if err := cont.SignalProcess(syscall.SIGKILL, pid); err != nil { - t.Fatalf("failed to signal process %d: %v", pid, err) - } - - // Verify that one process is gone. - if err := waitForProcessCount(cont, nProcs-1); err != nil { - t.Fatal(err) - } - - procs, err = cont.Processes() - if err != nil { - t.Fatalf("failed to get process list: %v", err) - } - for _, p := range procs { - if pid == int32(p.PID) { - t.Fatalf("pid %d is still alive, which should be killed", pid) - } - } - } -} - -// TestCheckpointRestore creates a container that continuously writes successive integers -// to a file. To test checkpoint and restore functionality, the container is -// checkpointed and the last number printed to the file is recorded. Then, it is restored in two -// new containers and the first number printed from these containers is checked. Both should -// be the next consecutive number after the last number from the checkpointed container. -func TestCheckpointRestore(t *testing.T) { - // Skip overlay because test requires writing to host file. - for _, conf := range configs(noOverlay...) { - t.Logf("Running test with conf: %+v", conf) - - dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test") - if err != nil { - t.Fatalf("ioutil.TempDir failed: %v", err) - } - if err := os.Chmod(dir, 0777); err != nil { - t.Fatalf("error chmoding file: %q, %v", dir, err) - } - - outputPath := filepath.Join(dir, "output") - outputFile, err := createWriteableOutputFile(outputPath) - if err != nil { - t.Fatalf("error creating output file: %v", err) - } - defer outputFile.Close() - - script := fmt.Sprintf("for ((i=0; ;i++)); do echo $i >> %q; sleep 1; done", outputPath) - spec := testutil.NewSpecWithArgs("bash", "-c", script) - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont.Destroy() - if err := cont.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Set the image path, which is where the checkpoint image will be saved. - imagePath := filepath.Join(dir, "test-image-file") - - // Create the image file and open for writing. - file, err := os.OpenFile(imagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644) - if err != nil { - t.Fatalf("error opening new file at imagePath: %v", err) - } - defer file.Close() - - // Wait until application has ran. - if err := waitForFileNotEmpty(outputFile); err != nil { - t.Fatalf("Failed to wait for output file: %v", err) - } - - // Checkpoint running container; save state into new file. - if err := cont.Checkpoint(file); err != nil { - t.Fatalf("error checkpointing container to empty file: %v", err) - } - defer os.RemoveAll(imagePath) - - lastNum, err := readOutputNum(outputPath, -1) - if err != nil { - t.Fatalf("error with outputFile: %v", err) - } - - // Delete and recreate file before restoring. - if err := os.Remove(outputPath); err != nil { - t.Fatalf("error removing file") - } - outputFile2, err := createWriteableOutputFile(outputPath) - if err != nil { - t.Fatalf("error creating output file: %v", err) - } - defer outputFile2.Close() - - // Restore into a new container. - args2 := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont2, err := New(conf, args2) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont2.Destroy() - - if err := cont2.Restore(spec, conf, imagePath); err != nil { - t.Fatalf("error restoring container: %v", err) - } - - // Wait until application has ran. - if err := waitForFileNotEmpty(outputFile2); err != nil { - t.Fatalf("Failed to wait for output file: %v", err) - } - - firstNum, err := readOutputNum(outputPath, 0) - if err != nil { - t.Fatalf("error with outputFile: %v", err) - } - - // Check that lastNum is one less than firstNum and that the container picks - // up from where it left off. - if lastNum+1 != firstNum { - t.Errorf("error numbers not in order, previous: %d, next: %d", lastNum, firstNum) - } - cont2.Destroy() - - // Restore into another container! - // Delete and recreate file before restoring. - if err := os.Remove(outputPath); err != nil { - t.Fatalf("error removing file") - } - outputFile3, err := createWriteableOutputFile(outputPath) - if err != nil { - t.Fatalf("error creating output file: %v", err) - } - defer outputFile3.Close() - - // Restore into a new container. - args3 := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont3, err := New(conf, args3) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont3.Destroy() - - if err := cont3.Restore(spec, conf, imagePath); err != nil { - t.Fatalf("error restoring container: %v", err) - } - - // Wait until application has ran. - if err := waitForFileNotEmpty(outputFile3); err != nil { - t.Fatalf("Failed to wait for output file: %v", err) - } - - firstNum2, err := readOutputNum(outputPath, 0) - if err != nil { - t.Fatalf("error with outputFile: %v", err) - } - - // Check that lastNum is one less than firstNum and that the container picks - // up from where it left off. - if lastNum+1 != firstNum2 { - t.Errorf("error numbers not in order, previous: %d, next: %d", lastNum, firstNum2) - } - cont3.Destroy() - } -} - -// TestUnixDomainSockets checks that Checkpoint/Restore works in cases -// with filesystem Unix Domain Socket use. -func TestUnixDomainSockets(t *testing.T) { - // Skip overlay because test requires writing to host file. - for _, conf := range configs(noOverlay...) { - t.Logf("Running test with conf: %+v", conf) - - // UDS path is limited to 108 chars for compatibility with older systems. - // Use '/tmp' (instead of testutil.TmpDir) to ensure the size limit is - // not exceeded. Assumes '/tmp' exists in the system. - dir, err := ioutil.TempDir("/tmp", "uds-test") - if err != nil { - t.Fatalf("ioutil.TempDir failed: %v", err) - } - defer os.RemoveAll(dir) - - outputPath := filepath.Join(dir, "uds_output") - outputFile, err := os.OpenFile(outputPath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666) - if err != nil { - t.Fatalf("error creating output file: %v", err) - } - defer outputFile.Close() - - app, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - socketPath := filepath.Join(dir, "uds_socket") - defer os.Remove(socketPath) - - spec := testutil.NewSpecWithArgs(app, "uds", "--file", outputPath, "--socket", socketPath) - spec.Process.User = specs.User{ - UID: uint32(os.Getuid()), - GID: uint32(os.Getgid()), - } - spec.Mounts = []specs.Mount{{ - Type: "bind", - Destination: dir, - Source: dir, - }} - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont.Destroy() - if err := cont.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Set the image path, the location where the checkpoint image will be saved. - imagePath := filepath.Join(dir, "test-image-file") - - // Create the image file and open for writing. - file, err := os.OpenFile(imagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644) - if err != nil { - t.Fatalf("error opening new file at imagePath: %v", err) - } - defer file.Close() - defer os.RemoveAll(imagePath) - - // Wait until application has ran. - if err := waitForFileNotEmpty(outputFile); err != nil { - t.Fatalf("Failed to wait for output file: %v", err) - } - - // Checkpoint running container; save state into new file. - if err := cont.Checkpoint(file); err != nil { - t.Fatalf("error checkpointing container to empty file: %v", err) - } - - // Read last number outputted before checkpoint. - lastNum, err := readOutputNum(outputPath, -1) - if err != nil { - t.Fatalf("error with outputFile: %v", err) - } - - // Delete and recreate file before restoring. - if err := os.Remove(outputPath); err != nil { - t.Fatalf("error removing file") - } - outputFile2, err := os.OpenFile(outputPath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666) - if err != nil { - t.Fatalf("error creating output file: %v", err) - } - defer outputFile2.Close() - - // Restore into a new container. - argsRestore := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - contRestore, err := New(conf, argsRestore) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer contRestore.Destroy() - - if err := contRestore.Restore(spec, conf, imagePath); err != nil { - t.Fatalf("error restoring container: %v", err) - } - - // Wait until application has ran. - if err := waitForFileNotEmpty(outputFile2); err != nil { - t.Fatalf("Failed to wait for output file: %v", err) - } - - // Read first number outputted after restore. - firstNum, err := readOutputNum(outputPath, 0) - if err != nil { - t.Fatalf("error with outputFile: %v", err) - } - - // Check that lastNum is one less than firstNum. - if lastNum+1 != firstNum { - t.Errorf("error numbers not consecutive, previous: %d, next: %d", lastNum, firstNum) - } - contRestore.Destroy() - } -} - -// TestPauseResume tests that we can successfully pause and resume a container. -// The container will keep touching a file to indicate it's running. The test -// pauses the container, removes the file, and checks that it doesn't get -// recreated. Then it resumes the container, verify that the file gets created -// again. -func TestPauseResume(t *testing.T) { - for _, conf := range configs(noOverlay...) { - t.Run(fmt.Sprintf("conf: %+v", conf), func(t *testing.T) { - t.Logf("Running test with conf: %+v", conf) - - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "lock") - if err != nil { - t.Fatalf("error creating temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - running := path.Join(tmpDir, "running") - script := fmt.Sprintf("while [[ true ]]; do touch %q; sleep 0.1; done", running) - spec := testutil.NewSpecWithArgs("/bin/bash", "-c", script) - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont.Destroy() - if err := cont.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Wait until container starts running, observed by the existence of running - // file. - if err := waitForFileExist(running); err != nil { - t.Errorf("error waiting for container to start: %v", err) - } - - // Pause the running container. - if err := cont.Pause(); err != nil { - t.Errorf("error pausing container: %v", err) - } - if got, want := cont.Status, Paused; got != want { - t.Errorf("container status got %v, want %v", got, want) - } - - if err := os.Remove(running); err != nil { - t.Fatalf("os.Remove(%q) failed: %v", running, err) - } - // Script touches the file every 100ms. Give a bit a time for it to run to - // catch the case that pause didn't work. - time.Sleep(200 * time.Millisecond) - if _, err := os.Stat(running); !os.IsNotExist(err) { - t.Fatalf("container did not pause: file exist check: %v", err) - } - - // Resume the running container. - if err := cont.Resume(); err != nil { - t.Errorf("error pausing container: %v", err) - } - if got, want := cont.Status, Running; got != want { - t.Errorf("container status got %v, want %v", got, want) - } - - // Verify that the file is once again created by container. - if err := waitForFileExist(running); err != nil { - t.Fatalf("error resuming container: file exist check: %v", err) - } - }) - } -} - -// TestPauseResumeStatus makes sure that the statuses are set correctly -// with calls to pause and resume and that pausing and resuming only -// occurs given the correct state. -func TestPauseResumeStatus(t *testing.T) { - spec := testutil.NewSpecWithArgs("sleep", "20") - conf := testutil.TestConfig() - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont.Destroy() - if err := cont.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Pause the running container. - if err := cont.Pause(); err != nil { - t.Errorf("error pausing container: %v", err) - } - if got, want := cont.Status, Paused; got != want { - t.Errorf("container status got %v, want %v", got, want) - } - - // Try to Pause again. Should cause error. - if err := cont.Pause(); err == nil { - t.Errorf("error pausing container that was already paused: %v", err) - } - if got, want := cont.Status, Paused; got != want { - t.Errorf("container status got %v, want %v", got, want) - } - - // Resume the running container. - if err := cont.Resume(); err != nil { - t.Errorf("error resuming container: %v", err) - } - if got, want := cont.Status, Running; got != want { - t.Errorf("container status got %v, want %v", got, want) - } - - // Try to resume again. Should cause error. - if err := cont.Resume(); err == nil { - t.Errorf("error resuming container already running: %v", err) - } - if got, want := cont.Status, Running; got != want { - t.Errorf("container status got %v, want %v", got, want) - } -} - -// TestCapabilities verifies that: -// - Running exec as non-root UID and GID will result in an error (because the -// executable file can't be read). -// - Running exec as non-root with CAP_DAC_OVERRIDE succeeds because it skips -// this check. -func TestCapabilities(t *testing.T) { - // Pick uid/gid different than ours. - uid := auth.KUID(os.Getuid() + 1) - gid := auth.KGID(os.Getgid() + 1) - - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - spec := testutil.NewSpecWithArgs("sleep", "100") - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont.Destroy() - if err := cont.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // expectedPL lists the expected process state of the container. - expectedPL := []*control.Process{ - { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", - Threads: []kernel.ThreadID{1}, - }, - { - UID: uid, - PID: 2, - PPID: 0, - C: 0, - Cmd: "exe", - Threads: []kernel.ThreadID{2}, - }, - } - if err := waitForProcessList(cont, expectedPL[:1]); err != nil { - t.Fatalf("Failed to wait for sleep to start, err: %v", err) - } - - // Create an executable that can't be run with the specified UID:GID. - // This shouldn't be callable within the container until we add the - // CAP_DAC_OVERRIDE capability to skip the access check. - exePath := filepath.Join(rootDir, "exe") - if err := ioutil.WriteFile(exePath, []byte("#!/bin/sh\necho hello"), 0770); err != nil { - t.Fatalf("couldn't create executable: %v", err) - } - defer os.Remove(exePath) - - // Need to traverse the intermediate directory. - os.Chmod(rootDir, 0755) - - execArgs := &control.ExecArgs{ - Filename: exePath, - Argv: []string{exePath}, - WorkingDirectory: "/", - KUID: uid, - KGID: gid, - Capabilities: &auth.TaskCapabilities{}, - } - - // "exe" should fail because we don't have the necessary permissions. - if _, err := cont.executeSync(execArgs); err == nil { - t.Fatalf("container executed without error, but an error was expected") - } - - // Now we run with the capability enabled and should succeed. - execArgs.Capabilities = &auth.TaskCapabilities{ - EffectiveCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE), - } - // "exe" should not fail this time. - if _, err := cont.executeSync(execArgs); err != nil { - t.Fatalf("container failed to exec %v: %v", args, err) - } - } -} - -// TestRunNonRoot checks that sandbox can be configured when running as -// non-privileged user. -func TestRunNonRoot(t *testing.T) { - for _, conf := range configs(noOverlay...) { - t.Logf("Running test with conf: %+v", conf) - - spec := testutil.NewSpecWithArgs("/bin/true") - - // Set a random user/group with no access to "blocked" dir. - spec.Process.User.UID = 343 - spec.Process.User.GID = 2401 - spec.Process.Capabilities = nil - - // User running inside container can't list '$TMP/blocked' and would fail to - // mount it. - dir, err := ioutil.TempDir(testutil.TmpDir(), "blocked") - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - if err := os.Chmod(dir, 0700); err != nil { - t.Fatalf("os.MkDir(%q) failed: %v", dir, err) - } - dir = path.Join(dir, "test") - if err := os.Mkdir(dir, 0755); err != nil { - t.Fatalf("os.MkDir(%q) failed: %v", dir, err) - } - - src, err := ioutil.TempDir(testutil.TmpDir(), "src") - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: dir, - Source: src, - Type: "bind", - }) - - if err := run(spec, conf); err != nil { - t.Fatalf("error running sandbox: %v", err) - } - } -} - -// TestMountNewDir checks that runsc will create destination directory if it -// doesn't exit. -func TestMountNewDir(t *testing.T) { - for _, conf := range configs(overlay) { - t.Logf("Running test with conf: %+v", conf) - - root, err := ioutil.TempDir(testutil.TmpDir(), "root") - if err != nil { - t.Fatal("ioutil.TempDir() failed:", err) - } - - srcDir := path.Join(root, "src", "dir", "anotherdir") - if err := os.MkdirAll(srcDir, 0755); err != nil { - t.Fatalf("os.MkDir(%q) failed: %v", srcDir, err) - } - - mountDir := path.Join(root, "dir", "anotherdir") - - spec := testutil.NewSpecWithArgs("/bin/ls", mountDir) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: mountDir, - Source: srcDir, - Type: "bind", - }) - - if err := run(spec, conf); err != nil { - t.Fatalf("error running sandbox: %v", err) - } - } -} - -func TestReadonlyRoot(t *testing.T) { - for _, conf := range configs(overlay) { - t.Logf("Running test with conf: %+v", conf) - - spec := testutil.NewSpecWithArgs("/bin/touch", "/foo") - spec.Root.Readonly = true - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create, start and wait for the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - ws, err := c.Wait() - if err != nil { - t.Fatalf("error waiting on container: %v", err) - } - if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { - t.Fatalf("container failed, waitStatus: %v", ws) - } - } -} - -func TestUIDMap(t *testing.T) { - for _, conf := range configs(noOverlay...) { - t.Logf("Running test with conf: %+v", conf) - testDir, err := ioutil.TempDir(testutil.TmpDir(), "test-mount") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(testDir) - testFile := path.Join(testDir, "testfile") - - spec := testutil.NewSpecWithArgs("touch", "/tmp/testfile") - uid := os.Getuid() - gid := os.Getgid() - spec.Linux = &specs.Linux{ - Namespaces: []specs.LinuxNamespace{ - {Type: specs.UserNamespace}, - {Type: specs.PIDNamespace}, - {Type: specs.MountNamespace}, - }, - UIDMappings: []specs.LinuxIDMapping{ - { - ContainerID: 0, - HostID: uint32(uid), - Size: 1, - }, - }, - GIDMappings: []specs.LinuxIDMapping{ - { - ContainerID: 0, - HostID: uint32(gid), - Size: 1, - }, - }, - } - - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp", - Source: testDir, - Type: "bind", - }) - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create, start and wait for the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - ws, err := c.Wait() - if err != nil { - t.Fatalf("error waiting on container: %v", err) - } - if !ws.Exited() || ws.ExitStatus() != 0 { - t.Fatalf("container failed, waitStatus: %v", ws) - } - st := syscall.Stat_t{} - if err := syscall.Stat(testFile, &st); err != nil { - t.Fatalf("error stat /testfile: %v", err) - } - - if st.Uid != uint32(uid) || st.Gid != uint32(gid) { - t.Fatalf("UID: %d (%d) GID: %d (%d)", st.Uid, uid, st.Gid, gid) - } - } -} - -func TestReadonlyMount(t *testing.T) { - for _, conf := range configs(overlay) { - t.Logf("Running test with conf: %+v", conf) - - dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount") - spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file")) - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: dir, - Source: dir, - Type: "bind", - Options: []string{"ro"}, - }) - spec.Root.Readonly = false - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create, start and wait for the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - ws, err := c.Wait() - if err != nil { - t.Fatalf("error waiting on container: %v", err) - } - if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { - t.Fatalf("container failed, waitStatus: %v", ws) - } - } -} - -// TestAbbreviatedIDs checks that runsc supports using abbreviated container -// IDs in place of full IDs. -func TestAbbreviatedIDs(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - cids := []string{ - "foo-" + testutil.UniqueContainerID(), - "bar-" + testutil.UniqueContainerID(), - "baz-" + testutil.UniqueContainerID(), - } - for _, cid := range cids { - spec := testutil.NewSpecWithArgs("sleep", "100") - bundleDir, err := testutil.SetupBundleDir(spec) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: cid, - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer cont.Destroy() - } - - // These should all be unambigious. - unambiguous := map[string]string{ - "f": cids[0], - cids[0]: cids[0], - "bar": cids[1], - cids[1]: cids[1], - "baz": cids[2], - cids[2]: cids[2], - } - for shortid, longid := range unambiguous { - if _, err := Load(rootDir, shortid); err != nil { - t.Errorf("%q should resolve to %q: %v", shortid, longid, err) - } - } - - // These should be ambiguous. - ambiguous := []string{ - "b", - "ba", - } - for _, shortid := range ambiguous { - if s, err := Load(rootDir, shortid); err == nil { - t.Errorf("%q should be ambiguous, but resolved to %q", shortid, s.ID) - } - } -} - -func TestGoferExits(t *testing.T) { - spec := testutil.NewSpecWithArgs("/bin/sleep", "10000") - conf := testutil.TestConfig() - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Kill sandbox and expect gofer to exit on its own. - sandboxProc, err := os.FindProcess(c.Sandbox.Pid) - if err != nil { - t.Fatalf("error finding sandbox process: %v", err) - } - if err := sandboxProc.Kill(); err != nil { - t.Fatalf("error killing sandbox process: %v", err) - } - - err = blockUntilWaitable(c.GoferPid) - if err != nil && err != syscall.ECHILD { - t.Errorf("error waiting for gofer to exit: %v", err) - } -} - -func TestRootNotMount(t *testing.T) { - appSym, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - app, err := filepath.EvalSymlinks(appSym) - if err != nil { - t.Fatalf("error resolving %q symlink: %v", appSym, err) - } - log.Infof("App path %q is a symlink to %q", appSym, app) - - static, err := testutil.IsStatic(app) - if err != nil { - t.Fatalf("error reading application binary: %v", err) - } - if !static { - // This happens during race builds; we cannot map in shared - // libraries also, so we need to skip the test. - t.Skip() - } - - root := filepath.Dir(app) - exe := "/" + filepath.Base(app) - log.Infof("Executing %q in %q", exe, root) - - spec := testutil.NewSpecWithArgs(exe, "help") - spec.Root.Path = root - spec.Root.Readonly = true - spec.Mounts = nil - - conf := testutil.TestConfig() - if err := run(spec, conf); err != nil { - t.Fatalf("error running sandbox: %v", err) - } -} - -func TestUserLog(t *testing.T) { - app, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - // sched_rr_get_interval = 148 - not implemented in gvisor. - spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall=148") - conf := testutil.TestConfig() - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - dir, err := ioutil.TempDir(testutil.TmpDir(), "user_log_test") - if err != nil { - t.Fatalf("error creating tmp dir: %v", err) - } - userLog := filepath.Join(dir, "user.log") - - // Create, start and wait for the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - UserLog: userLog, - Attached: true, - } - ws, err := Run(conf, args) - if err != nil { - t.Fatalf("error running container: %v", err) - } - if !ws.Exited() || ws.ExitStatus() != 0 { - t.Fatalf("container failed, waitStatus: %v", ws) - } - - out, err := ioutil.ReadFile(userLog) - if err != nil { - t.Fatalf("error opening user log file %q: %v", userLog, err) - } - if want := "Unsupported syscall: sched_rr_get_interval"; !strings.Contains(string(out), want) { - t.Errorf("user log file doesn't contain %q, out: %s", want, string(out)) - } -} - -func TestWaitOnExitedSandbox(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - // Run a shell that sleeps for 1 second and then exits with a - // non-zero code. - const wantExit = 17 - cmd := fmt.Sprintf("sleep 1; exit %d", wantExit) - spec := testutil.NewSpecWithArgs("/bin/sh", "-c", cmd) - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and Start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Wait on the sandbox. This will make an RPC to the sandbox - // and get the actual exit status of the application. - ws, err := c.Wait() - if err != nil { - t.Fatalf("error waiting on container: %v", err) - } - if got := ws.ExitStatus(); got != wantExit { - t.Errorf("got exit status %d, want %d", got, wantExit) - } - - // Now the sandbox has exited, but the zombie sandbox process - // still exists. Calling Wait() now will return the sandbox - // exit status. - ws, err = c.Wait() - if err != nil { - t.Fatalf("error waiting on container: %v", err) - } - if got := ws.ExitStatus(); got != wantExit { - t.Errorf("got exit status %d, want %d", got, wantExit) - } - } -} - -func TestDestroyNotStarted(t *testing.T) { - spec := testutil.NewSpecWithArgs("/bin/sleep", "100") - conf := testutil.TestConfig() - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create the container and check that it can be destroyed. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - if err := c.Destroy(); err != nil { - t.Fatalf("deleting non-started container failed: %v", err) - } -} - -// TestDestroyStarting attempts to force a race between start and destroy. -func TestDestroyStarting(t *testing.T) { - for i := 0; i < 10; i++ { - spec := testutil.NewSpecWithArgs("/bin/sleep", "100") - conf := testutil.TestConfig() - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create the container and check that it can be destroyed. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - - // Container is not thread safe, so load another instance to run in - // concurrently. - startCont, err := Load(rootDir, args.ID) - if err != nil { - t.Fatalf("error loading container: %v", err) - } - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - defer wg.Done() - // Ignore failures, start can fail if destroy runs first. - startCont.Start(conf) - }() - - wg.Add(1) - go func() { - defer wg.Done() - if err := c.Destroy(); err != nil { - t.Errorf("deleting non-started container failed: %v", err) - } - }() - wg.Wait() - } -} - -func TestCreateWorkingDir(t *testing.T) { - for _, conf := range configs(overlay) { - t.Logf("Running test with conf: %+v", conf) - - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create") - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - dir := path.Join(tmpDir, "new/working/dir") - - // touch will fail if the directory doesn't exist. - spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file")) - spec.Process.Cwd = dir - spec.Root.Readonly = true - - if err := run(spec, conf); err != nil { - t.Fatalf("Error running container: %v", err) - } - } -} - -// TestMountPropagation verifies that mount propagates to slave but not to -// private mounts. -func TestMountPropagation(t *testing.T) { - // Setup dir structure: - // - src: is mounted as shared and is used as source for both private and - // slave mounts - // - dir: will be bind mounted inside src and should propagate to slave - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "mount") - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - src := filepath.Join(tmpDir, "src") - srcMnt := filepath.Join(src, "mnt") - dir := filepath.Join(tmpDir, "dir") - for _, path := range []string{src, srcMnt, dir} { - if err := os.MkdirAll(path, 0777); err != nil { - t.Fatalf("MkdirAll(%q): %v", path, err) - } - } - dirFile := filepath.Join(dir, "file") - f, err := os.Create(dirFile) - if err != nil { - t.Fatalf("os.Create(%q): %v", dirFile, err) - } - f.Close() - - // Setup src as a shared mount. - if err := syscall.Mount(src, src, "bind", syscall.MS_BIND, ""); err != nil { - t.Fatalf("mount(%q, %q, MS_BIND): %v", dir, srcMnt, err) - } - if err := syscall.Mount("", src, "", syscall.MS_SHARED, ""); err != nil { - t.Fatalf("mount(%q, MS_SHARED): %v", srcMnt, err) - } - - spec := testutil.NewSpecWithArgs("sleep", "1000") - - priv := filepath.Join(tmpDir, "priv") - slave := filepath.Join(tmpDir, "slave") - spec.Mounts = []specs.Mount{ - { - Source: src, - Destination: priv, - Type: "bind", - Options: []string{"private"}, - }, - { - Source: src, - Destination: slave, - Type: "bind", - Options: []string{"slave"}, - }, - } - - conf := testutil.TestConfig() - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("creating container: %v", err) - } - defer cont.Destroy() - - if err := cont.Start(conf); err != nil { - t.Fatalf("starting container: %v", err) - } - - // After the container is started, mount dir inside source and check what - // happens to both destinations. - if err := syscall.Mount(dir, srcMnt, "bind", syscall.MS_BIND, ""); err != nil { - t.Fatalf("mount(%q, %q, MS_BIND): %v", dir, srcMnt, err) - } - - // Check that mount didn't propagate to private mount. - privFile := filepath.Join(priv, "mnt", "file") - execArgs := &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "!", "-f", privFile}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { - t.Fatalf("exec: test ! -f %q, ws: %v, err: %v", privFile, ws, err) - } - - // Check that mount propagated to slave mount. - slaveFile := filepath.Join(slave, "mnt", "file") - execArgs = &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "-f", slaveFile}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { - t.Fatalf("exec: test -f %q, ws: %v, err: %v", privFile, ws, err) - } -} - -func TestMountSymlink(t *testing.T) { - for _, conf := range configs(overlay) { - t.Logf("Running test with conf: %+v", conf) - - dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink") - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - - source := path.Join(dir, "source") - target := path.Join(dir, "target") - for _, path := range []string{source, target} { - if err := os.MkdirAll(path, 0777); err != nil { - t.Fatalf("os.MkdirAll(): %v", err) - } - } - f, err := os.Create(path.Join(source, "file")) - if err != nil { - t.Fatalf("os.Create(): %v", err) - } - f.Close() - - link := path.Join(dir, "link") - if err := os.Symlink(target, link); err != nil { - t.Fatalf("os.Symlink(%q, %q): %v", target, link, err) - } - - spec := testutil.NewSpecWithArgs("/bin/sleep", "1000") - - // Mount to a symlink to ensure the mount code will follow it and mount - // at the symlink target. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Type: "bind", - Destination: link, - Source: source, - }) - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("creating container: %v", err) - } - defer cont.Destroy() - - if err := cont.Start(conf); err != nil { - t.Fatalf("starting container: %v", err) - } - - // Check that symlink was resolved and mount was created where the symlink - // is pointing to. - file := path.Join(target, "file") - execArgs := &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "-f", file}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { - t.Fatalf("exec: test -f %q, ws: %v, err: %v", file, ws, err) - } - } -} - -// Check that --net-raw disables the CAP_NET_RAW capability. -func TestNetRaw(t *testing.T) { - capNetRaw := strconv.FormatUint(bits.MaskOf64(int(linux.CAP_NET_RAW)), 10) - app, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - for _, enableRaw := range []bool{true, false} { - conf := testutil.TestConfig() - conf.EnableRaw = enableRaw - - test := "--enabled" - if !enableRaw { - test = "--disabled" - } - - spec := testutil.NewSpecWithArgs(app, "capability", test, capNetRaw) - if err := run(spec, conf); err != nil { - t.Fatalf("Error running container: %v", err) - } - } -} - -// TestOverlayfsStaleRead most basic test that '--overlayfs-stale-read' works. -func TestOverlayfsStaleRead(t *testing.T) { - conf := testutil.TestConfig() - conf.OverlayfsStaleRead = true - - in, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.in") - if err != nil { - t.Fatalf("ioutil.TempFile() failed: %v", err) - } - defer in.Close() - if _, err := in.WriteString("stale data"); err != nil { - t.Fatalf("in.Write() failed: %v", err) - } - - out, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.out") - if err != nil { - t.Fatalf("ioutil.TempFile() failed: %v", err) - } - defer out.Close() - - const want = "foobar" - cmd := fmt.Sprintf("cat %q >&2 && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name()) - spec := testutil.NewSpecWithArgs("/bin/bash", "-c", cmd) - if err := run(spec, conf); err != nil { - t.Fatalf("Error running container: %v", err) - } - - gotBytes, err := ioutil.ReadAll(out) - if err != nil { - t.Fatalf("out.Read() failed: %v", err) - } - got := strings.TrimSpace(string(gotBytes)) - if want != got { - t.Errorf("Wrong content in out file, got: %q. want: %q", got, want) - } -} - -// TestTTYField checks TTY field returned by container.Processes(). -func TestTTYField(t *testing.T) { - stop := testutil.StartReaper() - defer stop() - - testApp, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - testCases := []struct { - name string - useTTY bool - wantTTYField string - }{ - { - name: "no tty", - useTTY: false, - wantTTYField: "?", - }, - { - name: "tty used", - useTTY: true, - wantTTYField: "pts/0", - }, - } - - for _, test := range testCases { - t.Run(test.name, func(t *testing.T) { - conf := testutil.TestConfig() - - // We will run /bin/sleep, possibly with an open TTY. - cmd := []string{"/bin/sleep", "10000"} - if test.useTTY { - // Run inside the "pty-runner". - cmd = append([]string{testApp, "pty-runner"}, cmd...) - } - - spec := testutil.NewSpecWithArgs(cmd...) - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Wait for sleep to be running, and check the TTY - // field. - var gotTTYField string - cb := func() error { - ps, err := c.Processes() - if err != nil { - err = fmt.Errorf("error getting process data from container: %v", err) - return &backoff.PermanentError{Err: err} - } - for _, p := range ps { - if strings.Contains(p.Cmd, "sleep") { - gotTTYField = p.TTY - return nil - } - } - return fmt.Errorf("sleep not running") - } - if err := testutil.Poll(cb, 30*time.Second); err != nil { - t.Fatalf("error waiting for sleep process: %v", err) - } - - if gotTTYField != test.wantTTYField { - t.Errorf("tty field got %q, want %q", gotTTYField, test.wantTTYField) - } - }) - } -} - -// executeSync synchronously executes a new process. -func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { - pid, err := cont.Execute(args) - if err != nil { - return 0, fmt.Errorf("error executing: %v", err) - } - ws, err := cont.WaitPID(pid) - if err != nil { - return 0, fmt.Errorf("error waiting: %v", err) - } - return ws, nil -} - -func TestMain(m *testing.M) { - log.SetLevel(log.Debug) - flag.Parse() - if err := testutil.ConfigureExePath(); err != nil { - panic(err.Error()) - } - specutils.MaybeRunAsRoot() - os.Exit(m.Run()) -} diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go deleted file mode 100644 index 2da93ec5b..000000000 --- a/runsc/container/multi_container_test.go +++ /dev/null @@ -1,1708 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package container - -import ( - "fmt" - "io/ioutil" - "math" - "os" - "path" - "path/filepath" - "strings" - "syscall" - "testing" - "time" - - specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/runsc/boot" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" -) - -func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { - var specs []*specs.Spec - var ids []string - rootID := testutil.UniqueContainerID() - - for i, cmd := range cmds { - spec := testutil.NewSpecWithArgs(cmd...) - if i == 0 { - spec.Annotations = map[string]string{ - specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeSandbox, - } - ids = append(ids, rootID) - } else { - spec.Annotations = map[string]string{ - specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer, - specutils.ContainerdSandboxIDAnnotation: rootID, - } - ids = append(ids, testutil.UniqueContainerID()) - } - specs = append(specs, spec) - } - return specs, ids -} - -func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) { - if len(conf.RootDir) == 0 { - panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") - } - - var containers []*Container - var bundles []string - cleanup := func() { - for _, c := range containers { - c.Destroy() - } - for _, b := range bundles { - os.RemoveAll(b) - } - } - for i, spec := range specs { - bundleDir, err := testutil.SetupBundleDir(spec) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("error setting up container: %v", err) - } - bundles = append(bundles, bundleDir) - - args := Args{ - ID: ids[i], - Spec: spec, - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("error creating container: %v", err) - } - containers = append(containers, cont) - - if err := cont.Start(conf); err != nil { - cleanup() - return nil, nil, fmt.Errorf("error starting container: %v", err) - } - } - return containers, cleanup, nil -} - -type execDesc struct { - c *Container - cmd []string - want int - desc string -} - -func execMany(execs []execDesc) error { - for _, exec := range execs { - args := &control.ExecArgs{Argv: exec.cmd} - if ws, err := exec.c.executeSync(args); err != nil { - return fmt.Errorf("error executing %+v: %v", args, err) - } else if ws.ExitStatus() != exec.want { - return fmt.Errorf("%q: exec %q got exit status: %d, want: %d", exec.desc, exec.cmd, ws.ExitStatus(), exec.want) - } - } - return nil -} - -func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) { - for _, spec := range pod { - spec.Annotations[boot.MountPrefix+name+".source"] = mount.Source - spec.Annotations[boot.MountPrefix+name+".type"] = mount.Type - spec.Annotations[boot.MountPrefix+name+".share"] = "pod" - if len(mount.Options) > 0 { - spec.Annotations[boot.MountPrefix+name+".options"] = strings.Join(mount.Options, ",") - } - } -} - -// TestMultiContainerSanity checks that it is possible to run 2 dead-simple -// containers in the same sandbox. -func TestMultiContainerSanity(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf.RootDir = rootDir - - // Setup the containers. - sleep := []string{"sleep", "100"} - specs, ids := createSpecs(sleep, sleep) - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Check via ps that multiple processes are running. - expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(containers[0], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - expectedPL = []*control.Process{ - {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, - } - if err := waitForProcessList(containers[1], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - } -} - -// TestMultiPIDNS checks that it is possible to run 2 dead-simple -// containers in the same sandbox with different pidns. -func TestMultiPIDNS(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf.RootDir = rootDir - - // Setup the containers. - sleep := []string{"sleep", "100"} - testSpecs, ids := createSpecs(sleep, sleep) - testSpecs[1].Linux = &specs.Linux{ - Namespaces: []specs.LinuxNamespace{ - { - Type: "pid", - }, - }, - } - - containers, cleanup, err := startContainers(conf, testSpecs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Check via ps that multiple processes are running. - expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(containers[0], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(containers[1], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - } -} - -// TestMultiPIDNSPath checks the pidns path. -func TestMultiPIDNSPath(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf.RootDir = rootDir - - // Setup the containers. - sleep := []string{"sleep", "100"} - testSpecs, ids := createSpecs(sleep, sleep, sleep) - testSpecs[0].Linux = &specs.Linux{ - Namespaces: []specs.LinuxNamespace{ - { - Type: "pid", - Path: "/proc/1/ns/pid", - }, - }, - } - testSpecs[1].Linux = &specs.Linux{ - Namespaces: []specs.LinuxNamespace{ - { - Type: "pid", - Path: "/proc/1/ns/pid", - }, - }, - } - testSpecs[2].Linux = &specs.Linux{ - Namespaces: []specs.LinuxNamespace{ - { - Type: "pid", - Path: "/proc/2/ns/pid", - }, - }, - } - - containers, cleanup, err := startContainers(conf, testSpecs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Check via ps that multiple processes are running. - expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(containers[0], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - if err := waitForProcessList(containers[2], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - - expectedPL = []*control.Process{ - {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, - } - if err := waitForProcessList(containers[1], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - } -} - -func TestMultiContainerWait(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - // The first container should run the entire duration of the test. - cmd1 := []string{"sleep", "100"} - // We'll wait on the second container, which is much shorter lived. - cmd2 := []string{"sleep", "1"} - specs, ids := createSpecs(cmd1, cmd2) - - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Check via ps that multiple processes are running. - expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, - } - if err := waitForProcessList(containers[1], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - - // Wait on the short lived container from multiple goroutines. - wg := sync.WaitGroup{} - for i := 0; i < 3; i++ { - wg.Add(1) - go func(c *Container) { - defer wg.Done() - if ws, err := c.Wait(); err != nil { - t.Errorf("failed to wait for process %s: %v", c.Spec.Process.Args, err) - } else if es := ws.ExitStatus(); es != 0 { - t.Errorf("process %s exited with non-zero status %d", c.Spec.Process.Args, es) - } - if _, err := c.Wait(); err != nil { - t.Errorf("wait for stopped container %s shouldn't fail: %v", c.Spec.Process.Args, err) - } - }(containers[1]) - } - - // Also wait via PID. - for i := 0; i < 3; i++ { - wg.Add(1) - go func(c *Container) { - defer wg.Done() - const pid = 2 - if ws, err := c.WaitPID(pid); err != nil { - t.Errorf("failed to wait for PID %d: %v", pid, err) - } else if es := ws.ExitStatus(); es != 0 { - t.Errorf("PID %d exited with non-zero status %d", pid, es) - } - if _, err := c.WaitPID(pid); err == nil { - t.Errorf("wait for stopped PID %d should fail", pid) - } - }(containers[1]) - } - - wg.Wait() - - // After Wait returns, ensure that the root container is running and - // the child has finished. - expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(containers[0], expectedPL); err != nil { - t.Errorf("failed to wait for %q to start: %v", strings.Join(containers[0].Spec.Process.Args, " "), err) - } -} - -// TestExecWait ensures what we can wait containers and individual processes in the -// sandbox that have already exited. -func TestExecWait(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - // The first container should run the entire duration of the test. - cmd1 := []string{"sleep", "100"} - // We'll wait on the second container, which is much shorter lived. - cmd2 := []string{"sleep", "1"} - specs, ids := createSpecs(cmd1, cmd2) - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Check via ps that process is running. - expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, - } - if err := waitForProcessList(containers[1], expectedPL); err != nil { - t.Fatalf("failed to wait for sleep to start: %v", err) - } - - // Wait for the second container to finish. - if err := waitForProcessCount(containers[1], 0); err != nil { - t.Fatalf("failed to wait for second container to stop: %v", err) - } - - // Get the second container exit status. - if ws, err := containers[1].Wait(); err != nil { - t.Fatalf("failed to wait for process %s: %v", containers[1].Spec.Process.Args, err) - } else if es := ws.ExitStatus(); es != 0 { - t.Fatalf("process %s exited with non-zero status %d", containers[1].Spec.Process.Args, es) - } - if _, err := containers[1].Wait(); err != nil { - t.Fatalf("wait for stopped container %s shouldn't fail: %v", containers[1].Spec.Process.Args, err) - } - - // Execute another process in the first container. - args := &control.ExecArgs{ - Filename: "/bin/sleep", - Argv: []string{"/bin/sleep", "1"}, - WorkingDirectory: "/", - KUID: 0, - } - pid, err := containers[0].Execute(args) - if err != nil { - t.Fatalf("error executing: %v", err) - } - - // Wait for the exec'd process to exit. - expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(containers[0], expectedPL); err != nil { - t.Fatalf("failed to wait for second container to stop: %v", err) - } - - // Get the exit status from the exec'd process. - if ws, err := containers[0].WaitPID(pid); err != nil { - t.Fatalf("failed to wait for process %+v with pid %d: %v", args, pid, err) - } else if es := ws.ExitStatus(); es != 0 { - t.Fatalf("process %+v exited with non-zero status %d", args, es) - } - if _, err := containers[0].WaitPID(pid); err == nil { - t.Fatalf("wait for stopped process %+v should fail", args) - } -} - -// TestMultiContainerMount tests that bind mounts can be used with multiple -// containers. -func TestMultiContainerMount(t *testing.T) { - cmd1 := []string{"sleep", "100"} - - // 'src != dst' ensures that 'dst' doesn't exist in the host and must be - // properly mapped inside the container to work. - src, err := ioutil.TempDir(testutil.TmpDir(), "container") - if err != nil { - t.Fatal("ioutil.TempDir failed:", err) - } - dst := src + ".dst" - cmd2 := []string{"touch", filepath.Join(dst, "file")} - - sps, ids := createSpecs(cmd1, cmd2) - sps[1].Mounts = append(sps[1].Mounts, specs.Mount{ - Source: src, - Destination: dst, - Type: "bind", - }) - - // Setup the containers. - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - containers, cleanup, err := startContainers(conf, sps, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - ws, err := containers[1].Wait() - if err != nil { - t.Error("error waiting on container:", err) - } - if !ws.Exited() || ws.ExitStatus() != 0 { - t.Error("container failed, waitStatus:", ws) - } -} - -// TestMultiContainerSignal checks that it is possible to signal individual -// containers without killing the entire sandbox. -func TestMultiContainerSignal(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf.RootDir = rootDir - - // Setup the containers. - sleep := []string{"sleep", "100"} - specs, ids := createSpecs(sleep, sleep) - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Check via ps that container 1 process is running. - expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, - } - - if err := waitForProcessList(containers[1], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - - // Kill process 2. - if err := containers[1].SignalContainer(syscall.SIGKILL, false); err != nil { - t.Errorf("failed to kill process 2: %v", err) - } - - // Make sure process 1 is still running. - expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(containers[0], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - - // goferPid is reset when container is destroyed. - goferPid := containers[1].GoferPid - - // Destroy container and ensure container's gofer process has exited. - if err := containers[1].Destroy(); err != nil { - t.Errorf("failed to destroy container: %v", err) - } - _, _, err = specutils.RetryEintr(func() (uintptr, uintptr, error) { - cpid, err := syscall.Wait4(goferPid, nil, 0, nil) - return uintptr(cpid), 0, err - }) - if err != syscall.ECHILD { - t.Errorf("error waiting for gofer to exit: %v", err) - } - // Make sure process 1 is still running. - if err := waitForProcessList(containers[0], expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - - // Now that process 2 is gone, ensure we get an error trying to - // signal it again. - if err := containers[1].SignalContainer(syscall.SIGKILL, false); err == nil { - t.Errorf("container %q shouldn't exist, but we were able to signal it", containers[1].ID) - } - - // Kill process 1. - if err := containers[0].SignalContainer(syscall.SIGKILL, false); err != nil { - t.Errorf("failed to kill process 1: %v", err) - } - - // Ensure that container's gofer and sandbox process are no more. - err = blockUntilWaitable(containers[0].GoferPid) - if err != nil && err != syscall.ECHILD { - t.Errorf("error waiting for gofer to exit: %v", err) - } - - err = blockUntilWaitable(containers[0].Sandbox.Pid) - if err != nil && err != syscall.ECHILD { - t.Errorf("error waiting for sandbox to exit: %v", err) - } - - // The sentry should be gone, so signaling should yield an error. - if err := containers[0].SignalContainer(syscall.SIGKILL, false); err == nil { - t.Errorf("sandbox %q shouldn't exist, but we were able to signal it", containers[0].Sandbox.ID) - } - - if err := containers[0].Destroy(); err != nil { - t.Errorf("failed to destroy container: %v", err) - } - } -} - -// TestMultiContainerDestroy checks that container are properly cleaned-up when -// they are destroyed. -func TestMultiContainerDestroy(t *testing.T) { - app, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf.RootDir = rootDir - - // First container will remain intact while the second container is killed. - podSpecs, ids := createSpecs( - []string{"sleep", "100"}, - []string{app, "fork-bomb"}) - - // Run the fork bomb in a PID namespace to prevent processes to be - // re-parented to PID=1 in the root container. - podSpecs[1].Linux = &specs.Linux{ - Namespaces: []specs.LinuxNamespace{{Type: "pid"}}, - } - containers, cleanup, err := startContainers(conf, podSpecs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Exec more processes to ensure signal all works for exec'd processes too. - args := &control.ExecArgs{ - Filename: app, - Argv: []string{app, "fork-bomb"}, - } - if _, err := containers[1].Execute(args); err != nil { - t.Fatalf("error exec'ing: %v", err) - } - - // Let it brew... - time.Sleep(500 * time.Millisecond) - - if err := containers[1].Destroy(); err != nil { - t.Fatalf("error destroying container: %v", err) - } - - // Check that destroy killed all processes belonging to the container and - // waited for them to exit before returning. - pss, err := containers[0].Sandbox.Processes("") - if err != nil { - t.Fatalf("error getting process data from sandbox: %v", err) - } - expectedPL := []*control.Process{{PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}} - if r, err := procListsEqual(pss, expectedPL); !r { - t.Errorf("container got process list: %s, want: %s: error: %v", - procListToString(pss), procListToString(expectedPL), err) - } - - // Check that cont.Destroy is safe to call multiple times. - if err := containers[1].Destroy(); err != nil { - t.Errorf("error destroying container: %v", err) - } - } -} - -func TestMultiContainerProcesses(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - // Note: use curly braces to keep 'sh' process around. Otherwise, shell - // will just execve into 'sleep' and both containers will look the - // same. - specs, ids := createSpecs( - []string{"sleep", "100"}, - []string{"sh", "-c", "{ sleep 100; }"}) - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Check root's container process list doesn't include other containers. - expectedPL0 := []*control.Process{ - {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, - } - if err := waitForProcessList(containers[0], expectedPL0); err != nil { - t.Errorf("failed to wait for process to start: %v", err) - } - - // Same for the other container. - expectedPL1 := []*control.Process{ - {PID: 2, Cmd: "sh", Threads: []kernel.ThreadID{2}}, - {PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}}, - } - if err := waitForProcessList(containers[1], expectedPL1); err != nil { - t.Errorf("failed to wait for process to start: %v", err) - } - - // Now exec into the second container and verify it shows up in the container. - args := &control.ExecArgs{ - Filename: "/bin/sleep", - Argv: []string{"/bin/sleep", "100"}, - } - if _, err := containers[1].Execute(args); err != nil { - t.Fatalf("error exec'ing: %v", err) - } - expectedPL1 = append(expectedPL1, &control.Process{PID: 4, Cmd: "sleep", Threads: []kernel.ThreadID{4}}) - if err := waitForProcessList(containers[1], expectedPL1); err != nil { - t.Errorf("failed to wait for process to start: %v", err) - } - // Root container should remain unchanged. - if err := waitForProcessList(containers[0], expectedPL0); err != nil { - t.Errorf("failed to wait for process to start: %v", err) - } -} - -// TestMultiContainerKillAll checks that all process that belong to a container -// are killed when SIGKILL is sent to *all* processes in that container. -func TestMultiContainerKillAll(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - for _, tc := range []struct { - killContainer bool - }{ - {killContainer: true}, - {killContainer: false}, - } { - app, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - // First container will remain intact while the second container is killed. - specs, ids := createSpecs( - []string{app, "task-tree", "--depth=2", "--width=2"}, - []string{app, "task-tree", "--depth=4", "--width=2"}) - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Wait until all processes are created. - rootProcCount := int(math.Pow(2, 3) - 1) - if err := waitForProcessCount(containers[0], rootProcCount); err != nil { - t.Fatal(err) - } - procCount := int(math.Pow(2, 5) - 1) - if err := waitForProcessCount(containers[1], procCount); err != nil { - t.Fatal(err) - } - - // Exec more processes to ensure signal works for exec'd processes too. - args := &control.ExecArgs{ - Filename: app, - Argv: []string{app, "task-tree", "--depth=2", "--width=2"}, - } - if _, err := containers[1].Execute(args); err != nil { - t.Fatalf("error exec'ing: %v", err) - } - // Wait for these new processes to start. - procCount += int(math.Pow(2, 3) - 1) - if err := waitForProcessCount(containers[1], procCount); err != nil { - t.Fatal(err) - } - - if tc.killContainer { - // First kill the init process to make the container be stopped with - // processes still running inside. - containers[1].SignalContainer(syscall.SIGKILL, false) - op := func() error { - c, err := Load(conf.RootDir, ids[1]) - if err != nil { - return err - } - if c.Status != Stopped { - return fmt.Errorf("container is not stopped") - } - return nil - } - if err := testutil.Poll(op, 5*time.Second); err != nil { - t.Fatalf("container did not stop %q: %v", containers[1].ID, err) - } - } - - c, err := Load(conf.RootDir, ids[1]) - if err != nil { - t.Fatalf("failed to load child container %q: %v", c.ID, err) - } - // Kill'Em All - if err := c.SignalContainer(syscall.SIGKILL, true); err != nil { - t.Fatalf("failed to send SIGKILL to container %q: %v", c.ID, err) - } - - // Check that all processes are gone. - if err := waitForProcessCount(containers[1], 0); err != nil { - t.Fatal(err) - } - // Check that root container was not affected. - if err := waitForProcessCount(containers[0], rootProcCount); err != nil { - t.Fatal(err) - } - } -} - -func TestMultiContainerDestroyNotStarted(t *testing.T) { - specs, ids := createSpecs( - []string{"/bin/sleep", "100"}, - []string{"/bin/sleep", "100"}) - - conf := testutil.TestConfig() - rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(rootBundleDir) - - rootArgs := Args{ - ID: ids[0], - Spec: specs[0], - BundleDir: rootBundleDir, - } - root, err := New(conf, rootArgs) - if err != nil { - t.Fatalf("error creating root container: %v", err) - } - defer root.Destroy() - if err := root.Start(conf); err != nil { - t.Fatalf("error starting root container: %v", err) - } - - // Create and destroy sub-container. - bundleDir, err := testutil.SetupBundleDir(specs[1]) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(bundleDir) - - args := Args{ - ID: ids[1], - Spec: specs[1], - BundleDir: bundleDir, - } - cont, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - - // Check that container can be destroyed. - if err := cont.Destroy(); err != nil { - t.Fatalf("deleting non-started container failed: %v", err) - } -} - -// TestMultiContainerDestroyStarting attempts to force a race between start -// and destroy. -func TestMultiContainerDestroyStarting(t *testing.T) { - cmds := make([][]string, 10) - for i := range cmds { - cmds[i] = []string{"/bin/sleep", "100"} - } - specs, ids := createSpecs(cmds...) - - conf := testutil.TestConfig() - rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(rootBundleDir) - - rootArgs := Args{ - ID: ids[0], - Spec: specs[0], - BundleDir: rootBundleDir, - } - root, err := New(conf, rootArgs) - if err != nil { - t.Fatalf("error creating root container: %v", err) - } - defer root.Destroy() - if err := root.Start(conf); err != nil { - t.Fatalf("error starting root container: %v", err) - } - - wg := sync.WaitGroup{} - for i := range cmds { - if i == 0 { - continue // skip root container - } - - bundleDir, err := testutil.SetupBundleDir(specs[i]) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(bundleDir) - - rootArgs := Args{ - ID: ids[i], - Spec: specs[i], - BundleDir: rootBundleDir, - } - cont, err := New(conf, rootArgs) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - - // Container is not thread safe, so load another instance to run in - // concurrently. - startCont, err := Load(rootDir, ids[i]) - if err != nil { - t.Fatalf("error loading container: %v", err) - } - wg.Add(1) - go func() { - defer wg.Done() - startCont.Start(conf) // ignore failures, start can fail if destroy runs first. - }() - - wg.Add(1) - go func() { - defer wg.Done() - if err := cont.Destroy(); err != nil { - t.Errorf("deleting non-started container failed: %v", err) - } - }() - } - wg.Wait() -} - -// TestMultiContainerDifferentFilesystems tests that different containers have -// different root filesystems. -func TestMultiContainerDifferentFilesystems(t *testing.T) { - filename := "/foo" - // Root container will create file and then sleep. - cmdRoot := []string{"sh", "-c", fmt.Sprintf("touch %q && sleep 100", filename)} - - // Child containers will assert that the file does not exist, and will - // then create it. - script := fmt.Sprintf("if [ -f %q ]; then exit 1; else touch %q; fi", filename, filename) - cmd := []string{"sh", "-c", script} - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - // Make sure overlay is enabled, and none of the root filesystems are - // read-only, otherwise we won't be able to create the file. - conf.Overlay = true - specs, ids := createSpecs(cmdRoot, cmd, cmd) - for _, s := range specs { - s.Root.Readonly = false - } - - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Both child containers should exit successfully. - for i, c := range containers { - if i == 0 { - // Don't wait on the root. - continue - } - if ws, err := c.Wait(); err != nil { - t.Errorf("failed to wait for process %s: %v", c.Spec.Process.Args, err) - } else if es := ws.ExitStatus(); es != 0 { - t.Errorf("process %s exited with non-zero status %d", c.Spec.Process.Args, es) - } - } -} - -// TestMultiContainerContainerDestroyStress tests that IO operations continue -// to work after containers have been stopped and gofers killed. -func TestMultiContainerContainerDestroyStress(t *testing.T) { - app, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - // Setup containers. Root container just reaps children, while the others - // perform some IOs. Children are executed in 3 batches of 10. Within the - // batch there is overlap between containers starting and being destroyed. In - // between batches all containers stop before starting another batch. - cmds := [][]string{{app, "reaper"}} - const batchSize = 10 - for i := 0; i < 3*batchSize; i++ { - dir, err := ioutil.TempDir(testutil.TmpDir(), "gofer-stop-test") - if err != nil { - t.Fatal("ioutil.TempDir failed:", err) - } - defer os.RemoveAll(dir) - - cmd := "find /bin -type f | head | xargs -I SRC cp SRC " + dir - cmds = append(cmds, []string{"sh", "-c", cmd}) - } - allSpecs, allIDs := createSpecs(cmds...) - - // Split up the specs and IDs. - rootSpec := allSpecs[0] - rootID := allIDs[0] - childrenSpecs := allSpecs[1:] - childrenIDs := allIDs[1:] - - conf := testutil.TestConfig() - rootDir, bundleDir, err := testutil.SetupContainer(rootSpec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Start root container. - rootArgs := Args{ - ID: rootID, - Spec: rootSpec, - BundleDir: bundleDir, - } - root, err := New(conf, rootArgs) - if err != nil { - t.Fatalf("error creating root container: %v", err) - } - if err := root.Start(conf); err != nil { - t.Fatalf("error starting root container: %v", err) - } - defer root.Destroy() - - // Run batches. Each batch starts containers in parallel, then wait and - // destroy them before starting another batch. - for i := 0; i < len(childrenSpecs); i += batchSize { - t.Logf("Starting batch from %d to %d", i, i+batchSize) - specs := childrenSpecs[i : i+batchSize] - ids := childrenIDs[i : i+batchSize] - - var children []*Container - for j, spec := range specs { - bundleDir, err := testutil.SetupBundleDir(spec) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(bundleDir) - - args := Args{ - ID: ids[j], - Spec: spec, - BundleDir: bundleDir, - } - child, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - children = append(children, child) - - if err := child.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // Give a small gap between containers. - time.Sleep(50 * time.Millisecond) - } - for _, child := range children { - ws, err := child.Wait() - if err != nil { - t.Fatalf("waiting for container: %v", err) - } - if !ws.Exited() || ws.ExitStatus() != 0 { - t.Fatalf("container failed, waitStatus: %x (%d)", ws, ws.ExitStatus()) - } - if err := child.Destroy(); err != nil { - t.Fatalf("error destroying container: %v", err) - } - } - } -} - -// Test that pod shared mounts are properly mounted in 2 containers and that -// changes from one container is reflected in the other. -func TestMultiContainerSharedMount(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf.RootDir = rootDir - - // Setup the containers. - sleep := []string{"sleep", "100"} - podSpec, ids := createSpecs(sleep, sleep) - mnt0 := specs.Mount{ - Destination: "/mydir/test", - Source: "/some/dir", - Type: "tmpfs", - Options: nil, - } - podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) - - mnt1 := mnt0 - mnt1.Destination = "/mydir2/test2" - podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) - - createSharedMount(mnt0, "test-mount", podSpec...) - - containers, cleanup, err := startContainers(conf, podSpec, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - file0 := path.Join(mnt0.Destination, "abc") - file1 := path.Join(mnt1.Destination, "abc") - execs := []execDesc{ - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", - }, - { - c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, - desc: "create file in container0", - }, - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file appears in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file appears in container1", - }, - { - c: containers[1], - cmd: []string{"/bin/rm", file1}, - desc: "file removed from container1", - }, - { - c: containers[0], - cmd: []string{"/usr/bin/test", "!", "-f", file0}, - desc: "file removed from container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "!", "-f", file1}, - desc: "file removed from container1", - }, - { - c: containers[1], - cmd: []string{"/bin/mkdir", file1}, - desc: "create directory in container1", - }, - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-d", file0}, - desc: "dir appears in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-d", file1}, - desc: "dir appears in container1", - }, - { - c: containers[0], - cmd: []string{"/bin/rmdir", file0}, - desc: "create directory in container0", - }, - { - c: containers[0], - cmd: []string{"/usr/bin/test", "!", "-d", file0}, - desc: "dir removed from container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "!", "-d", file1}, - desc: "dir removed from container1", - }, - } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } - } -} - -// Test that pod mounts are mounted as readonly when requested. -func TestMultiContainerSharedMountReadonly(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf.RootDir = rootDir - - // Setup the containers. - sleep := []string{"sleep", "100"} - podSpec, ids := createSpecs(sleep, sleep) - mnt0 := specs.Mount{ - Destination: "/mydir/test", - Source: "/some/dir", - Type: "tmpfs", - Options: []string{"ro"}, - } - podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) - - mnt1 := mnt0 - mnt1.Destination = "/mydir2/test2" - podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) - - createSharedMount(mnt0, "test-mount", podSpec...) - - containers, cleanup, err := startContainers(conf, podSpec, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - file0 := path.Join(mnt0.Destination, "abc") - file1 := path.Join(mnt1.Destination, "abc") - execs := []execDesc{ - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", - }, - { - c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, - want: 1, - desc: "fails to write to container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/touch", file1}, - want: 1, - desc: "fails to write to container1", - }, - } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } - } -} - -// Test that shared pod mounts continue to work after container is restarted. -func TestMultiContainerSharedMountRestart(t *testing.T) { - for _, conf := range configs(all...) { - t.Logf("Running test with conf: %+v", conf) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf.RootDir = rootDir - - // Setup the containers. - sleep := []string{"sleep", "100"} - podSpec, ids := createSpecs(sleep, sleep) - mnt0 := specs.Mount{ - Destination: "/mydir/test", - Source: "/some/dir", - Type: "tmpfs", - Options: nil, - } - podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) - - mnt1 := mnt0 - mnt1.Destination = "/mydir2/test2" - podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) - - createSharedMount(mnt0, "test-mount", podSpec...) - - containers, cleanup, err := startContainers(conf, podSpec, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - file0 := path.Join(mnt0.Destination, "abc") - file1 := path.Join(mnt1.Destination, "abc") - execs := []execDesc{ - { - c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, - desc: "create file in container0", - }, - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file appears in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file appears in container1", - }, - } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } - - containers[1].Destroy() - - bundleDir, err := testutil.SetupBundleDir(podSpec[1]) - if err != nil { - t.Fatalf("error restarting container: %v", err) - } - defer os.RemoveAll(bundleDir) - - args := Args{ - ID: ids[1], - Spec: podSpec[1], - BundleDir: bundleDir, - } - containers[1], err = New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - if err := containers[1].Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - execs = []execDesc{ - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file is still in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file is still in container1", - }, - { - c: containers[1], - cmd: []string{"/bin/rm", file1}, - desc: "file removed from container1", - }, - { - c: containers[0], - cmd: []string{"/usr/bin/test", "!", "-f", file0}, - desc: "file removed from container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "!", "-f", file1}, - desc: "file removed from container1", - }, - } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } - } -} - -// Test that unsupported pod mounts options are ignored when matching master and -// slave mounts. -func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - // Setup the containers. - sleep := []string{"/bin/sleep", "100"} - podSpec, ids := createSpecs(sleep, sleep) - mnt0 := specs.Mount{ - Destination: "/mydir/test", - Source: "/some/dir", - Type: "tmpfs", - Options: []string{"rw", "rbind", "relatime"}, - } - podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) - - mnt1 := mnt0 - mnt1.Destination = "/mydir2/test2" - mnt1.Options = []string{"rw", "nosuid"} - podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) - - createSharedMount(mnt0, "test-mount", podSpec...) - - containers, cleanup, err := startContainers(conf, podSpec, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - execs := []execDesc{ - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", - }, - } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } -} - -// Test that one container can send an FD to another container, even though -// they have distinct MountNamespaces. -func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { - app, err := testutil.FindFile("runsc/container/test_app/test_app") - if err != nil { - t.Fatal("error finding test_app:", err) - } - - // We set up two containers with one shared mount that is used for a - // shared socket. The first container will send an FD over the socket - // to the second container. The FD corresponds to a file in the first - // container's mount namespace that is not part of the second - // container's mount namespace. However, the second container still - // should be able to read the FD. - - // Create a shared mount where we will put the socket. - sharedMnt := specs.Mount{ - Destination: "/mydir/test", - Type: "tmpfs", - // Shared mounts need a Source, even for tmpfs. It is only used - // to match up different shared mounts inside the pod. - Source: "/some/dir", - } - socketPath := filepath.Join(sharedMnt.Destination, "socket") - - // Create a writeable tmpfs mount where the FD sender app will create - // files to send. This will only be mounted in the FD sender. - writeableMnt := specs.Mount{ - Destination: "/tmp", - Type: "tmpfs", - } - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - // Create the specs. - specs, ids := createSpecs( - []string{"sleep", "1000"}, - []string{app, "fd_sender", "--socket", socketPath}, - []string{app, "fd_receiver", "--socket", socketPath}, - ) - createSharedMount(sharedMnt, "shared-mount", specs...) - specs[1].Mounts = append(specs[2].Mounts, sharedMnt, writeableMnt) - specs[2].Mounts = append(specs[1].Mounts, sharedMnt) - - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Both containers should exit successfully. - for _, c := range containers[1:] { - if ws, err := c.Wait(); err != nil { - t.Errorf("failed to wait for process %s: %v", c.Spec.Process.Args, err) - } else if es := ws.ExitStatus(); es != 0 { - t.Errorf("process %s exited with non-zero status %d", c.Spec.Process.Args, es) - } - } -} - -// Test that container is destroyed when Gofer is killed. -func TestMultiContainerGoferKilled(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - sleep := []string{"sleep", "100"} - specs, ids := createSpecs(sleep, sleep, sleep) - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Ensure container is running - c := containers[2] - expectedPL := []*control.Process{ - {PID: 3, Cmd: "sleep", Threads: []kernel.ThreadID{3}}, - } - if err := waitForProcessList(c, expectedPL); err != nil { - t.Errorf("failed to wait for sleep to start: %v", err) - } - - // Kill container's gofer. - if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil { - t.Fatalf("syscall.Kill(%d, SIGKILL)=%v", c.GoferPid, err) - } - - // Wait until container stops. - if err := waitForProcessList(c, nil); err != nil { - t.Errorf("Container %q was not stopped after gofer death: %v", c.ID, err) - } - - // Check that container isn't running anymore. - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err == nil { - t.Fatalf("Container %q was not stopped after gofer death", c.ID) - } - - // Check that other containers are unaffected. - for i, c := range containers { - if i == 2 { - continue // container[2] has been killed. - } - pl := []*control.Process{ - {PID: kernel.ThreadID(i + 1), Cmd: "sleep", Threads: []kernel.ThreadID{kernel.ThreadID(i + 1)}}, - } - if err := waitForProcessList(c, pl); err != nil { - t.Errorf("Container %q was affected by another container: %v", c.ID, err) - } - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err != nil { - t.Fatalf("Container %q was affected by another container: %v", c.ID, err) - } - } - - // Kill root container's gofer to bring entire sandbox down. - c = containers[0] - if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil { - t.Fatalf("syscall.Kill(%d, SIGKILL)=%v", c.GoferPid, err) - } - - // Wait until sandbox stops. waitForProcessList will loop until sandbox exits - // and RPC errors out. - impossiblePL := []*control.Process{ - {PID: 100, Cmd: "non-existent-process", Threads: []kernel.ThreadID{100}}, - } - if err := waitForProcessList(c, impossiblePL); err == nil { - t.Fatalf("Sandbox was not killed after gofer death") - } - - // Check that entire sandbox isn't running anymore. - for _, c := range containers { - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err == nil { - t.Fatalf("Container %q was not stopped after gofer death", c.ID) - } - } -} - -func TestMultiContainerLoadSandbox(t *testing.T) { - sleep := []string{"sleep", "100"} - specs, ids := createSpecs(sleep, sleep, sleep) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - // Create containers for the sandbox. - wants, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Then create unrelated containers. - for i := 0; i < 3; i++ { - specs, ids = createSpecs(sleep, sleep, sleep) - _, cleanup, err = startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - } - - // Create an unrelated directory under root. - dir := filepath.Join(conf.RootDir, "not-a-container") - if err := os.MkdirAll(dir, 0755); err != nil { - t.Fatalf("os.MkdirAll(%q)=%v", dir, err) - } - - // Create a valid but empty container directory. - randomCID := testutil.UniqueContainerID() - dir = filepath.Join(conf.RootDir, randomCID) - if err := os.MkdirAll(dir, 0755); err != nil { - t.Fatalf("os.MkdirAll(%q)=%v", dir, err) - } - - // Load the sandbox and check that the correct containers were returned. - id := wants[0].Sandbox.ID - gots, err := loadSandbox(conf.RootDir, id) - if err != nil { - t.Fatalf("loadSandbox()=%v", err) - } - wantIDs := make(map[string]struct{}) - for _, want := range wants { - wantIDs[want.ID] = struct{}{} - } - for _, got := range gots { - if got.Sandbox.ID != id { - t.Errorf("wrong sandbox ID, got: %v, want: %v", got.Sandbox.ID, id) - } - if _, ok := wantIDs[got.ID]; !ok { - t.Errorf("wrong container ID, got: %v, wants: %v", got.ID, wantIDs) - } - delete(wantIDs, got.ID) - } - if len(wantIDs) != 0 { - t.Errorf("containers not found: %v", wantIDs) - } -} - -// TestMultiContainerRunNonRoot checks that child container can be configured -// when running as non-privileged user. -func TestMultiContainerRunNonRoot(t *testing.T) { - cmdRoot := []string{"/bin/sleep", "100"} - cmdSub := []string{"/bin/true"} - podSpecs, ids := createSpecs(cmdRoot, cmdSub) - - // User running inside container can't list '$TMP/blocked' and would fail to - // mount it. - blocked, err := ioutil.TempDir(testutil.TmpDir(), "blocked") - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - if err := os.Chmod(blocked, 0700); err != nil { - t.Fatalf("os.MkDir(%q) failed: %v", blocked, err) - } - dir := path.Join(blocked, "test") - if err := os.Mkdir(dir, 0755); err != nil { - t.Fatalf("os.MkDir(%q) failed: %v", dir, err) - } - - src, err := ioutil.TempDir(testutil.TmpDir(), "src") - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - - // Set a random user/group with no access to "blocked" dir. - podSpecs[1].Process.User.UID = 343 - podSpecs[1].Process.User.GID = 2401 - podSpecs[1].Process.Capabilities = nil - - podSpecs[1].Mounts = append(podSpecs[1].Mounts, specs.Mount{ - Destination: dir, - Source: src, - Type: "bind", - }) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - pod, cleanup, err := startContainers(conf, podSpecs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - // Once all containers are started, wait for the child container to exit. - // This means that the volume was mounted properly. - ws, err := pod[1].Wait() - if err != nil { - t.Fatalf("running child container: %v", err) - } - if !ws.Exited() || ws.ExitStatus() != 0 { - t.Fatalf("child container failed, waitStatus: %v", ws) - } -} diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go deleted file mode 100644 index dc4194134..000000000 --- a/runsc/container/shared_volume_test.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package container - -import ( - "bytes" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/runsc/boot" - "gvisor.dev/gvisor/runsc/testutil" -) - -// TestSharedVolume checks that modifications to a volume mount are propagated -// into and out of the sandbox. -func TestSharedVolume(t *testing.T) { - conf := testutil.TestConfig() - conf.FileAccess = boot.FileAccessShared - t.Logf("Running test with conf: %+v", conf) - - // Main process just sleeps. We will use "exec" to probe the state of - // the filesystem. - spec := testutil.NewSpecWithArgs("sleep", "1000") - - dir, err := ioutil.TempDir(testutil.TmpDir(), "shared-volume-test") - if err != nil { - t.Fatalf("TempDir failed: %v", err) - } - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // File that will be used to check consistency inside/outside sandbox. - filename := filepath.Join(dir, "file") - - // File does not exist yet. Reading from the sandbox should fail. - argsTestFile := &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "-f", filename}, - } - if ws, err := c.executeSync(argsTestFile); err != nil { - t.Fatalf("unexpected error testing file %q: %v", filename, err) - } else if ws.ExitStatus() == 0 { - t.Errorf("test %q exited with code %v, wanted not zero", ws.ExitStatus(), err) - } - - // Create the file from outside of the sandbox. - if err := ioutil.WriteFile(filename, []byte("foobar"), 0777); err != nil { - t.Fatalf("error writing to file %q: %v", filename, err) - } - - // Now we should be able to test the file from within the sandbox. - if ws, err := c.executeSync(argsTestFile); err != nil { - t.Fatalf("unexpected error testing file %q: %v", filename, err) - } else if ws.ExitStatus() != 0 { - t.Errorf("test %q exited with code %v, wanted zero", filename, ws.ExitStatus()) - } - - // Rename the file from outside of the sandbox. - newFilename := filepath.Join(dir, "newfile") - if err := os.Rename(filename, newFilename); err != nil { - t.Fatalf("os.Rename(%q, %q) failed: %v", filename, newFilename, err) - } - - // File should no longer exist at the old path within the sandbox. - if ws, err := c.executeSync(argsTestFile); err != nil { - t.Fatalf("unexpected error testing file %q: %v", filename, err) - } else if ws.ExitStatus() == 0 { - t.Errorf("test %q exited with code %v, wanted not zero", filename, ws.ExitStatus()) - } - - // We should be able to test the new filename from within the sandbox. - argsTestNewFile := &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "-f", newFilename}, - } - if ws, err := c.executeSync(argsTestNewFile); err != nil { - t.Fatalf("unexpected error testing file %q: %v", newFilename, err) - } else if ws.ExitStatus() != 0 { - t.Errorf("test %q exited with code %v, wanted zero", newFilename, ws.ExitStatus()) - } - - // Delete the renamed file from outside of the sandbox. - if err := os.Remove(newFilename); err != nil { - t.Fatalf("error removing file %q: %v", filename, err) - } - - // Renamed file should no longer exist at the old path within the sandbox. - if ws, err := c.executeSync(argsTestNewFile); err != nil { - t.Fatalf("unexpected error testing file %q: %v", newFilename, err) - } else if ws.ExitStatus() == 0 { - t.Errorf("test %q exited with code %v, wanted not zero", newFilename, ws.ExitStatus()) - } - - // Now create the file from WITHIN the sandbox. - argsTouch := &control.ExecArgs{ - Filename: "/usr/bin/touch", - Argv: []string{"touch", filename}, - KUID: auth.KUID(os.Getuid()), - KGID: auth.KGID(os.Getgid()), - } - if ws, err := c.executeSync(argsTouch); err != nil { - t.Fatalf("unexpected error touching file %q: %v", filename, err) - } else if ws.ExitStatus() != 0 { - t.Errorf("touch %q exited with code %v, wanted zero", filename, ws.ExitStatus()) - } - - // File should exist outside the sandbox. - if _, err := os.Stat(filename); err != nil { - t.Errorf("stat %q got error %v, wanted nil", filename, err) - } - - // File should exist outside the sandbox. - if _, err := os.Stat(filename); err != nil { - t.Errorf("stat %q got error %v, wanted nil", filename, err) - } - - // Delete the file from within the sandbox. - argsRemove := &control.ExecArgs{ - Filename: "/bin/rm", - Argv: []string{"rm", filename}, - } - if ws, err := c.executeSync(argsRemove); err != nil { - t.Fatalf("unexpected error removing file %q: %v", filename, err) - } else if ws.ExitStatus() != 0 { - t.Errorf("remove %q exited with code %v, wanted zero", filename, ws.ExitStatus()) - } - - // File should not exist outside the sandbox. - if _, err := os.Stat(filename); !os.IsNotExist(err) { - t.Errorf("stat %q got error %v, wanted ErrNotExist", filename, err) - } -} - -func checkFile(c *Container, filename string, want []byte) error { - cpy := filename + ".copy" - argsCp := &control.ExecArgs{ - Filename: "/bin/cp", - Argv: []string{"cp", "-f", filename, cpy}, - } - if _, err := c.executeSync(argsCp); err != nil { - return fmt.Errorf("unexpected error copying file %q to %q: %v", filename, cpy, err) - } - got, err := ioutil.ReadFile(cpy) - if err != nil { - return fmt.Errorf("Error reading file %q: %v", filename, err) - } - if !bytes.Equal(got, want) { - return fmt.Errorf("file content inside the sandbox is wrong, got: %q, want: %q", got, want) - } - return nil -} - -// TestSharedVolumeFile tests that changes to file content outside the sandbox -// is reflected inside. -func TestSharedVolumeFile(t *testing.T) { - conf := testutil.TestConfig() - conf.FileAccess = boot.FileAccessShared - t.Logf("Running test with conf: %+v", conf) - - // Main process just sleeps. We will use "exec" to probe the state of - // the filesystem. - spec := testutil.NewSpecWithArgs("sleep", "1000") - - dir, err := ioutil.TempDir(testutil.TmpDir(), "shared-volume-test") - if err != nil { - t.Fatalf("TempDir failed: %v", err) - } - - rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer os.RemoveAll(rootDir) - defer os.RemoveAll(bundleDir) - - // Create and start the container. - args := Args{ - ID: testutil.UniqueContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - // File that will be used to check consistency inside/outside sandbox. - filename := filepath.Join(dir, "file") - - // Write file from outside the container and check that the same content is - // read inside. - want := []byte("host-") - if err := ioutil.WriteFile(filename, []byte(want), 0666); err != nil { - t.Fatalf("Error writing to %q: %v", filename, err) - } - if err := checkFile(c, filename, want); err != nil { - t.Fatal(err.Error()) - } - - // Append to file inside the container and check that content is not lost. - argsAppend := &control.ExecArgs{ - Filename: "/bin/bash", - Argv: []string{"bash", "-c", "echo -n sandbox- >> " + filename}, - } - if _, err := c.executeSync(argsAppend); err != nil { - t.Fatalf("unexpected error appending file %q: %v", filename, err) - } - want = []byte("host-sandbox-") - if err := checkFile(c, filename, want); err != nil { - t.Fatal(err.Error()) - } - - // Write again from outside the container and check that the same content is - // read inside. - f, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY, 0) - if err != nil { - t.Fatalf("Error openning file %q: %v", filename, err) - } - defer f.Close() - if _, err := f.Write([]byte("host")); err != nil { - t.Fatalf("Error writing to file %q: %v", filename, err) - } - want = []byte("host-sandbox-host") - if err := checkFile(c, filename, want); err != nil { - t.Fatal(err.Error()) - } - - // Shrink file outside and check that the same content is read inside. - if err := f.Truncate(5); err != nil { - t.Fatalf("Error truncating file %q: %v", filename, err) - } - want = want[:5] - if err := checkFile(c, filename, want); err != nil { - t.Fatal(err.Error()) - } -} diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go index 17a251530..17a251530 100644..100755 --- a/runsc/container/state_file.go +++ b/runsc/container/state_file.go diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD deleted file mode 100644 index 0defbd9fc..000000000 --- a/runsc/container/test_app/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "test_app", - testonly = 1, - srcs = [ - "fds.go", - "test_app.go", - ], - pure = True, - visibility = ["//runsc/container:__pkg__"], - deps = [ - "//pkg/unet", - "//runsc/flag", - "//runsc/testutil", - "@com_github_google_subcommands//:go_default_library", - "@com_github_kr_pty//:go_default_library", - ], -) diff --git a/runsc/container/test_app/fds.go b/runsc/container/test_app/fds.go deleted file mode 100644 index 2a146a2c3..000000000 --- a/runsc/container/test_app/fds.go +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "io/ioutil" - "log" - "os" - "time" - - "github.com/google/subcommands" - "gvisor.dev/gvisor/pkg/unet" - "gvisor.dev/gvisor/runsc/flag" - "gvisor.dev/gvisor/runsc/testutil" -) - -const fileContents = "foobarbaz" - -// fdSender will open a file and send the FD over a unix domain socket. -type fdSender struct { - socketPath string -} - -// Name implements subcommands.Command.Name. -func (*fdSender) Name() string { - return "fd_sender" -} - -// Synopsis implements subcommands.Command.Synopsys. -func (*fdSender) Synopsis() string { - return "creates a file and sends the FD over the socket" -} - -// Usage implements subcommands.Command.Usage. -func (*fdSender) Usage() string { - return "fd_sender <flags>" -} - -// SetFlags implements subcommands.Command.SetFlags. -func (fds *fdSender) SetFlags(f *flag.FlagSet) { - f.StringVar(&fds.socketPath, "socket", "", "path to socket") -} - -// Execute implements subcommands.Command.Execute. -func (fds *fdSender) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - if fds.socketPath == "" { - log.Fatalf("socket flag must be set") - } - - dir, err := ioutil.TempDir("", "") - if err != nil { - log.Fatalf("TempDir failed: %v", err) - } - - fileToSend, err := ioutil.TempFile(dir, "") - if err != nil { - log.Fatalf("TempFile failed: %v", err) - } - defer fileToSend.Close() - - if _, err := fileToSend.WriteString(fileContents); err != nil { - log.Fatalf("Write(%q) failed: %v", fileContents, err) - } - - // Receiver may not be started yet, so try connecting in a poll loop. - var s *unet.Socket - if err := testutil.Poll(func() error { - var err error - s, err = unet.Connect(fds.socketPath, true /* SEQPACKET, so we can send empty message with FD */) - return err - }, 10*time.Second); err != nil { - log.Fatalf("Error connecting to socket %q: %v", fds.socketPath, err) - } - defer s.Close() - - w := s.Writer(true) - w.ControlMessage.PackFDs(int(fileToSend.Fd())) - if _, err := w.WriteVec([][]byte{[]byte{'a'}}); err != nil { - log.Fatalf("Error sending FD %q over socket %q: %v", fileToSend.Fd(), fds.socketPath, err) - } - - log.Print("FD SENDER exiting successfully") - return subcommands.ExitSuccess -} - -// fdReceiver receives an FD from a unix domain socket and does things to it. -type fdReceiver struct { - socketPath string -} - -// Name implements subcommands.Command.Name. -func (*fdReceiver) Name() string { - return "fd_receiver" -} - -// Synopsis implements subcommands.Command.Synopsys. -func (*fdReceiver) Synopsis() string { - return "reads an FD from a unix socket, and then does things to it" -} - -// Usage implements subcommands.Command.Usage. -func (*fdReceiver) Usage() string { - return "fd_receiver <flags>" -} - -// SetFlags implements subcommands.Command.SetFlags. -func (fdr *fdReceiver) SetFlags(f *flag.FlagSet) { - f.StringVar(&fdr.socketPath, "socket", "", "path to socket") -} - -// Execute implements subcommands.Command.Execute. -func (fdr *fdReceiver) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - if fdr.socketPath == "" { - log.Fatalf("Flags cannot be empty, given: socket: %q", fdr.socketPath) - } - - ss, err := unet.BindAndListen(fdr.socketPath, true /* packet */) - if err != nil { - log.Fatalf("BindAndListen(%q) failed: %v", fdr.socketPath, err) - } - defer ss.Close() - - var s *unet.Socket - c := make(chan error, 1) - go func() { - var err error - s, err = ss.Accept() - c <- err - }() - - select { - case err := <-c: - if err != nil { - log.Fatalf("Accept() failed: %v", err) - } - case <-time.After(10 * time.Second): - log.Fatalf("Timeout waiting for accept") - } - - r := s.Reader(true) - r.EnableFDs(1) - b := [][]byte{{'a'}} - if n, err := r.ReadVec(b); n != 1 || err != nil { - log.Fatalf("ReadVec got n=%d err %v (wanted 0, nil)", n, err) - } - - fds, err := r.ExtractFDs() - if err != nil { - log.Fatalf("ExtractFD() got err %v", err) - } - if len(fds) != 1 { - log.Fatalf("ExtractFD() got %d FDs, wanted 1", len(fds)) - } - fd := fds[0] - - file := os.NewFile(uintptr(fd), "received file") - defer file.Close() - if _, err := file.Seek(0, os.SEEK_SET); err != nil { - log.Fatalf("Seek(0, 0) failed: %v", err) - } - - got, err := ioutil.ReadAll(file) - if err != nil { - log.Fatalf("ReadAll failed: %v", err) - } - if string(got) != fileContents { - log.Fatalf("ReadAll got %q want %q", string(got), fileContents) - } - - log.Print("FD RECEIVER exiting successfully") - return subcommands.ExitSuccess -} diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go deleted file mode 100644 index 01c47c79f..000000000 --- a/runsc/container/test_app/test_app.go +++ /dev/null @@ -1,394 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary test_app is like a swiss knife for tests that need to run anything -// inside the sandbox. New functionality can be added with new commands. -package main - -import ( - "context" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "os" - "os/exec" - "regexp" - "strconv" - sys "syscall" - "time" - - "github.com/google/subcommands" - "github.com/kr/pty" - "gvisor.dev/gvisor/runsc/flag" - "gvisor.dev/gvisor/runsc/testutil" -) - -func main() { - subcommands.Register(subcommands.HelpCommand(), "") - subcommands.Register(subcommands.FlagsCommand(), "") - subcommands.Register(new(capability), "") - subcommands.Register(new(fdReceiver), "") - subcommands.Register(new(fdSender), "") - subcommands.Register(new(forkBomb), "") - subcommands.Register(new(ptyRunner), "") - subcommands.Register(new(reaper), "") - subcommands.Register(new(syscall), "") - subcommands.Register(new(taskTree), "") - subcommands.Register(new(uds), "") - - flag.Parse() - - exitCode := subcommands.Execute(context.Background()) - os.Exit(int(exitCode)) -} - -type uds struct { - fileName string - socketPath string -} - -// Name implements subcommands.Command.Name. -func (*uds) Name() string { - return "uds" -} - -// Synopsis implements subcommands.Command.Synopsys. -func (*uds) Synopsis() string { - return "creates unix domain socket client and server. Client sends a contant flow of sequential numbers. Server prints them to --file" -} - -// Usage implements subcommands.Command.Usage. -func (*uds) Usage() string { - return "uds <flags>" -} - -// SetFlags implements subcommands.Command.SetFlags. -func (c *uds) SetFlags(f *flag.FlagSet) { - f.StringVar(&c.fileName, "file", "", "name of output file") - f.StringVar(&c.socketPath, "socket", "", "path to socket") -} - -// Execute implements subcommands.Command.Execute. -func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - if c.fileName == "" || c.socketPath == "" { - log.Fatalf("Flags cannot be empty, given: fileName: %q, socketPath: %q", c.fileName, c.socketPath) - return subcommands.ExitFailure - } - outputFile, err := os.OpenFile(c.fileName, os.O_WRONLY|os.O_CREATE, 0666) - if err != nil { - log.Fatal("error opening output file:", err) - } - - defer os.Remove(c.socketPath) - - listener, err := net.Listen("unix", c.socketPath) - if err != nil { - log.Fatal("error listening on socket %q:", c.socketPath, err) - } - - go server(listener, outputFile) - for i := 0; ; i++ { - conn, err := net.Dial("unix", c.socketPath) - if err != nil { - log.Fatal("error dialing:", err) - } - if _, err := conn.Write([]byte(strconv.Itoa(i))); err != nil { - log.Fatal("error writing:", err) - } - conn.Close() - time.Sleep(100 * time.Millisecond) - } -} - -func server(listener net.Listener, out *os.File) { - buf := make([]byte, 16) - - for { - c, err := listener.Accept() - if err != nil { - log.Fatal("error accepting connection:", err) - } - nr, err := c.Read(buf) - if err != nil { - log.Fatal("error reading from buf:", err) - } - data := buf[0:nr] - fmt.Fprint(out, string(data)+"\n") - } -} - -type taskTree struct { - depth int - width int - pause bool -} - -// Name implements subcommands.Command. -func (*taskTree) Name() string { - return "task-tree" -} - -// Synopsis implements subcommands.Command. -func (*taskTree) Synopsis() string { - return "creates a tree of tasks" -} - -// Usage implements subcommands.Command. -func (*taskTree) Usage() string { - return "task-tree <flags>" -} - -// SetFlags implements subcommands.Command. -func (c *taskTree) SetFlags(f *flag.FlagSet) { - f.IntVar(&c.depth, "depth", 1, "number of levels to create") - f.IntVar(&c.width, "width", 1, "number of tasks at each level") - f.BoolVar(&c.pause, "pause", false, "whether the tasks should pause perpetually") -} - -// Execute implements subcommands.Command. -func (c *taskTree) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - stop := testutil.StartReaper() - defer stop() - - if c.depth == 0 { - log.Printf("Child sleeping, PID: %d\n", os.Getpid()) - select {} - } - log.Printf("Parent %d sleeping, PID: %d\n", c.depth, os.Getpid()) - - var cmds []*exec.Cmd - for i := 0; i < c.width; i++ { - cmd := exec.Command( - "/proc/self/exe", c.Name(), - "--depth", strconv.Itoa(c.depth-1), - "--width", strconv.Itoa(c.width), - "--pause", strconv.FormatBool(c.pause)) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if err := cmd.Start(); err != nil { - log.Fatal("failed to call self:", err) - } - cmds = append(cmds, cmd) - } - - for _, c := range cmds { - c.Wait() - } - - if c.pause { - select {} - } - - return subcommands.ExitSuccess -} - -type forkBomb struct { - delay time.Duration -} - -// Name implements subcommands.Command. -func (*forkBomb) Name() string { - return "fork-bomb" -} - -// Synopsis implements subcommands.Command. -func (*forkBomb) Synopsis() string { - return "creates child process until the end of times" -} - -// Usage implements subcommands.Command. -func (*forkBomb) Usage() string { - return "fork-bomb <flags>" -} - -// SetFlags implements subcommands.Command. -func (c *forkBomb) SetFlags(f *flag.FlagSet) { - f.DurationVar(&c.delay, "delay", 100*time.Millisecond, "amount of time to delay creation of child") -} - -// Execute implements subcommands.Command. -func (c *forkBomb) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - time.Sleep(c.delay) - - cmd := exec.Command("/proc/self/exe", c.Name()) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - log.Fatal("failed to call self:", err) - } - return subcommands.ExitSuccess -} - -type reaper struct{} - -// Name implements subcommands.Command. -func (*reaper) Name() string { - return "reaper" -} - -// Synopsis implements subcommands.Command. -func (*reaper) Synopsis() string { - return "reaps all children in a loop" -} - -// Usage implements subcommands.Command. -func (*reaper) Usage() string { - return "reaper <flags>" -} - -// SetFlags implements subcommands.Command. -func (*reaper) SetFlags(*flag.FlagSet) {} - -// Execute implements subcommands.Command. -func (c *reaper) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - stop := testutil.StartReaper() - defer stop() - select {} -} - -type syscall struct { - sysno uint64 -} - -// Name implements subcommands.Command. -func (*syscall) Name() string { - return "syscall" -} - -// Synopsis implements subcommands.Command. -func (*syscall) Synopsis() string { - return "syscall makes a syscall" -} - -// Usage implements subcommands.Command. -func (*syscall) Usage() string { - return "syscall <flags>" -} - -// SetFlags implements subcommands.Command. -func (s *syscall) SetFlags(f *flag.FlagSet) { - f.Uint64Var(&s.sysno, "syscall", 0, "syscall to call") -} - -// Execute implements subcommands.Command. -func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - if _, _, errno := sys.Syscall(uintptr(s.sysno), 0, 0, 0); errno != 0 { - fmt.Printf("syscall(%d, 0, 0...) failed: %v\n", s.sysno, errno) - } else { - fmt.Printf("syscall(%d, 0, 0...) success\n", s.sysno) - } - return subcommands.ExitSuccess -} - -type capability struct { - enabled uint64 - disabled uint64 -} - -// Name implements subcommands.Command. -func (*capability) Name() string { - return "capability" -} - -// Synopsis implements subcommands.Command. -func (*capability) Synopsis() string { - return "checks if effective capabilities are set/unset" -} - -// Usage implements subcommands.Command. -func (*capability) Usage() string { - return "capability [--enabled=number] [--disabled=number]" -} - -// SetFlags implements subcommands.Command. -func (c *capability) SetFlags(f *flag.FlagSet) { - f.Uint64Var(&c.enabled, "enabled", 0, "") - f.Uint64Var(&c.disabled, "disabled", 0, "") -} - -// Execute implements subcommands.Command. -func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - if c.enabled == 0 && c.disabled == 0 { - fmt.Println("One of the flags must be set") - return subcommands.ExitUsageError - } - - status, err := ioutil.ReadFile("/proc/self/status") - if err != nil { - fmt.Printf("Error reading %q: %v\n", "proc/self/status", err) - return subcommands.ExitFailure - } - re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n") - matches := re.FindStringSubmatch(string(status)) - if matches == nil || len(matches) != 2 { - fmt.Printf("Effective capabilities not found in\n%s\n", status) - return subcommands.ExitFailure - } - caps, err := strconv.ParseUint(matches[1], 16, 64) - if err != nil { - fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err) - return subcommands.ExitFailure - } - - if c.enabled != 0 && (caps&c.enabled) != c.enabled { - fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps) - return subcommands.ExitFailure - } - if c.disabled != 0 && (caps&c.disabled) != 0 { - fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps) - return subcommands.ExitFailure - } - - return subcommands.ExitSuccess -} - -type ptyRunner struct{} - -// Name implements subcommands.Command. -func (*ptyRunner) Name() string { - return "pty-runner" -} - -// Synopsis implements subcommands.Command. -func (*ptyRunner) Synopsis() string { - return "runs the given command with an open pty terminal" -} - -// Usage implements subcommands.Command. -func (*ptyRunner) Usage() string { - return "pty-runner [command]" -} - -// SetFlags implements subcommands.Command.SetFlags. -func (*ptyRunner) SetFlags(f *flag.FlagSet) {} - -// Execute implements subcommands.Command. -func (*ptyRunner) Execute(_ context.Context, fs *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { - c := exec.Command(fs.Args()[0], fs.Args()[1:]...) - f, err := pty.Start(c) - if err != nil { - fmt.Printf("pty.Start failed: %v", err) - return subcommands.ExitFailure - } - defer f.Close() - - // Copy stdout from the command to keep this process alive until the - // subprocess exits. - io.Copy(os.Stdout, f) - - return subcommands.ExitSuccess -} diff --git a/runsc/criutil/BUILD b/runsc/criutil/BUILD deleted file mode 100644 index 8a571a000..000000000 --- a/runsc/criutil/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "criutil", - testonly = 1, - srcs = ["criutil.go"], - visibility = ["//:sandbox"], - deps = ["//runsc/testutil"], -) diff --git a/runsc/criutil/criutil.go b/runsc/criutil/criutil.go deleted file mode 100644 index 773f5a1c4..000000000 --- a/runsc/criutil/criutil.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package criutil contains utility functions for interacting with the -// Container Runtime Interface (CRI), principally via the crictl command line -// tool. This requires critools to be installed on the local system. -package criutil - -import ( - "encoding/json" - "fmt" - "os" - "os/exec" - "strings" - "time" - - "gvisor.dev/gvisor/runsc/testutil" -) - -const endpointPrefix = "unix://" - -// Crictl contains information required to run the crictl utility. -type Crictl struct { - executable string - timeout time.Duration - imageEndpoint string - runtimeEndpoint string -} - -// NewCrictl returns a Crictl configured with a timeout and an endpoint over -// which it will talk to containerd. -func NewCrictl(timeout time.Duration, endpoint string) *Crictl { - // Bazel doesn't pass PATH through, assume the location of crictl - // unless specified by environment variable. - executable := os.Getenv("CRICTL_PATH") - if executable == "" { - executable = "/usr/local/bin/crictl" - } - return &Crictl{ - executable: executable, - timeout: timeout, - imageEndpoint: endpointPrefix + endpoint, - runtimeEndpoint: endpointPrefix + endpoint, - } -} - -// Pull pulls an container image. It corresponds to `crictl pull`. -func (cc *Crictl) Pull(imageName string) error { - _, err := cc.run("pull", imageName) - return err -} - -// RunPod creates a sandbox. It corresponds to `crictl runp`. -func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { - podID, err := cc.run("runp", sbSpecFile) - if err != nil { - return "", fmt.Errorf("runp failed: %v", err) - } - // Strip the trailing newline from crictl output. - return strings.TrimSpace(podID), nil -} - -// Create creates a container within a sandbox. It corresponds to `crictl -// create`. -func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) { - podID, err := cc.run("create", podID, contSpecFile, sbSpecFile) - if err != nil { - return "", fmt.Errorf("create failed: %v", err) - } - // Strip the trailing newline from crictl output. - return strings.TrimSpace(podID), nil -} - -// Start starts a container. It corresponds to `crictl start`. -func (cc *Crictl) Start(contID string) (string, error) { - output, err := cc.run("start", contID) - if err != nil { - return "", fmt.Errorf("start failed: %v", err) - } - return output, nil -} - -// Stop stops a container. It corresponds to `crictl stop`. -func (cc *Crictl) Stop(contID string) error { - _, err := cc.run("stop", contID) - return err -} - -// Exec execs a program inside a container. It corresponds to `crictl exec`. -func (cc *Crictl) Exec(contID string, args ...string) (string, error) { - a := []string{"exec", contID} - a = append(a, args...) - output, err := cc.run(a...) - if err != nil { - return "", fmt.Errorf("exec failed: %v", err) - } - return output, nil -} - -// Rm removes a container. It corresponds to `crictl rm`. -func (cc *Crictl) Rm(contID string) error { - _, err := cc.run("rm", contID) - return err -} - -// StopPod stops a pod. It corresponds to `crictl stopp`. -func (cc *Crictl) StopPod(podID string) error { - _, err := cc.run("stopp", podID) - return err -} - -// containsConfig is a minimal copy of -// https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/apis/cri/runtime/v1alpha2/api.proto -// It only contains fields needed for testing. -type containerConfig struct { - Status containerStatus -} - -type containerStatus struct { - Network containerNetwork -} - -type containerNetwork struct { - IP string -} - -// PodIP returns a pod's IP address. -func (cc *Crictl) PodIP(podID string) (string, error) { - output, err := cc.run("inspectp", podID) - if err != nil { - return "", err - } - conf := &containerConfig{} - if err := json.Unmarshal([]byte(output), conf); err != nil { - return "", fmt.Errorf("failed to unmarshal JSON: %v, %s", err, output) - } - if conf.Status.Network.IP == "" { - return "", fmt.Errorf("no IP found in config: %s", output) - } - return conf.Status.Network.IP, nil -} - -// RmPod removes a container. It corresponds to `crictl rmp`. -func (cc *Crictl) RmPod(podID string) error { - _, err := cc.run("rmp", podID) - return err -} - -// StartContainer pulls the given image ands starts the container in the -// sandbox with the given podID. -func (cc *Crictl) StartContainer(podID, image, sbSpec, contSpec string) (string, error) { - // Write the specs to files that can be read by crictl. - sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec) - if err != nil { - return "", fmt.Errorf("failed to write sandbox spec: %v", err) - } - contSpecFile, err := testutil.WriteTmpFile("contSpec", contSpec) - if err != nil { - return "", fmt.Errorf("failed to write container spec: %v", err) - } - - return cc.startContainer(podID, image, sbSpecFile, contSpecFile) -} - -func (cc *Crictl) startContainer(podID, image, sbSpecFile, contSpecFile string) (string, error) { - if err := cc.Pull(image); err != nil { - return "", fmt.Errorf("failed to pull %s: %v", image, err) - } - - contID, err := cc.Create(podID, contSpecFile, sbSpecFile) - if err != nil { - return "", fmt.Errorf("failed to create container in pod %q: %v", podID, err) - } - - if _, err := cc.Start(contID); err != nil { - return "", fmt.Errorf("failed to start container %q in pod %q: %v", contID, podID, err) - } - - return contID, nil -} - -// StopContainer stops and deletes the container with the given container ID. -func (cc *Crictl) StopContainer(contID string) error { - if err := cc.Stop(contID); err != nil { - return fmt.Errorf("failed to stop container %q: %v", contID, err) - } - - if err := cc.Rm(contID); err != nil { - return fmt.Errorf("failed to remove container %q: %v", contID, err) - } - - return nil -} - -// StartPodAndContainer pulls an image, then starts a sandbox and container in -// that sandbox. It returns the pod ID and container ID. -func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) { - // Write the specs to files that can be read by crictl. - sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec) - if err != nil { - return "", "", fmt.Errorf("failed to write sandbox spec: %v", err) - } - contSpecFile, err := testutil.WriteTmpFile("contSpec", contSpec) - if err != nil { - return "", "", fmt.Errorf("failed to write container spec: %v", err) - } - - podID, err := cc.RunPod(sbSpecFile) - if err != nil { - return "", "", err - } - - contID, err := cc.startContainer(podID, image, sbSpecFile, contSpecFile) - - return podID, contID, err -} - -// StopPodAndContainer stops a container and pod. -func (cc *Crictl) StopPodAndContainer(podID, contID string) error { - if err := cc.StopContainer(contID); err != nil { - return fmt.Errorf("failed to stop container %q in pod %q: %v", contID, podID, err) - } - - if err := cc.StopPod(podID); err != nil { - return fmt.Errorf("failed to stop pod %q: %v", podID, err) - } - - if err := cc.RmPod(podID); err != nil { - return fmt.Errorf("failed to remove pod %q: %v", podID, err) - } - - return nil -} - -// run runs crictl with the given args and returns an error if it takes longer -// than cc.Timeout to run. -func (cc *Crictl) run(args ...string) (string, error) { - defaultArgs := []string{ - "--image-endpoint", cc.imageEndpoint, - "--runtime-endpoint", cc.runtimeEndpoint, - } - cmd := exec.Command(cc.executable, append(defaultArgs, args...)...) - - // Run the command with a timeout. - done := make(chan string) - errCh := make(chan error) - go func() { - output, err := cmd.CombinedOutput() - if err != nil { - errCh <- fmt.Errorf("error: \"%v\", output: %s", err, string(output)) - return - } - done <- string(output) - }() - select { - case output := <-done: - return output, nil - case err := <-errCh: - return "", err - case <-time.After(cc.timeout): - if err := testutil.KillCommand(cmd); err != nil { - return "", fmt.Errorf("timed out, then couldn't kill process %+v: %v", cmd, err) - } - return "", fmt.Errorf("timed out: %+v", cmd) - } -} diff --git a/runsc/debian/description b/runsc/debian/description deleted file mode 100644 index 9e8e08805..000000000 --- a/runsc/debian/description +++ /dev/null @@ -1 +0,0 @@ -gVisor container sandbox runtime diff --git a/runsc/debian/postinst.sh b/runsc/debian/postinst.sh deleted file mode 100755 index dc7aeee87..000000000 --- a/runsc/debian/postinst.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/sh -e - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -if [ "$1" != configure ]; then - exit 0 -fi - -if [ -f /etc/docker/daemon.json ]; then - runsc install - systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2 -fi diff --git a/runsc/dockerutil/BUILD b/runsc/dockerutil/BUILD deleted file mode 100644 index 8621af901..000000000 --- a/runsc/dockerutil/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "dockerutil", - testonly = 1, - srcs = ["dockerutil.go"], - visibility = ["//:sandbox"], - deps = [ - "//runsc/testutil", - "@com_github_kr_pty//:go_default_library", - ], -) diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go deleted file mode 100644 index 1ff5e8cc3..000000000 --- a/runsc/dockerutil/dockerutil.go +++ /dev/null @@ -1,476 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package dockerutil is a collection of utility functions, primarily for -// testing. -package dockerutil - -import ( - "encoding/json" - "flag" - "fmt" - "io/ioutil" - "log" - "os" - "os/exec" - "path" - "regexp" - "strconv" - "strings" - "syscall" - "time" - - "github.com/kr/pty" - "gvisor.dev/gvisor/runsc/testutil" -) - -var ( - runtime = flag.String("runtime", "runsc", "specify which runtime to use") - config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths") -) - -// EnsureSupportedDockerVersion checks if correct docker is installed. -func EnsureSupportedDockerVersion() { - cmd := exec.Command("docker", "version") - out, err := cmd.CombinedOutput() - if err != nil { - log.Fatalf("Error running %q: %v", "docker version", err) - } - re := regexp.MustCompile(`Version:\s+(\d+)\.(\d+)\.\d.*`) - matches := re.FindStringSubmatch(string(out)) - if len(matches) != 3 { - log.Fatalf("Invalid docker output: %s", out) - } - major, _ := strconv.Atoi(matches[1]) - minor, _ := strconv.Atoi(matches[2]) - if major < 17 || (major == 17 && minor < 9) { - log.Fatalf("Docker version 17.09.0 or greater is required, found: %02d.%02d", major, minor) - } -} - -// RuntimePath returns the binary path for the current runtime. -func RuntimePath() (string, error) { - // Read the configuration data; the file must exist. - configBytes, err := ioutil.ReadFile(*config) - if err != nil { - return "", err - } - - // Unmarshal the configuration. - c := make(map[string]interface{}) - if err := json.Unmarshal(configBytes, &c); err != nil { - return "", err - } - - // Decode the expected configuration. - r, ok := c["runtimes"] - if !ok { - return "", fmt.Errorf("no runtimes declared: %v", c) - } - rs, ok := r.(map[string]interface{}) - if !ok { - // The runtimes are not a map. - return "", fmt.Errorf("unexpected format: %v", c) - } - r, ok = rs[*runtime] - if !ok { - // The expected runtime is not declared. - return "", fmt.Errorf("runtime %q not found: %v", *runtime, c) - } - rs, ok = r.(map[string]interface{}) - if !ok { - // The runtime is not a map. - return "", fmt.Errorf("unexpected format: %v", c) - } - p, ok := rs["path"].(string) - if !ok { - // The runtime does not declare a path. - return "", fmt.Errorf("unexpected format: %v", c) - } - return p, nil -} - -// MountMode describes if the mount should be ro or rw. -type MountMode int - -const ( - // ReadOnly is what the name says. - ReadOnly MountMode = iota - // ReadWrite is what the name says. - ReadWrite -) - -// String returns the mount mode argument for this MountMode. -func (m MountMode) String() string { - switch m { - case ReadOnly: - return "ro" - case ReadWrite: - return "rw" - } - panic(fmt.Sprintf("invalid mode: %d", m)) -} - -// MountArg formats the volume argument to mount in the container. -func MountArg(source, target string, mode MountMode) string { - return fmt.Sprintf("-v=%s:%s:%v", source, target, mode) -} - -// LinkArg formats the link argument. -func LinkArg(source *Docker, target string) string { - return fmt.Sprintf("--link=%s:%s", source.Name, target) -} - -// PrepareFiles creates temp directory to copy files there. The sandbox doesn't -// have access to files in the test dir. -func PrepareFiles(names ...string) (string, error) { - dir, err := ioutil.TempDir("", "image-test") - if err != nil { - return "", fmt.Errorf("ioutil.TempDir failed: %v", err) - } - if err := os.Chmod(dir, 0777); err != nil { - return "", fmt.Errorf("os.Chmod(%q, 0777) failed: %v", dir, err) - } - for _, name := range names { - src, err := testutil.FindFile(name) - if err != nil { - return "", fmt.Errorf("testutil.Preparefiles(%q) failed: %v", name, err) - } - dst := path.Join(dir, path.Base(name)) - if err := testutil.Copy(src, dst); err != nil { - return "", fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) - } - } - return dir, nil -} - -// do executes docker command. -func do(args ...string) (string, error) { - log.Printf("Running: docker %s\n", args) - cmd := exec.Command("docker", args...) - out, err := cmd.CombinedOutput() - if err != nil { - return "", fmt.Errorf("error executing docker %s: %v\nout: %s", args, err, out) - } - return string(out), nil -} - -// doWithPty executes docker command with stdio attached to a pty. -func doWithPty(args ...string) (*exec.Cmd, *os.File, error) { - log.Printf("Running with pty: docker %s\n", args) - cmd := exec.Command("docker", args...) - ptmx, err := pty.Start(cmd) - if err != nil { - return nil, nil, fmt.Errorf("error executing docker %s with a pty: %v", args, err) - } - return cmd, ptmx, nil -} - -// Pull pulls a docker image. This is used in tests to isolate the -// time to pull the image off the network from the time to actually -// start the container, to avoid timeouts over slow networks. -func Pull(image string) error { - _, err := do("pull", image) - return err -} - -// Docker contains the name and the runtime of a docker container. -type Docker struct { - Runtime string - Name string -} - -// MakeDocker sets up the struct for a Docker container. -// Names of containers will be unique. -func MakeDocker(namePrefix string) Docker { - return Docker{ - Name: testutil.RandomName(namePrefix), - Runtime: *runtime, - } -} - -// logDockerID logs a container id, which is needed to find container runsc logs. -func (d *Docker) logDockerID() { - id, err := d.ID() - if err != nil { - log.Printf("%v\n", err) - } - log.Printf("Name: %s ID: %v\n", d.Name, id) -} - -// Create calls 'docker create' with the arguments provided. -func (d *Docker) Create(args ...string) error { - a := []string{"create", "--runtime", d.Runtime, "--name", d.Name} - a = append(a, args...) - _, err := do(a...) - if err == nil { - d.logDockerID() - } - return err -} - -// Start calls 'docker start'. -func (d *Docker) Start() error { - if _, err := do("start", d.Name); err != nil { - return fmt.Errorf("error starting container %q: %v", d.Name, err) - } - return nil -} - -// Stop calls 'docker stop'. -func (d *Docker) Stop() error { - if _, err := do("stop", d.Name); err != nil { - return fmt.Errorf("error stopping container %q: %v", d.Name, err) - } - return nil -} - -// Run calls 'docker run' with the arguments provided. The container starts -// running in the background and the call returns immediately. -func (d *Docker) Run(args ...string) error { - a := d.runArgs("-d") - a = append(a, args...) - _, err := do(a...) - if err == nil { - d.logDockerID() - } - return err -} - -// RunWithPty is like Run but with an attached pty. -func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) { - a := d.runArgs("-it") - a = append(a, args...) - return doWithPty(a...) -} - -// RunFg calls 'docker run' with the arguments provided in the foreground. It -// blocks until the container exits and returns the output. -func (d *Docker) RunFg(args ...string) (string, error) { - a := d.runArgs(args...) - out, err := do(a...) - if err == nil { - d.logDockerID() - } - return string(out), err -} - -func (d *Docker) runArgs(args ...string) []string { - // Environment variable RUNSC_TEST_NAME is picked up by the runtime and added - // to the log name, so one can easily identify the corresponding logs for - // this test. - rv := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-e", "RUNSC_TEST_NAME=" + d.Name} - return append(rv, args...) -} - -// Logs calls 'docker logs'. -func (d *Docker) Logs() (string, error) { - return do("logs", d.Name) -} - -// Exec calls 'docker exec' with the arguments provided. -func (d *Docker) Exec(args ...string) (string, error) { - return d.ExecWithFlags(nil, args...) -} - -// ExecWithFlags calls 'docker exec <flags> name <args>'. -func (d *Docker) ExecWithFlags(flags []string, args ...string) (string, error) { - a := []string{"exec"} - a = append(a, flags...) - a = append(a, d.Name) - a = append(a, args...) - return do(a...) -} - -// ExecAsUser calls 'docker exec' as the given user with the arguments -// provided. -func (d *Docker) ExecAsUser(user string, args ...string) (string, error) { - a := []string{"exec", "--user", user, d.Name} - a = append(a, args...) - return do(a...) -} - -// ExecWithTerminal calls 'docker exec -it' with the arguments provided and -// attaches a pty to stdio. -func (d *Docker) ExecWithTerminal(args ...string) (*exec.Cmd, *os.File, error) { - a := []string{"exec", "-it", d.Name} - a = append(a, args...) - return doWithPty(a...) -} - -// Pause calls 'docker pause'. -func (d *Docker) Pause() error { - if _, err := do("pause", d.Name); err != nil { - return fmt.Errorf("error pausing container %q: %v", d.Name, err) - } - return nil -} - -// Unpause calls 'docker pause'. -func (d *Docker) Unpause() error { - if _, err := do("unpause", d.Name); err != nil { - return fmt.Errorf("error unpausing container %q: %v", d.Name, err) - } - return nil -} - -// Checkpoint calls 'docker checkpoint'. -func (d *Docker) Checkpoint(name string) error { - if _, err := do("checkpoint", "create", d.Name, name); err != nil { - return fmt.Errorf("error pausing container %q: %v", d.Name, err) - } - return nil -} - -// Restore calls 'docker start --checkname [name]'. -func (d *Docker) Restore(name string) error { - if _, err := do("start", "--checkpoint", name, d.Name); err != nil { - return fmt.Errorf("error starting container %q: %v", d.Name, err) - } - return nil -} - -// Remove calls 'docker rm'. -func (d *Docker) Remove() error { - if _, err := do("rm", d.Name); err != nil { - return fmt.Errorf("error deleting container %q: %v", d.Name, err) - } - return nil -} - -// CleanUp kills and deletes the container (best effort). -func (d *Docker) CleanUp() { - d.logDockerID() - if _, err := do("kill", d.Name); err != nil { - if strings.Contains(err.Error(), "is not running") { - // Nothing to kill. Don't log the error in this case. - } else { - log.Printf("error killing container %q: %v", d.Name, err) - } - } - if err := d.Remove(); err != nil { - log.Print(err) - } -} - -// FindPort returns the host port that is mapped to 'sandboxPort'. This calls -// docker to allocate a free port in the host and prevent conflicts. -func (d *Docker) FindPort(sandboxPort int) (int, error) { - format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort) - out, err := do("inspect", "-f", format, d.Name) - if err != nil { - return -1, fmt.Errorf("error retrieving port: %v", err) - } - port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing port %q: %v", out, err) - } - return port, nil -} - -// FindIP returns the IP address of the container as a string. -func (d *Docker) FindIP() (string, error) { - const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}` - out, err := do("inspect", "-f", format, d.Name) - if err != nil { - return "", fmt.Errorf("error retrieving IP: %v", err) - } - return strings.TrimSpace(out), nil -} - -// SandboxPid returns the PID to the sandbox process. -func (d *Docker) SandboxPid() (int, error) { - out, err := do("inspect", "-f={{.State.Pid}}", d.Name) - if err != nil { - return -1, fmt.Errorf("error retrieving pid: %v", err) - } - pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing pid %q: %v", out, err) - } - return pid, nil -} - -// ID returns the container ID. -func (d *Docker) ID() (string, error) { - out, err := do("inspect", "-f={{.Id}}", d.Name) - if err != nil { - return "", fmt.Errorf("error retrieving ID: %v", err) - } - return strings.TrimSpace(string(out)), nil -} - -// Wait waits for container to exit, up to the given timeout. Returns error if -// wait fails or timeout is hit. Returns the application return code otherwise. -// Note that the application may have failed even if err == nil, always check -// the exit code. -func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) { - timeoutChan := time.After(timeout) - waitChan := make(chan (syscall.WaitStatus)) - errChan := make(chan (error)) - - go func() { - out, err := do("wait", d.Name) - if err != nil { - errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err) - } - exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err) - } - waitChan <- syscall.WaitStatus(uint32(exit)) - }() - - select { - case ws := <-waitChan: - return ws, nil - case err := <-errChan: - return syscall.WaitStatus(1), err - case <-timeoutChan: - return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name) - } -} - -// WaitForOutput calls 'docker logs' to retrieve containers output and searches -// for the given pattern. -func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) { - matches, err := d.WaitForOutputSubmatch(pattern, timeout) - if err != nil { - return "", err - } - if len(matches) == 0 { - return "", nil - } - return matches[0], nil -} - -// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and -// searches for the given pattern. It returns any regexp submatches as well. -func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) { - re := regexp.MustCompile(pattern) - var out string - for exp := time.Now().Add(timeout); time.Now().Before(exp); { - var err error - out, err = d.Logs() - if err != nil { - return nil, err - } - if matches := re.FindStringSubmatch(out); matches != nil { - // Success! - return matches, nil - } - time.Sleep(100 * time.Millisecond) - } - return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), out) -} diff --git a/runsc/flag/BUILD b/runsc/flag/BUILD deleted file mode 100644 index 5cb7604a8..000000000 --- a/runsc/flag/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "flag", - srcs = ["flag.go"], - visibility = ["//:sandbox"], -) diff --git a/runsc/flag/flag.go b/runsc/flag/flag.go index 0ca4829d7..0ca4829d7 100644..100755 --- a/runsc/flag/flag.go +++ b/runsc/flag/flag.go diff --git a/runsc/flag/flag_state_autogen.go b/runsc/flag/flag_state_autogen.go new file mode 100755 index 000000000..933063e6c --- /dev/null +++ b/runsc/flag/flag_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package flag diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD deleted file mode 100644 index 64a406ae2..000000000 --- a/runsc/fsgofer/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fsgofer", - srcs = [ - "fsgofer.go", - "fsgofer_amd64_unsafe.go", - "fsgofer_arm64_unsafe.go", - "fsgofer_unsafe.go", - ], - visibility = ["//runsc:__subpackages__"], - deps = [ - "//pkg/abi/linux", - "//pkg/fd", - "//pkg/log", - "//pkg/p9", - "//pkg/sync", - "//pkg/syserr", - "//runsc/specutils", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fsgofer_test", - size = "small", - srcs = ["fsgofer_test.go"], - library = ":fsgofer", - deps = [ - "//pkg/log", - "//pkg/p9", - ], -) diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD deleted file mode 100644 index 82b48ef32..000000000 --- a/runsc/fsgofer/filter/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "filter", - srcs = [ - "config.go", - "config_amd64.go", - "config_arm64.go", - "extra_filters.go", - "extra_filters_msan.go", - "extra_filters_race.go", - "filter.go", - ], - visibility = [ - "//runsc:__subpackages__", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/flipcall", - "//pkg/log", - "//pkg/seccomp", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go index a4b28cb8b..a4b28cb8b 100644..100755 --- a/runsc/fsgofer/filter/config_amd64.go +++ b/runsc/fsgofer/filter/config_amd64.go diff --git a/runsc/fsgofer/filter/config_arm64.go b/runsc/fsgofer/filter/config_arm64.go index d2697deb7..d2697deb7 100644..100755 --- a/runsc/fsgofer/filter/config_arm64.go +++ b/runsc/fsgofer/filter/config_arm64.go diff --git a/runsc/fsgofer/filter/filter_amd64_state_autogen.go b/runsc/fsgofer/filter/filter_amd64_state_autogen.go new file mode 100755 index 000000000..0f27e5568 --- /dev/null +++ b/runsc/fsgofer/filter/filter_amd64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package filter diff --git a/runsc/fsgofer/filter/filter_arm64_state_autogen.go b/runsc/fsgofer/filter/filter_arm64_state_autogen.go new file mode 100755 index 000000000..e87cf5af7 --- /dev/null +++ b/runsc/fsgofer/filter/filter_arm64_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package filter diff --git a/runsc/fsgofer/filter/filter_state_autogen.go b/runsc/fsgofer/filter/filter_state_autogen.go new file mode 100755 index 000000000..545d526ae --- /dev/null +++ b/runsc/fsgofer/filter/filter_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +// +build !msan,!race +// +build msan +// +build race + +package filter diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go index 5d4aab597..5d4aab597 100644..100755 --- a/runsc/fsgofer/fsgofer_amd64_unsafe.go +++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe_state_autogen.go b/runsc/fsgofer/fsgofer_amd64_unsafe_state_autogen.go new file mode 100755 index 000000000..df6721aaa --- /dev/null +++ b/runsc/fsgofer/fsgofer_amd64_unsafe_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build amd64 + +package fsgofer diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go index 8041fd352..8041fd352 100644..100755 --- a/runsc/fsgofer/fsgofer_arm64_unsafe.go +++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe_state_autogen.go b/runsc/fsgofer/fsgofer_arm64_unsafe_state_autogen.go new file mode 100755 index 000000000..d2a18c61c --- /dev/null +++ b/runsc/fsgofer/fsgofer_arm64_unsafe_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build arm64 + +package fsgofer diff --git a/runsc/fsgofer/fsgofer_state_autogen.go b/runsc/fsgofer/fsgofer_state_autogen.go new file mode 100755 index 000000000..d2f978fb9 --- /dev/null +++ b/runsc/fsgofer/fsgofer_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fsgofer diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go deleted file mode 100644 index 05af7e397..000000000 --- a/runsc/fsgofer/fsgofer_test.go +++ /dev/null @@ -1,692 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fsgofer - -import ( - "fmt" - "io/ioutil" - "net" - "os" - "path" - "path/filepath" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/p9" -) - -func init() { - log.SetLevel(log.Debug) - - allConfs = append(allConfs, rwConfs...) - allConfs = append(allConfs, roConfs...) - - if err := OpenProcSelfFD(); err != nil { - panic(err) - } -} - -func assertPanic(t *testing.T, f func()) { - defer func() { - if r := recover(); r == nil { - t.Errorf("function did not panic") - } - }() - f() -} - -func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error { - want := make([]byte, len(content)) - copy(want, content) - - b := []byte("test-1-2-3") - w, err := f.WriteAt(b, uint64(len(content))) - if flags == p9.WriteOnly || flags == p9.ReadWrite { - if err != nil { - return fmt.Errorf("WriteAt(): %v", err) - } - if w != len(b) { - return fmt.Errorf("WriteAt() was partial, got: %d, want: %d", w, len(b)) - } - want = append(want, b...) - } else { - if e, ok := err.(syscall.Errno); !ok || e != syscall.EBADF { - return fmt.Errorf("WriteAt() should have failed, got: %d, want: EBADFD", err) - } - } - - rBuf := make([]byte, len(want)) - r, err := f.ReadAt(rBuf, 0) - if flags == p9.ReadOnly || flags == p9.ReadWrite { - if err != nil { - return fmt.Errorf("ReadAt(): %v", err) - } - if r != len(rBuf) { - return fmt.Errorf("ReadAt() was partial, got: %d, want: %d", r, len(rBuf)) - } - if string(rBuf) != string(want) { - return fmt.Errorf("ReadAt() wrong data, got: %s, want: %s", string(rBuf), want) - } - } else { - if e, ok := err.(syscall.Errno); !ok || e != syscall.EBADF { - return fmt.Errorf("ReadAt() should have failed, got: %d, want: EBADFD", err) - } - } - return nil -} - -var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} - -var ( - allTypes = []fileType{regular, directory, symlink} - - // allConfs is set in init() above. - allConfs []Config - - rwConfs = []Config{{ROMount: false}} - roConfs = []Config{{ROMount: true}} -) - -type state struct { - root *localFile - file *localFile - conf Config - ft fileType -} - -func (s state) String() string { - return fmt.Sprintf("type(%v)", s.ft) -} - -func runAll(t *testing.T, test func(*testing.T, state)) { - runCustom(t, allTypes, allConfs, test) -} - -func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testing.T, state)) { - for _, c := range confs { - t.Logf("Config: %+v", c) - - for _, ft := range types { - t.Logf("File type: %v", ft) - - path, name, err := setup(ft) - if err != nil { - t.Fatalf("%v", err) - } - defer os.RemoveAll(path) - - a, err := NewAttachPoint(path, c) - if err != nil { - t.Fatalf("NewAttachPoint failed: %v", err) - } - root, err := a.Attach() - if err != nil { - t.Fatalf("Attach failed, err: %v", err) - } - - _, file, err := root.Walk([]string{name}) - if err != nil { - root.Close() - t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err) - } - - st := state{root: root.(*localFile), file: file.(*localFile), conf: c, ft: ft} - test(t, st) - file.Close() - root.Close() - } - } -} - -func setup(ft fileType) (string, string, error) { - path, err := ioutil.TempDir("", "root-") - if err != nil { - return "", "", fmt.Errorf("ioutil.TempDir() failed, err: %v", err) - } - - // First attach with writable configuration to setup tree. - a, err := NewAttachPoint(path, Config{}) - if err != nil { - return "", "", err - } - root, err := a.Attach() - if err != nil { - return "", "", fmt.Errorf("Attach failed, err: %v", err) - } - defer root.Close() - - var name string - switch ft { - case regular: - name = "file" - _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) - if err != nil { - return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err) - } - defer f.Close() - case directory: - name = "dir" - if _, err := root.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { - return "", "", fmt.Errorf("root.MkDir(%q) failed, err: %v", name, err) - } - case symlink: - name = "symlink" - if _, err := root.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { - return "", "", fmt.Errorf("root.Symlink(%q) failed, err: %v", name, err) - } - default: - panic(fmt.Sprintf("unknown file type %v", ft)) - } - return path, name, nil -} - -func createFile(dir *localFile, name string) (*localFile, error) { - _, f, _, _, err := dir.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) - if err != nil { - return nil, err - } - return f.(*localFile), nil -} - -func TestReadWrite(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { - child, err := createFile(s.file, "test") - if err != nil { - t.Fatalf("%v: createFile() failed, err: %v", s, err) - } - defer child.Close() - want := []byte("foobar") - w, err := child.WriteAt(want, 0) - if err != nil { - t.Fatalf("%v: Write() failed, err: %v", s, err) - } - if w != len(want) { - t.Fatalf("%v: Write() was partial, got: %d, expected: %d", s, w, len(want)) - } - for _, flags := range allOpenFlags { - _, l, err := s.file.Walk([]string{"test"}) - if err != nil { - t.Fatalf("%v: Walk(%s) failed, err: %v", s, "test", err) - } - if _, _, _, err := l.Open(flags); err != nil { - t.Fatalf("%v: Open(%v) failed, err: %v", s, flags, err) - } - if err := testReadWrite(l, flags, want); err != nil { - t.Fatalf("%v: testReadWrite(%v) failed: %v", s, flags, err) - } - } - }) -} - -func TestCreate(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { - for i, flags := range allOpenFlags { - _, l, _, _, err := s.file.Create(fmt.Sprintf("test-%d", i), flags, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) - if err != nil { - t.Fatalf("%v, %v: WriteAt() failed, err: %v", s, flags, err) - } - - if err := testReadWrite(l, flags, []byte{}); err != nil { - t.Fatalf("%v: testReadWrite(%v) failed: %v", s, flags, err) - } - } - }) -} - -// TestReadWriteDup tests that a file opened in any mode can be dup'ed and -// reopened in any other mode. -func TestReadWriteDup(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { - child, err := createFile(s.file, "test") - if err != nil { - t.Fatalf("%v: createFile() failed, err: %v", s, err) - } - defer child.Close() - want := []byte("foobar") - w, err := child.WriteAt(want, 0) - if err != nil { - t.Fatalf("%v: Write() failed, err: %v", s, err) - } - if w != len(want) { - t.Fatalf("%v: Write() was partial, got: %d, expected: %d", s, w, len(want)) - } - for _, flags := range allOpenFlags { - _, l, err := s.file.Walk([]string{"test"}) - if err != nil { - t.Fatalf("%v: Walk(%s) failed, err: %v", s, "test", err) - } - defer l.Close() - if _, _, _, err := l.Open(flags); err != nil { - t.Fatalf("%v: Open(%v) failed, err: %v", s, flags, err) - } - for _, dupFlags := range allOpenFlags { - t.Logf("Original flags: %v, dup flags: %v", flags, dupFlags) - _, dup, err := l.Walk([]string{}) - if err != nil { - t.Fatalf("%v: Walk(<empty>) failed: %v", s, err) - } - defer dup.Close() - if _, _, _, err := dup.Open(dupFlags); err != nil { - t.Fatalf("%v: Open(%v) failed: %v", s, flags, err) - } - if err := testReadWrite(dup, dupFlags, want); err != nil { - t.Fatalf("%v: testReadWrite(%v) failed: %v", s, dupFlags, err) - } - } - } - }) -} - -func TestUnopened(t *testing.T) { - runCustom(t, []fileType{regular}, allConfs, func(t *testing.T, s state) { - b := []byte("foobar") - if _, err := s.file.WriteAt(b, 0); err != syscall.EBADF { - t.Errorf("%v: WriteAt() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - if _, err := s.file.ReadAt(b, 0); err != syscall.EBADF { - t.Errorf("%v: ReadAt() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - if _, err := s.file.Readdir(0, 100); err != syscall.EBADF { - t.Errorf("%v: Readdir() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - if err := s.file.FSync(); err != syscall.EBADF { - t.Errorf("%v: FSync() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - }) -} - -func SetGetAttr(l *localFile, valid p9.SetAttrMask, attr p9.SetAttr) (p9.Attr, error) { - if err := l.SetAttr(valid, attr); err != nil { - return p9.Attr{}, err - } - _, _, a, err := l.GetAttr(p9.AttrMask{}) - if err != nil { - return p9.Attr{}, err - } - return a, nil -} - -func TestSetAttrPerm(t *testing.T) { - runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) { - valid := p9.SetAttrMask{Permissions: true} - attr := p9.SetAttr{Permissions: 0777} - got, err := SetGetAttr(s.file, valid, attr) - if s.ft == symlink { - if err == nil { - t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) - } - } else { - if err != nil { - t.Fatalf("%v: SetGetAttr(valid, %v) failed, err: %v", s, attr.Permissions, err) - } - if got.Mode.Permissions() != attr.Permissions { - t.Errorf("%v: wrong permission, got: %v, expected: %v", s, got.Mode.Permissions(), attr.Permissions) - } - } - }) -} - -func TestSetAttrSize(t *testing.T) { - runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) { - for _, size := range []uint64{1024, 0, 1024 * 1024} { - valid := p9.SetAttrMask{Size: true} - attr := p9.SetAttr{Size: size} - got, err := SetGetAttr(s.file, valid, attr) - if s.ft == symlink || s.ft == directory { - if err == nil { - t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) - } - // Run for one size only, they will all fail the same way. - return - } - if err != nil { - t.Fatalf("%v: SetGetAttr(valid, %v) failed, err: %v", s, attr.Size, err) - } - if got.Size != size { - t.Errorf("%v: wrong size, got: %v, expected: %v", s, got.Size, size) - } - } - }) -} - -func TestSetAttrTime(t *testing.T) { - runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) { - valid := p9.SetAttrMask{ATime: true, ATimeNotSystemTime: true} - attr := p9.SetAttr{ATimeSeconds: 123, ATimeNanoSeconds: 456} - got, err := SetGetAttr(s.file, valid, attr) - if err != nil { - t.Fatalf("%v: SetGetAttr(valid, %v:%v) failed, err: %v", s, attr.ATimeSeconds, attr.ATimeNanoSeconds, err) - } - if got.ATimeSeconds != 123 { - t.Errorf("%v: wrong ATimeSeconds, got: %v, expected: %v", s, got.ATimeSeconds, 123) - } - if got.ATimeNanoSeconds != 456 { - t.Errorf("%v: wrong ATimeNanoSeconds, got: %v, expected: %v", s, got.ATimeNanoSeconds, 456) - } - - valid = p9.SetAttrMask{MTime: true, MTimeNotSystemTime: true} - attr = p9.SetAttr{MTimeSeconds: 789, MTimeNanoSeconds: 012} - got, err = SetGetAttr(s.file, valid, attr) - if err != nil { - t.Fatalf("%v: SetGetAttr(valid, %v:%v) failed, err: %v", s, attr.MTimeSeconds, attr.MTimeNanoSeconds, err) - } - if got.MTimeSeconds != 789 { - t.Errorf("%v: wrong MTimeSeconds, got: %v, expected: %v", s, got.MTimeSeconds, 789) - } - if got.MTimeNanoSeconds != 012 { - t.Errorf("%v: wrong MTimeNanoSeconds, got: %v, expected: %v", s, got.MTimeNanoSeconds, 012) - } - }) -} - -func TestSetAttrOwner(t *testing.T) { - if os.Getuid() != 0 { - t.Skipf("SetAttr(owner) test requires CAP_CHOWN, running as %d", os.Getuid()) - } - - runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) { - newUID := os.Getuid() + 1 - valid := p9.SetAttrMask{UID: true} - attr := p9.SetAttr{UID: p9.UID(newUID)} - got, err := SetGetAttr(s.file, valid, attr) - if err != nil { - t.Fatalf("%v: SetGetAttr(valid, %v) failed, err: %v", s, attr.UID, err) - } - if got.UID != p9.UID(newUID) { - t.Errorf("%v: wrong uid, got: %v, expected: %v", s, got.UID, newUID) - } - }) -} - -func TestLink(t *testing.T) { - if os.Getuid() != 0 { - t.Skipf("Link test requires CAP_DAC_READ_SEARCH, running as %d", os.Getuid()) - } - runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) { - const dirName = "linkdir" - const linkFile = "link" - if _, err := s.root.Mkdir(dirName, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { - t.Fatalf("%v: MkDir(%s) failed, err: %v", s, dirName, err) - } - _, dir, err := s.root.Walk([]string{dirName}) - if err != nil { - t.Fatalf("%v: Walk({%s}) failed, err: %v", s, dirName, err) - } - - err = dir.Link(s.file, linkFile) - if s.ft == directory { - if err != syscall.EPERM { - t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: syscall.EPERM", s, linkFile, err) - } - return - } - if err != nil { - t.Errorf("%v: Link(target, %s) failed, err: %v", s, linkFile, err) - } - }) -} - -func TestROMountChecks(t *testing.T) { - runCustom(t, allTypes, roConfs, func(t *testing.T, s state) { - if _, _, _, _, err := s.file.Create("some_file", p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: Create() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - if _, err := s.file.Mkdir("some_dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: MkDir() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - if err := s.file.RenameAt("some_file", s.file, "other_file"); err != syscall.EBADF { - t.Errorf("%v: Rename() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - if _, err := s.file.Symlink("some_place", "some_symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: Symlink() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - if err := s.file.UnlinkAt("some_file", 0); err != syscall.EBADF { - t.Errorf("%v: UnlinkAt() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - if err := s.file.Link(s.file, "some_link"); err != syscall.EBADF { - t.Errorf("%v: Link() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - - valid := p9.SetAttrMask{Size: true} - attr := p9.SetAttr{Size: 0} - if err := s.file.SetAttr(valid, attr); err != syscall.EBADF { - t.Errorf("%v: SetAttr() should have failed, got: %v, expected: syscall.EBADF", s, err) - } - }) -} - -func TestROMountPanics(t *testing.T) { - conf := Config{ROMount: true, PanicOnWrite: true} - runCustom(t, allTypes, []Config{conf}, func(t *testing.T, s state) { - assertPanic(t, func() { s.file.Create("some_file", p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.Mkdir("some_dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.RenameAt("some_file", s.file, "other_file") }) - assertPanic(t, func() { s.file.Symlink("some_place", "some_symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.UnlinkAt("some_file", 0) }) - assertPanic(t, func() { s.file.Link(s.file, "some_link") }) - - valid := p9.SetAttrMask{Size: true} - attr := p9.SetAttr{Size: 0} - assertPanic(t, func() { s.file.SetAttr(valid, attr) }) - }) -} - -func TestWalkNotFound(t *testing.T) { - runCustom(t, []fileType{directory}, allConfs, func(t *testing.T, s state) { - if _, _, err := s.file.Walk([]string{"nobody-here"}); err != syscall.ENOENT { - t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: syscall.ENOENT", s, "nobody-here", err) - } - }) -} - -func TestWalkDup(t *testing.T) { - runAll(t, func(t *testing.T, s state) { - _, dup, err := s.file.Walk([]string{}) - if err != nil { - t.Fatalf("%v: Walk(nil) failed, err: %v", s, err) - } - // Check that 'dup' is usable. - if _, _, _, err := dup.GetAttr(p9.AttrMask{}); err != nil { - t.Errorf("%v: GetAttr() failed, err: %v", s, err) - } - }) -} - -func TestReaddir(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { - name := "dir" - if _, err := s.file.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { - t.Fatalf("%v: MkDir(%s) failed, err: %v", s, name, err) - } - name = "symlink" - if _, err := s.file.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { - t.Fatalf("%v: Symlink(%q) failed, err: %v", s, name, err) - } - name = "file" - _, f, _, _, err := s.file.Create(name, p9.ReadWrite, 0555, p9.UID(os.Getuid()), p9.GID(os.Getgid())) - if err != nil { - t.Fatalf("%v: createFile(root, %q) failed, err: %v", s, name, err) - } - f.Close() - - if _, _, _, err := s.file.Open(p9.ReadOnly); err != nil { - t.Fatalf("%v: Open(ReadOnly) failed, err: %v", s, err) - } - - dirents, err := s.file.Readdir(0, 10) - if err != nil { - t.Fatalf("%v: Readdir(0, 10) failed, err: %v", s, err) - } - if len(dirents) != 3 { - t.Fatalf("%v: Readdir(0, 10) wrong number of items, got: %v, expected: 3", s, len(dirents)) - } - var dir, symlink, file bool - for _, d := range dirents { - switch d.Name { - case "dir": - if d.Type != p9.TypeDir { - t.Errorf("%v: dirent.Type got: %v, expected: %v", s, d.Type, p9.TypeDir) - } - dir = true - case "symlink": - if d.Type != p9.TypeSymlink { - t.Errorf("%v: dirent.Type got: %v, expected: %v", s, d.Type, p9.TypeSymlink) - } - symlink = true - case "file": - if d.Type != p9.TypeRegular { - t.Errorf("%v: dirent.Type got: %v, expected: %v", s, d.Type, p9.TypeRegular) - } - file = true - default: - t.Errorf("%v: dirent.Name got: %v", s, d.Name) - } - - _, f, err := s.file.Walk([]string{d.Name}) - if err != nil { - t.Fatalf("%v: Walk({%s}) failed, err: %v", s, d.Name, err) - } - _, _, a, err := f.GetAttr(p9.AttrMask{}) - if err != nil { - t.Fatalf("%v: GetAttr() failed, err: %v", s, err) - } - if d.Type != a.Mode.QIDType() { - t.Errorf("%v: dirent.Type different than GetAttr().Mode.QIDType(), got: %v, expected: %v", s, d.Type, a.Mode.QIDType()) - } - } - if !dir || !symlink || !file { - t.Errorf("%v: Readdir(0, 10) wrong files returned, dir: %v, symlink: %v, file: %v", s, dir, symlink, file) - } - }) -} - -// Test that attach point can be written to when it points to a file, e.g. -// /etc/hosts. -func TestAttachFile(t *testing.T) { - conf := Config{ROMount: false} - dir, err := ioutil.TempDir("", "root-") - if err != nil { - t.Fatalf("ioutil.TempDir() failed, err: %v", err) - } - defer os.RemoveAll(dir) - - path := path.Join(dir, "test") - if _, err := os.Create(path); err != nil { - t.Fatalf("os.Create(%q) failed, err: %v", path, err) - } - - a, err := NewAttachPoint(path, conf) - if err != nil { - t.Fatalf("NewAttachPoint failed: %v", err) - } - root, err := a.Attach() - if err != nil { - t.Fatalf("Attach failed, err: %v", err) - } - - if _, _, _, err := root.Open(p9.ReadWrite); err != nil { - t.Fatalf("Open(ReadWrite) failed, err: %v", err) - } - defer root.Close() - - b := []byte("foobar") - w, err := root.WriteAt(b, 0) - if err != nil { - t.Fatalf("Write() failed, err: %v", err) - } - if w != len(b) { - t.Fatalf("Write() was partial, got: %d, expected: %d", w, len(b)) - } - rBuf := make([]byte, len(b)) - r, err := root.ReadAt(rBuf, 0) - if err != nil { - t.Fatalf("ReadAt() failed, err: %v", err) - } - if r != len(rBuf) { - t.Fatalf("ReadAt() was partial, got: %d, expected: %d", r, len(rBuf)) - } - if string(rBuf) != "foobar" { - t.Fatalf("ReadAt() wrong data, got: %s, expected: %s", string(rBuf), "foobar") - } -} - -func TestAttachInvalidType(t *testing.T) { - dir, err := ioutil.TempDir("", "attach-") - if err != nil { - t.Fatalf("ioutil.TempDir() failed, err: %v", err) - } - defer os.RemoveAll(dir) - - fifo := filepath.Join(dir, "fifo") - if err := syscall.Mkfifo(fifo, 0755); err != nil { - t.Fatalf("Mkfifo(%q): %v", fifo, err) - } - - dirFile, err := os.Open(dir) - if err != nil { - t.Fatalf("Open(%s): %v", dir, err) - } - defer dirFile.Close() - - // Bind a socket via /proc to be sure that a length of a socket path - // is less than UNIX_PATH_MAX. - socket := filepath.Join(fmt.Sprintf("/proc/self/fd/%d", dirFile.Fd()), "socket") - l, err := net.Listen("unix", socket) - if err != nil { - t.Fatalf("net.Listen(unix, %q): %v", socket, err) - } - defer l.Close() - - for _, tc := range []struct { - name string - path string - }{ - {name: "fifo", path: fifo}, - {name: "socket", path: socket}, - } { - t.Run(tc.name, func(t *testing.T) { - conf := Config{ROMount: false} - a, err := NewAttachPoint(tc.path, conf) - if err != nil { - t.Fatalf("NewAttachPoint failed: %v", err) - } - f, err := a.Attach() - if f != nil || err == nil { - t.Fatalf("Attach should have failed, got (%v, %v)", f, err) - } - }) - } -} - -func TestDoubleAttachError(t *testing.T) { - conf := Config{ROMount: false} - root, err := ioutil.TempDir("", "root-") - if err != nil { - t.Fatalf("ioutil.TempDir() failed, err: %v", err) - } - defer os.RemoveAll(root) - a, err := NewAttachPoint(root, conf) - if err != nil { - t.Fatalf("NewAttachPoint failed: %v", err) - } - - if _, err := a.Attach(); err != nil { - t.Fatalf("Attach failed: %v", err) - } - if _, err := a.Attach(); err == nil { - t.Fatalf("Attach should have failed, got %v want non-nil", err) - } -} diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD deleted file mode 100644 index c95d50294..000000000 --- a/runsc/sandbox/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sandbox", - srcs = [ - "network.go", - "network_unsafe.go", - "sandbox.go", - ], - visibility = [ - "//runsc:__subpackages__", - ], - deps = [ - "//pkg/control/client", - "//pkg/control/server", - "//pkg/log", - "//pkg/sentry/control", - "//pkg/sentry/platform", - "//pkg/sync", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/urpc", - "//runsc/boot", - "//runsc/boot/platforms", - "//runsc/cgroup", - "//runsc/console", - "//runsc/specutils", - "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@com_github_syndtr_gocapability//capability:go_default_library", - "@com_github_vishvananda_netlink//:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/runsc/sandbox/sandbox_state_autogen.go b/runsc/sandbox/sandbox_state_autogen.go new file mode 100755 index 000000000..79ebc2220 --- /dev/null +++ b/runsc/sandbox/sandbox_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sandbox diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD deleted file mode 100644 index 4ccd77f63..000000000 --- a/runsc/specutils/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "specutils", - srcs = [ - "cri.go", - "fs.go", - "namespace.go", - "specutils.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/log", - "//pkg/sentry/kernel/auth", - "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@com_github_syndtr_gocapability//capability:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "specutils_test", - size = "small", - srcs = ["specutils_test.go"], - library = ":specutils", - deps = ["@com_github_opencontainers_runtime-spec//specs-go:go_default_library"], -) diff --git a/runsc/specutils/cri.go b/runsc/specutils/cri.go index 9c5877cd5..9c5877cd5 100644..100755 --- a/runsc/specutils/cri.go +++ b/runsc/specutils/cri.go diff --git a/runsc/specutils/specutils_state_autogen.go b/runsc/specutils/specutils_state_autogen.go new file mode 100755 index 000000000..11eefbaa2 --- /dev/null +++ b/runsc/specutils/specutils_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package specutils diff --git a/runsc/specutils/specutils_test.go b/runsc/specutils/specutils_test.go deleted file mode 100644 index 2c86fffe8..000000000 --- a/runsc/specutils/specutils_test.go +++ /dev/null @@ -1,265 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package specutils - -import ( - "fmt" - "os/exec" - "strings" - "testing" - "time" - - specs "github.com/opencontainers/runtime-spec/specs-go" -) - -func TestWaitForReadyHappy(t *testing.T) { - cmd := exec.Command("/bin/sleep", "1000") - if err := cmd.Start(); err != nil { - t.Fatalf("cmd.Start() failed, err: %v", err) - } - defer cmd.Wait() - - var count int - err := WaitForReady(cmd.Process.Pid, 5*time.Second, func() (bool, error) { - if count < 3 { - count++ - return false, nil - } - return true, nil - }) - if err != nil { - t.Errorf("ProcessWaitReady got: %v, expected: nil", err) - } - cmd.Process.Kill() -} - -func TestWaitForReadyFail(t *testing.T) { - cmd := exec.Command("/bin/sleep", "1000") - if err := cmd.Start(); err != nil { - t.Fatalf("cmd.Start() failed, err: %v", err) - } - defer cmd.Wait() - - var count int - err := WaitForReady(cmd.Process.Pid, 5*time.Second, func() (bool, error) { - if count < 3 { - count++ - return false, nil - } - return false, fmt.Errorf("Fake error") - }) - if err == nil { - t.Errorf("ProcessWaitReady got: nil, expected: error") - } - cmd.Process.Kill() -} - -func TestWaitForReadyNotRunning(t *testing.T) { - cmd := exec.Command("/bin/true") - if err := cmd.Start(); err != nil { - t.Fatalf("cmd.Start() failed, err: %v", err) - } - defer cmd.Wait() - - err := WaitForReady(cmd.Process.Pid, 5*time.Second, func() (bool, error) { - return false, nil - }) - if err != nil && !strings.Contains(err.Error(), "terminated") { - t.Errorf("ProcessWaitReady got: %v, expected: process terminated", err) - } - if err == nil { - t.Errorf("ProcessWaitReady incorrectly succeeded") - } -} - -func TestWaitForReadyTimeout(t *testing.T) { - cmd := exec.Command("/bin/sleep", "1000") - if err := cmd.Start(); err != nil { - t.Fatalf("cmd.Start() failed, err: %v", err) - } - defer cmd.Wait() - - err := WaitForReady(cmd.Process.Pid, 50*time.Millisecond, func() (bool, error) { - return false, nil - }) - if !strings.Contains(err.Error(), "not running yet") { - t.Errorf("ProcessWaitReady got: %v, expected: not running yet", err) - } - cmd.Process.Kill() -} - -func TestSpecInvalid(t *testing.T) { - for _, test := range []struct { - name string - spec specs.Spec - error string - }{ - { - name: "valid", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - Mounts: []specs.Mount{ - { - Source: "src", - Destination: "/dst", - }, - }, - }, - error: "", - }, - { - name: "valid+warning", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - // This is normally set by docker and will just cause warnings to be logged. - ApparmorProfile: "someprofile", - }, - // This is normally set by docker and will just cause warnings to be logged. - Linux: &specs.Linux{Seccomp: &specs.LinuxSeccomp{}}, - }, - error: "", - }, - { - name: "no root", - spec: specs.Spec{ - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - }, - error: "must be defined", - }, - { - name: "empty root", - spec: specs.Spec{ - Root: &specs.Root{}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - }, - error: "must be defined", - }, - { - name: "no process", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - }, - error: "must be defined", - }, - { - name: "empty args", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{}, - }, - error: "must be defined", - }, - { - name: "selinux", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - SelinuxLabel: "somelabel", - }, - }, - error: "is not supported", - }, - { - name: "solaris", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - Solaris: &specs.Solaris{}, - }, - error: "is not supported", - }, - { - name: "windows", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - Windows: &specs.Windows{}, - }, - error: "is not supported", - }, - { - name: "relative mount destination", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - Mounts: []specs.Mount{ - { - Source: "src", - Destination: "dst", - }, - }, - }, - error: "must be an absolute path", - }, - { - name: "invalid mount option", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - Mounts: []specs.Mount{ - { - Source: "/src", - Destination: "/dst", - Type: "bind", - Options: []string{"shared"}, - }, - }, - }, - error: "is not supported", - }, - { - name: "invalid rootfs propagation", - spec: specs.Spec{ - Root: &specs.Root{Path: "/"}, - Process: &specs.Process{ - Args: []string{"/bin/true"}, - }, - Linux: &specs.Linux{ - RootfsPropagation: "foo", - }, - }, - error: "root mount propagation option must specify private or slave", - }, - } { - err := ValidateSpec(&test.spec) - if len(test.error) == 0 { - if err != nil { - t.Errorf("ValidateSpec(%q) failed, err: %v", test.name, err) - } - } else { - if err == nil || !strings.Contains(err.Error(), test.error) { - t.Errorf("ValidateSpec(%q) wrong error, got: %v, want: .*%s.*", test.name, err, test.error) - } - } - } -} diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD deleted file mode 100644 index 945405303..000000000 --- a/runsc/testutil/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - testonly = 1, - srcs = [ - "testutil.go", - "testutil_runfiles.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//runsc/boot", - "//runsc/specutils", - "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - ], -) diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go deleted file mode 100644 index 51e487715..000000000 --- a/runsc/testutil/testutil.go +++ /dev/null @@ -1,432 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil contains utility functions for runsc tests. -package testutil - -import ( - "bufio" - "context" - "debug/elf" - "encoding/base32" - "encoding/json" - "flag" - "fmt" - "io" - "io/ioutil" - "math" - "math/rand" - "net/http" - "os" - "os/exec" - "os/signal" - "path/filepath" - "strconv" - "strings" - "sync/atomic" - "syscall" - "time" - - "github.com/cenkalti/backoff" - specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/runsc/boot" - "gvisor.dev/gvisor/runsc/specutils" -) - -var ( - checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support") -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -// IsCheckpointSupported returns the relevant command line flag. -func IsCheckpointSupported() bool { - return *checkpoint -} - -// TmpDir returns the absolute path to a writable directory that can be used as -// scratch by the test. -func TmpDir() string { - dir := os.Getenv("TEST_TMPDIR") - if dir == "" { - dir = "/tmp" - } - return dir -} - -// ConfigureExePath configures the executable for runsc in the test environment. -func ConfigureExePath() error { - path, err := FindFile("runsc/runsc") - if err != nil { - return err - } - specutils.ExePath = path - return nil -} - -// TestConfig returns the default configuration to use in tests. Note that -// 'RootDir' must be set by caller if required. -func TestConfig() *boot.Config { - logDir := "" - if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { - logDir = dir + "/" - } - return &boot.Config{ - Debug: true, - DebugLog: logDir, - LogFormat: "text", - DebugLogFormat: "text", - AlsoLogToStderr: true, - LogPackets: true, - Network: boot.NetworkNone, - Strace: true, - Platform: "ptrace", - FileAccess: boot.FileAccessExclusive, - NumNetworkChannels: 1, - - TestOnlyAllowRunAsCurrentUserWithoutChroot: true, - } -} - -// NewSpecWithArgs creates a simple spec with the given args suitable for use -// in tests. -func NewSpecWithArgs(args ...string) *specs.Spec { - return &specs.Spec{ - // The host filesystem root is the container root. - Root: &specs.Root{ - Path: "/", - Readonly: true, - }, - Process: &specs.Process{ - Args: args, - Env: []string{ - "PATH=" + os.Getenv("PATH"), - }, - Capabilities: specutils.AllCapabilities(), - }, - Mounts: []specs.Mount{ - // Hide the host /etc to avoid any side-effects. - // For example, bash reads /etc/passwd and if it is - // very big, tests can fail by timeout. - { - Type: "tmpfs", - Destination: "/etc", - }, - // Root is readonly, but many tests want to write to tmpdir. - // This creates a writable mount inside the root. Also, when tmpdir points - // to "/tmp", it makes the the actual /tmp to be mounted and not a tmpfs - // inside the sentry. - { - Type: "bind", - Destination: TmpDir(), - Source: TmpDir(), - }, - }, - Hostname: "runsc-test-hostname", - } -} - -// SetupRootDir creates a root directory for containers. -func SetupRootDir() (string, error) { - rootDir, err := ioutil.TempDir(TmpDir(), "containers") - if err != nil { - return "", fmt.Errorf("error creating root dir: %v", err) - } - return rootDir, nil -} - -// SetupContainer creates a bundle and root dir for the container, generates a -// test config, and writes the spec to config.json in the bundle dir. -func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, err error) { - rootDir, err = SetupRootDir() - if err != nil { - return "", "", err - } - conf.RootDir = rootDir - bundleDir, err = SetupBundleDir(spec) - return rootDir, bundleDir, err -} - -// SetupBundleDir creates a bundle dir and writes the spec to config.json. -func SetupBundleDir(spec *specs.Spec) (bundleDir string, err error) { - bundleDir, err = ioutil.TempDir(TmpDir(), "bundle") - if err != nil { - return "", fmt.Errorf("error creating bundle dir: %v", err) - } - - if err = writeSpec(bundleDir, spec); err != nil { - return "", fmt.Errorf("error writing spec: %v", err) - } - return bundleDir, nil -} - -// writeSpec writes the spec to disk in the given directory. -func writeSpec(dir string, spec *specs.Spec) error { - b, err := json.Marshal(spec) - if err != nil { - return err - } - return ioutil.WriteFile(filepath.Join(dir, "config.json"), b, 0755) -} - -// UniqueContainerID generates a unique container id for each test. -// -// The container id is used to create an abstract unix domain socket, which must -// be unique. While the container forbids creating two containers with the same -// name, sometimes between test runs the socket does not get cleaned up quickly -// enough, causing container creation to fail. -func UniqueContainerID() string { - // Read 20 random bytes. - b := make([]byte, 20) - // "[Read] always returns len(p) and a nil error." --godoc - if _, err := rand.Read(b); err != nil { - panic("rand.Read failed: " + err.Error()) - } - // base32 encode the random bytes, so that the name is a valid - // container id and can be used as a socket name in the filesystem. - return fmt.Sprintf("test-container-%s", base32.StdEncoding.EncodeToString(b)) -} - -// Copy copies file from src to dst. -func Copy(src, dst string) error { - in, err := os.Open(src) - if err != nil { - return err - } - defer in.Close() - - out, err := os.Create(dst) - if err != nil { - return err - } - defer out.Close() - - _, err = io.Copy(out, in) - return err -} - -// Poll is a shorthand function to poll for something with given timeout. -func Poll(cb func() error, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx) - return backoff.Retry(cb, b) -} - -// WaitForHTTP tries GET requests on a port until the call succeeds or timeout. -func WaitForHTTP(port int, timeout time.Duration) error { - cb := func() error { - c := &http.Client{ - // Calculate timeout to be able to do minimum 5 attempts. - Timeout: timeout / 5, - } - url := fmt.Sprintf("http://localhost:%d/", port) - resp, err := c.Get(url) - if err != nil { - log.Infof("Waiting %s: %v", url, err) - return err - } - resp.Body.Close() - return nil - } - return Poll(cb, timeout) -} - -// Reaper reaps child processes. -type Reaper struct { - // mu protects ch, which will be nil if the reaper is not running. - mu sync.Mutex - ch chan os.Signal -} - -// Start starts reaping child processes. -func (r *Reaper) Start() { - r.mu.Lock() - defer r.mu.Unlock() - - if r.ch != nil { - panic("reaper.Start called on a running reaper") - } - - r.ch = make(chan os.Signal, 1) - signal.Notify(r.ch, syscall.SIGCHLD) - - go func() { - for { - r.mu.Lock() - ch := r.ch - r.mu.Unlock() - if ch == nil { - return - } - - _, ok := <-ch - if !ok { - // Channel closed. - return - } - for { - cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil) - if cpid < 1 { - break - } - } - } - }() -} - -// Stop stops reaping child processes. -func (r *Reaper) Stop() { - r.mu.Lock() - defer r.mu.Unlock() - - if r.ch == nil { - panic("reaper.Stop called on a stopped reaper") - } - - signal.Stop(r.ch) - close(r.ch) - r.ch = nil -} - -// StartReaper is a helper that starts a new Reaper and returns a function to -// stop it. -func StartReaper() func() { - r := &Reaper{} - r.Start() - return r.Stop -} - -// WaitUntilRead reads from the given reader until the wanted string is found -// or until timeout. -func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time.Duration) error { - sc := bufio.NewScanner(r) - if split != nil { - sc.Split(split) - } - // done must be accessed atomically. A value greater than 0 indicates - // that the read loop can exit. - var done uint32 - doneCh := make(chan struct{}) - go func() { - for sc.Scan() { - t := sc.Text() - if strings.Contains(t, want) { - atomic.StoreUint32(&done, 1) - close(doneCh) - break - } - if atomic.LoadUint32(&done) > 0 { - break - } - } - }() - select { - case <-time.After(timeout): - atomic.StoreUint32(&done, 1) - return fmt.Errorf("timeout waiting to read %q", want) - case <-doneCh: - return nil - } -} - -// KillCommand kills the process running cmd unless it hasn't been started. It -// returns an error if it cannot kill the process unless the reason is that the -// process has already exited. -func KillCommand(cmd *exec.Cmd) error { - if cmd.Process == nil { - return nil - } - if err := cmd.Process.Kill(); err != nil { - if !strings.Contains(err.Error(), "process already finished") { - return fmt.Errorf("failed to kill process %v: %v", cmd, err) - } - } - return nil -} - -// WriteTmpFile writes text to a temporary file, closes the file, and returns -// the name of the file. -func WriteTmpFile(pattern, text string) (string, error) { - file, err := ioutil.TempFile(TmpDir(), pattern) - if err != nil { - return "", err - } - defer file.Close() - if _, err := file.Write([]byte(text)); err != nil { - return "", err - } - return file.Name(), nil -} - -// RandomName create a name with a 6 digit random number appended to it. -func RandomName(prefix string) string { - return fmt.Sprintf("%s-%06d", prefix, rand.Int31n(1000000)) -} - -// IsStatic returns true iff the given file is a static binary. -func IsStatic(filename string) (bool, error) { - f, err := elf.Open(filename) - if err != nil { - return false, err - } - for _, prog := range f.Progs { - if prog.Type == elf.PT_INTERP { - return false, nil // Has interpreter. - } - } - return true, nil -} - -// TestIndicesForShard returns indices for this test shard based on the -// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. -// -// If either of the env vars are not present, then the function will return all -// tests. If there are more shards than there are tests, then the returned list -// may be empty. -func TestIndicesForShard(numTests int) ([]int, error) { - var ( - shardIndex = 0 - shardTotal = 1 - ) - - indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS") - if indexStr != "" && totalStr != "" { - // Parse index and total to ints. - var err error - shardIndex, err = strconv.Atoi(indexStr) - if err != nil { - return nil, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err) - } - shardTotal, err = strconv.Atoi(totalStr) - if err != nil { - return nil, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err) - } - } - - // Calculate! - var indices []int - numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal))) - for i := 0; i < numBlocks; i++ { - pick := i*shardTotal + shardIndex - if pick < numTests { - indices = append(indices, pick) - } - } - return indices, nil -} diff --git a/runsc/testutil/testutil_runfiles.go b/runsc/testutil/testutil_runfiles.go deleted file mode 100644 index ece9ea9a1..000000000 --- a/runsc/testutil/testutil_runfiles.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "fmt" - "os" - "path/filepath" -) - -// FindFile searchs for a file inside the test run environment. It returns the -// full path to the file. It fails if none or more than one file is found. -func FindFile(path string) (string, error) { - wd, err := os.Getwd() - if err != nil { - return "", err - } - - // The test root is demarcated by a path element called "__main__". Search for - // it backwards from the working directory. - root := wd - for { - dir, name := filepath.Split(root) - if name == "__main__" { - break - } - if len(dir) == 0 { - return "", fmt.Errorf("directory __main__ not found in %q", wd) - } - // Remove ending slash to loop around. - root = dir[:len(dir)-1] - } - - // Annoyingly, bazel adds the build type to the directory path for go - // binaries, but not for c++ binaries. We use two different patterns to - // to find our file. - patterns := []string{ - // Try the obvious path first. - filepath.Join(root, path), - // If it was a go binary, use a wildcard to match the build - // type. The pattern is: /test-path/__main__/directories/*/file. - filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)), - } - - for _, p := range patterns { - matches, err := filepath.Glob(p) - if err != nil { - // "The only possible returned error is ErrBadPattern, - // when pattern is malformed." -godoc - return "", fmt.Errorf("error globbing %q: %v", p, err) - } - switch len(matches) { - case 0: - // Try the next pattern. - case 1: - // We found it. - return matches[0], nil - default: - return "", fmt.Errorf("more than one match found for %q: %s", path, matches) - } - } - return "", fmt.Errorf("file %q not found", path) -} diff --git a/runsc/version_test.sh b/runsc/version_test.sh deleted file mode 100755 index 747350654..000000000 --- a/runsc/version_test.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -euf -x -o pipefail - -readonly runsc="$1" -readonly version=$($runsc --version) - -# Version should should not match VERSION, which is the default and which will -# also appear if something is wrong with workspace_status.sh script. -if [[ $version =~ "VERSION" ]]; then - echo "FAIL: Got bad version $version" - exit 1 -fi - -# Version should contain at least one number. -if [[ ! $version =~ [0-9] ]]; then - echo "FAIL: Got bad version $version" - exit 1 -fi - -echo "PASS: Got OK version $version" -exit 0 diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh deleted file mode 100644 index 06d44f914..000000000 --- a/scripts/benchmark.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Exporting for subprocesses as GCP APIs and tools check this environmental -# variable for authentication. -export GOOGLE_APPLICATION_CREDENTIALS="${KOKORO_KEYSTORE_DIR}/${GCLOUD_CREDENTIALS}" - -gcloud auth activate-service-account \ - --key-file "${GOOGLE_APPLICATION_CREDENTIALS}" - -gcloud config set project ${PROJECT} -gcloud config set compute/zone ${ZONE} - -bazel run //benchmarks:benchmarks -- \ - --verbose \ - run-gcp \ - startup \ - --runtime=runc \ - --runtime=runsc \ - --installers=head diff --git a/scripts/benchmarks.sh b/scripts/benchmarks.sh deleted file mode 100755 index 6b9065b07..000000000 --- a/scripts/benchmarks.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#!/usr/bin/env bash - -if [ "$#" -lt "1" ]; then - echo "usage: $0 <--mock |--env=<filename>> ..." - echo "example: $0 --mock --runs=8" - exit 1 -fi - -source $(dirname $0)/common.sh - -readonly TIMESTAMP=`date "+%Y%m%d-%H%M%S"` -readonly OUTDIR="$(mktemp --tmpdir -d run-${TIMESTAMP}-XXX)" -readonly DEFAULT_RUNTIMES="--runtime=runc --runtime=runsc --runtime=runsc-kvm" -readonly ALL_RUNTIMES="--runtime=runc --runtime=runsc --runtime=runsc-kvm" - -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'fio.(read|write)' --metric=bandwidth --size=5g --ioengine=sync --blocksize=1m > "${OUTDIR}/fio.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} fio.rand --metric=bandwidth --size=5g --ioengine=sync --blocksize=4k --time=30 > "${OUTDIR}/tmp_fio.csv" -cat "${OUTDIR}/tmp_fio.csv" | grep "\(runc\|runsc\)" >> "${OUTDIR}/fio.csv" && rm "${OUTDIR}/tmp_fio.csv" - -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'fio.(read|write)' --metric=bandwidth --tmpfs=True --size=5g --ioengine=sync --blocksize=1m > "${OUTDIR}/fio-tmpfs.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} fio.rand --metric=bandwidth --tmpfs=True --size=5g --ioengine=sync --blocksize=4k --time=30 > "${OUTDIR}/tmp_fio.csv" -cat "${OUTDIR}/tmp_fio.csv" | grep "\(runc\|runsc\)" >> "${OUTDIR}/fio-tmpfs.csv" && rm "${OUTDIR}/tmp_fio.csv" - -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} startup --count=50 > "${OUTDIR}/startup.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} density > "${OUTDIR}/density.csv" - -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} sysbench.cpu --threads=1 --max_prime=50000 --options='--max-time=5' > "${OUTDIR}/sysbench-cpu.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} sysbench.memory --threads=1 --options='--memory-block-size=1M --memory-total-size=500G' > "${OUTDIR}/sysbench-memory.csv" -run //benchmarks:perf -- run "$@" ${ALL_RUNTIMES} syscall > "${OUTDIR}/syscall.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'network.(upload|download)' --runs=20 > "${OUTDIR}/iperf.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} ml.tensorflow > "${OUTDIR}/tensorflow.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} media.ffmpeg > "${OUTDIR}/ffmpeg.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} http.httpd --path=latin100k.txt --connections=1 --connections=5 --connections=10 --connections=25 > "${OUTDIR}/httpd100k.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} http.httpd --path=latin10240k.txt --connections=1 --connections=5 --connections=10 --connections=25 > "${OUTDIR}/httpd10240k.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} redis > "${OUTDIR}/redis.csv" -run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'http.(ruby|node)' > "${OUTDIR}/applications.csv" - -echo "${OUTPUT}" && exit 0 diff --git a/scripts/build.sh b/scripts/build.sh deleted file mode 100755 index 7c9c99800..000000000 --- a/scripts/build.sh +++ /dev/null @@ -1,88 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Install required packages for make_repository.sh et al. -apt_install dpkg-sig coreutils apt-utils xz-utils - -# Build runsc. -runsc=$(build -c opt //runsc) - -# Build packages. -pkgs=$(build -c opt //runsc:runsc-debian) - -# Stop here if we have no artifacts directory. -[[ -v KOKORO_ARTIFACTS_DIR ]] || exit 0 - -# install_raw installs raw artifacts. -install_raw() { - mkdir -p "$1" - cp -f "${runsc}" "$1"/runsc - sha512sum "$1"/runsc | awk '{print $1 " runsc"}' > "$1"/runsc.sha512 -} - -# Build a repository, if the key is available. -# -# Note that make_repository.sh script will install packages into the provided -# root, but will output to stdout a directory that can be copied arbitrarily -# into "${KOKORO_ARTIFACTS_DIR}"/dists/XXX. We do things this way because we -# will copy the same repository structure into multiple locations, below. -if [[ -v KOKORO_REPO_KEY ]]; then - repo=$(tools/make_repository.sh \ - "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" \ - gvisor-bot@google.com \ - main \ - "${KOKORO_ARTIFACTS_DIR}" \ - ${pkgs}) -fi - -# install_repo installs a repository. -# -# Note that packages are already installed, as noted above. -install_repo() { - if [[ -v repo ]]; then - rm -rf "$1" && mkdir -p "$(dirname "$1")" && cp -a "${repo}" "$1" - fi -} - -# If nightly, install only nightly artifacts. -if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then - # The "latest" directory and current date. - stamp="$(date -Idate)" - install_raw "${KOKORO_ARTIFACTS_DIR}/nightly/latest" - install_raw "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" - install_repo "${KOKORO_ARTIFACTS_DIR}/dists/nightly" -else - # Is it a tagged release? Build that. - tags="$(git tag --points-at HEAD)" - if ! [[ -z "${tags}" ]]; then - # Note that a given commit can match any number of tags. We have to iterate - # through all possible tags and produce associated artifacts. - for tag in ${tags}; do - name=$(echo "${tag}" | cut -d'-' -f2) - base=$(echo "${name}" | cut -d'.' -f1) - install_raw "${KOKORO_ARTIFACTS_DIR}/release/${name}" - install_raw "${KOKORO_ARTIFACTS_DIR}/release/latest" - install_repo "${KOKORO_ARTIFACTS_DIR}/dists/release" - install_repo "${KOKORO_ARTIFACTS_DIR}/dists/${base}" - done - else - # Otherwise, assume it is a raw master commit. - install_raw "${KOKORO_ARTIFACTS_DIR}/master/latest" - install_repo "${KOKORO_ARTIFACTS_DIR}/dists/master" - fi -fi diff --git a/scripts/common.sh b/scripts/common.sh deleted file mode 100755 index 735a383de..000000000 --- a/scripts/common.sh +++ /dev/null @@ -1,100 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeou pipefail - -# Get the path to the directory this script lives in. -# If this script is being called with `source`, $0 will be the path of the -# *sourcing* script, so we can't use `dirname $0` to find scripts in this -# directory. -if [[ -v BASH_SOURCE && "$0" != "$BASH_SOURCE" ]]; then - declare -r script_dir="$(dirname "$BASH_SOURCE")" -else - declare -r script_dir="$(dirname "$0")" -fi - -source "${script_dir}/common_build.sh" - -# Ensure it attempts to collect logs in all cases. -trap collect_logs EXIT - -function set_runtime() { - RUNTIME=${1:-runsc} - RUNSC_BIN=/tmp/"${RUNTIME}"/runsc - RUNSC_LOGS_DIR="$(dirname ${RUNSC_BIN})"/logs - RUNSC_LOGS="${RUNSC_LOGS_DIR}"/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND% -} - -function test_runsc() { - test --test_arg=--runtime=${RUNTIME} "$@" -} - -function install_runsc_for_test() { - local -r test_name=$1 - shift - if [[ -z "${test_name}" ]]; then - echo "Missing mandatory test name" - exit 1 - fi - - # Add test to the name, so it doesn't conflict with other runtimes. - set_runtime $(find_branch_name)_"${test_name}" - - # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name - # down to the runtime. - install_runsc "${RUNTIME}" \ - --TESTONLY-test-name-env=RUNSC_TEST_NAME \ - --debug \ - --strace \ - --log-packets \ - "$@" -} - -# Installs the runsc with given runtime name. set_runtime must have been called -# to set runtime and logs location. -function install_runsc() { - local -r runtime=$1 - shift - - # Prepare the runtime binary. - local -r output=$(build //runsc) - mkdir -p "$(dirname ${RUNSC_BIN})" - cp -f "${output}" "${RUNSC_BIN}" - chmod 0755 "${RUNSC_BIN}" - - # Install the runtime. - sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@" - - # Clear old logs files that may exist. - sudo rm -f "${RUNSC_LOGS_DIR}"/'*' - - # Restart docker to pick up the new runtime configuration. - sudo systemctl restart docker -} - -# Installs the given packages. Note that the package names should be verified to -# be correct, otherwise this may result in a loop that spins until time out. -function apt_install() { - while true; do - if (sudo apt-get update && sudo apt-get install -y "$@"); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - return $result - fi - done -} diff --git a/scripts/common_build.sh b/scripts/common_build.sh deleted file mode 100755 index 3be0bb21c..000000000 --- a/scripts/common_build.sh +++ /dev/null @@ -1,112 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -which bazel -bazel version - -# Switch into the workspace; only necessary if run with kokoro. -if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then - cd git/repo -elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then - cd github/repo -fi - -# Set the standard bazel flags. -declare -a BAZEL_FLAGS=( - "--show_timestamps" - "--test_output=errors" - "--keep_going" - "--verbose_failures=true" -) -if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then - BAZEL_FLAGS+=( - "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" - "--config=remote" - ) -fi -declare -r BAZEL_FLAGS - -# Wrap bazel. -function build() { - bazel build "${BAZEL_FLAGS[@]}" "$@" 2>&1 \ - | tee /dev/fd/2 \ - | grep -E '^ bazel-bin/' \ - | awk '{ print $1; }' -} - -function test() { - bazel test "${BAZEL_FLAGS[@]}" "$@" -} - -function run() { - local binary=$1 - shift - bazel run "${binary}" -- "$@" -} - -function run_as_root() { - local binary=$1 - shift - bazel run --run_under="sudo" "${binary}" -- "$@" -} - -function collect_logs() { - # Zip out everything into a convenient form. - if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then - # Merge results files of all shards for each test suite. - for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do - junitparser merge `find $d -name test.xml` $d/test.xml - cat $d/shard_*_of_*/test.log > $d/test.log - if ls -l $d/shard_*_of_*/test.outputs/outputs.zip 2>/dev/null; then - zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs/outputs.zip - fi - done - find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf - # Move test logs to Kokoro directory. tar is used to conveniently perform - # renames while moving files. - find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | - tar --create --files-from - --transform 's/test\./sponge_log./' | - tar --extract --directory ${KOKORO_ARTIFACTS_DIR} - - # Collect sentry logs, if any. - if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then - # Check if the directory is empty or not (only the first line it needed). - local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) - if [[ "${logs}" ]]; then - local -r archive=runsc_logs_"${RUNTIME}".tar.gz - if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then - echo "runsc logs will be uploaded to:" - echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" - echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" - fi - time tar \ - --verbose \ - --create \ - --gzip \ - --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \ - --directory "${RUNSC_LOGS_DIR}" \ - . - fi - fi - fi -} - -function find_branch_name() { - git branch --show-current \ - || git rev-parse HEAD \ - || bazel info workspace \ - | xargs basename -} diff --git a/scripts/dev.sh b/scripts/dev.sh deleted file mode 100755 index a9107f33e..000000000 --- a/scripts/dev.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# common.sh sets '-x', but it's annoying to see so much output. -set +x - -# Defaults -declare -i REFRESH=0 -declare NAME=$(find_branch_name) - -while [[ $# -gt 0 ]]; do - case "$1" in - --refresh) - REFRESH=1 - ;; - --help) - echo "Use this script to build and install runsc with Docker." - echo - echo "usage: $0 [--refresh] [runtime_name]" - exit 1 - ;; - *) - NAME=$1 - ;; - esac - shift -done - -set_runtime "${NAME}" -echo -echo "Using runtime=${RUNTIME}" -echo - -echo Building runsc... -# Build first and fail on error. $() prevents "set -e" from reporting errors. -build //runsc -declare OUTPUT="$(build //runsc)" - -if [[ ${REFRESH} -eq 0 ]]; then - install_runsc "${RUNTIME}" --net-raw - install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets - install_runsc "${RUNTIME}-p" --net-raw --profile - - echo - echo "Runtimes ${RUNTIME}, ${RUNTIME}-d (debug enabled), and ${RUNTIME}-p installed." - echo "Use --runtime="${RUNTIME}" with your Docker command." - echo " docker run --rm --runtime="${RUNTIME}" hello-world" - echo - echo "If you rebuild, use $0 --refresh." - -else - mkdir -p "$(dirname ${RUNSC_BIN})" - cp -f ${OUTPUT} "${RUNSC_BIN}" - chmod a+rx "${RUNSC_BIN}" - - echo - echo "Runtime ${RUNTIME} refreshed." -fi - -echo "Logs are in: ${RUNSC_LOGS_DIR}" diff --git a/scripts/do_tests.sh b/scripts/do_tests.sh deleted file mode 100755 index a3a387c37..000000000 --- a/scripts/do_tests.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Build runsc. -build //runsc - -# run runsc do without root privileges. -run //runsc --rootless do true -run //runsc --rootless --network=none do true - -# run runsc do with root privileges. -run_as_root //runsc do true diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh deleted file mode 100755 index 72ba05260..000000000 --- a/scripts/docker_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -install_runsc_for_test docker -test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/go.sh b/scripts/go.sh deleted file mode 100755 index 626ed8fa4..000000000 --- a/scripts/go.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Build the go path. -build :gopath - -# Build the synthetic branch. -tools/go_branch.sh - -# Checkout the new branch. -git checkout go && git clean -f - -go version - -# Build everything. -go build ./... - -# Push, if required. -if [[ -v KOKORO_GO_PUSH ]] && [[ "${KOKORO_GO_PUSH}" == "true" ]]; then - if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then - git config --global credential.helper cache - git credential approve <<EOF -protocol=https -host=github.com -username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}") -password=x-oauth-basic -EOF - fi - git push origin go:go -fi diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh deleted file mode 100755 index 41298293d..000000000 --- a/scripts/hostnet_tests.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Install the runtime and perform basic tests. -install_runsc_for_test hostnet --network=host -test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh deleted file mode 100755 index b4a5211a5..000000000 --- a/scripts/iptables_tests.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -install_runsc_for_test iptables - -# Build the docker image for the test. -run //test/iptables/runner:runner-image --norun - -test //test/iptables:iptables_test \ - "--test_arg=--runtime=runc" \ - "--test_arg=--image=bazel/test/iptables/runner:runner-image" - -test //test/iptables:iptables_test \ - "--test_arg=--runtime=runsc" \ - "--test_arg=--image=bazel/test/iptables/runner:runner-image" diff --git a/scripts/issue_reviver.sh b/scripts/issue_reviver.sh deleted file mode 100755 index bac9b9192..000000000 --- a/scripts/issue_reviver.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -DIR=$(dirname $0) -source "${DIR}"/common.sh - -# Provide a credential file if available. -export OAUTH_TOKEN_FILE="" -if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then - OAUTH_TOKEN_FILE="${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}" -fi - -REPO_ROOT=$(cd "$(dirname "${DIR}")"; pwd) -run //tools/issue_reviver:issue_reviver --path "${REPO_ROOT}" --oauth-token-file="${OAUTH_TOKEN_FILE}" diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh deleted file mode 100755 index 5662401df..000000000 --- a/scripts/kvm_tests.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Ensure that KVM is loaded, and we can use it. -(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm -sudo chmod a+rw /dev/kvm - -# Run all KVM platform tests (locally). -run_as_root //pkg/sentry/platform/kvm:kvm_test - -# Install the KVM runtime and run all integration tests. -install_runsc_for_test kvm --platform=kvm -test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/make_tests.sh b/scripts/make_tests.sh deleted file mode 100755 index 79426756d..000000000 --- a/scripts/make_tests.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -top_level=$(git rev-parse --show-toplevel 2>/dev/null) -[[ $? -eq 0 ]] && cd "${top_level}" || exit 1 - -make -make runsc -make BAZEL_OPTIONS="build //..." bazel -make bazel-shutdown diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh deleted file mode 100755 index 2a1f12c0b..000000000 --- a/scripts/overlay_tests.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Install the runtime and perform basic tests. -install_runsc_for_test overlay --overlay -test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/packetdrill_tests.sh b/scripts/packetdrill_tests.sh deleted file mode 100755 index fc6bef79c..000000000 --- a/scripts/packetdrill_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -install_runsc_for_test runsc-d -test_runsc $(bazel query "attr(tags, manual, tests(//test/packetdrill/...))") diff --git a/scripts/release.sh b/scripts/release.sh deleted file mode 100755 index e14ba04a7..000000000 --- a/scripts/release.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Tag a release only if provided. -if ! [[ -v KOKORO_RELEASE_COMMIT ]]; then - echo "No KOKORO_RELEASE_COMMIT provided." >&2 - exit 1 -fi -if ! [[ -v KOKORO_RELEASE_TAG ]]; then - echo "No KOKORO_RELEASE_TAG provided." >&2 - exit 1 -fi -if ! [[ -v KOKORO_RELNOTES ]]; then - echo "No KOKORO_RELNOTES provided." >&2 - exit 1 -fi -if ! [[ -r "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" ]]; then - echo "The file '${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}' is not readable." >&2 - exit 1 -fi - -# Unless an explicit releaser is provided, use the bot e-mail. -declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot} -declare -r EMAIL=${EMAIL:-${KOKORO_RELEASE_AUTHOR}@google.com} - -# Ensure we have an appropriate configuration for the tag. -git config --get user.name || git config user.name "gVisor-bot" -git config --get user.email || git config user.email "${EMAIL}" - -# Provide a credential if available. -if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then - git config --global credential.helper cache - git credential approve <<EOF -protocol=https -host=github.com -username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}") -password=x-oauth-basic -EOF -fi - -# Run the release tool, which pushes to the origin repository. -tools/tag_release.sh \ - "${KOKORO_RELEASE_COMMIT}" \ - "${KOKORO_RELEASE_TAG}" \ - "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh deleted file mode 100755 index 4e4fcc76b..000000000 --- a/scripts/root_tests.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Reinstall the latest containerd shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim - -# Run the tests that require root. -install_runsc_for_test root -run_as_root //test/root:root_test --runtime=${RUNTIME} - diff --git a/scripts/simple_tests.sh b/scripts/simple_tests.sh deleted file mode 100755 index 3a15050c2..000000000 --- a/scripts/simple_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Run all simple tests (locally). -test //pkg/... //runsc/... //tools/... //benchmarks/... //benchmarks/runner:runner_test diff --git a/scripts/swgso_tests.sh b/scripts/swgso_tests.sh deleted file mode 100755 index 0de2df1d2..000000000 --- a/scripts/swgso_tests.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Install the runtime and perform basic tests. -install_runsc_for_test swgso --software-gso=true --gso=false -test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/syscall_kvm_tests.sh b/scripts/syscall_kvm_tests.sh deleted file mode 100755 index de85daa5a..000000000 --- a/scripts/syscall_kvm_tests.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# TODO(b/112165693): "test --test_tag_filters=runsc_kvm" can be used -# when the "manual" tag will be removed for kvm tests. -test `bazel query "attr(tags, runsc_kvm, tests(//test/syscalls/...))"` diff --git a/scripts/syscall_tests.sh b/scripts/syscall_tests.sh deleted file mode 100755 index a131b2d50..000000000 --- a/scripts/syscall_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Run all ptrace-variants of the system call tests. -test --test_tag_filters=runsc_ptrace //test/syscalls/... diff --git a/test/BUILD b/test/BUILD deleted file mode 100644 index 34b950644..000000000 --- a/test/BUILD +++ /dev/null @@ -1 +0,0 @@ -package(licenses = ["notice"]) diff --git a/test/README.md b/test/README.md deleted file mode 100644 index 97fe7ea04..000000000 --- a/test/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# Tests - -The tests defined under this path are verifying functionality beyond what unit -tests can cover, e.g. integration and end to end tests. Due to their nature, -they may need extra setup in the test machine and extra configuration to run. - -- **syscalls**: system call tests use a local runner, and do not require - additional configuration in the machine. -- **integration:** defines integration tests that uses `docker run` to test - functionality. -- **image:** basic end to end test for popular images. These require the same - setup as integration tests. -- **root:** tests that require to be run as root. These require the same setup - as integration tests. -- **util:** utilities library to support the tests. - -For the above noted cases, the relevant runtime must be installed via `runsc -install` before running. Just note that they require specific configuration to -work. This is handled automatically by the test scripts in the `scripts` -directory and they can be used to run tests locally on your machine. They are -also used to run these tests in `kokoro`. - -**Example:** - -To run image and integration tests, run: - -`./scripts/docker_test.sh` - -To run root tests, run: - -`./scripts/root_test.sh` - -There are a few other interesting variations for image and integration tests: - -* overlay: sets writable overlay inside the sentry -* hostnet: configures host network pass-thru, instead of netstack -* kvm: runsc the test using the KVM platform, instead of ptrace - -The test will build runsc, configure it with your local docker, restart -`dockerd`, and run tests. The location for runsc logs is printed to the output. diff --git a/test/e2e/BUILD b/test/e2e/BUILD deleted file mode 100644 index 76e04f878..000000000 --- a/test/e2e/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "integration_test", - size = "large", - srcs = [ - "exec_test.go", - "integration_test.go", - "regression_test.go", - ], - library = ":integration", - tags = [ - # Requires docker and runsc to be configured before the test runs. - "manual", - "local", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//runsc/dockerutil", - "//runsc/specutils", - "//runsc/testutil", - ], -) - -go_library( - name = "integration", - srcs = ["integration.go"], -) diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go deleted file mode 100644 index 4074d2285..000000000 --- a/test/e2e/exec_test.go +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package integration provides end-to-end integration tests for runsc. These -// tests require docker and runsc to be installed on the machine. -// -// Each test calls docker commands to start up a container, and tests that it -// is behaving properly, with various runsc commands. The container is killed -// and deleted at the end. - -package integration - -import ( - "fmt" - "strconv" - "strings" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/bits" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/specutils" -) - -// Test that exec uses the exact same capability set as the container. -func TestExecCapabilities(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-capabilities-test") - - // Start the container. - if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) - if err != nil { - t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) - } - if len(matches) != 2 { - t.Fatalf("There should be a match for the whole line and the capability bitmask") - } - want := fmt.Sprintf("CapEff:\t%s\n", matches[1]) - t.Log("Root capabilities:", want) - - // Now check that exec'd process capabilities match the root. - got, err := d.Exec("grep", "CapEff:", "/proc/self/status") - if err != nil { - t.Fatalf("docker exec failed: %v", err) - } - t.Logf("CapEff: %v", got) - if got != want { - t.Errorf("wrong capabilities, got: %q, want: %q", got, want) - } -} - -// Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW -// which is removed from the container when --net-raw=false. -func TestExecPrivileged(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-privileged-test") - - // Start the container with all capabilities dropped. - if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Check that all capabilities where dropped from container. - matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) - if err != nil { - t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) - } - if len(matches) != 2 { - t.Fatalf("There should be a match for the whole line and the capability bitmask") - } - containerCaps, err := strconv.ParseUint(matches[1], 16, 64) - if err != nil { - t.Fatalf("failed to convert capabilities %q: %v", matches[1], err) - } - t.Logf("Container capabilities: %#x", containerCaps) - if containerCaps != 0 { - t.Fatalf("Container should have no capabilities: %x", containerCaps) - } - - // Check that 'exec --privileged' adds all capabilities, except - // for CAP_NET_RAW. - got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status") - if err != nil { - t.Fatalf("docker exec failed: %v", err) - } - t.Logf("Exec CapEff: %v", got) - want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW))) - if got != want { - t.Errorf("Wrong capabilities, got: %q, want: %q. Make sure runsc is not using '--net-raw'", got, want) - } -} - -func TestExecJobControl(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-job-control-test") - - // Start the container. - if err := d.Run("alpine", "sleep", "1000"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Exec 'sh' with an attached pty. - cmd, ptmx, err := d.ExecWithTerminal("sh") - if err != nil { - t.Fatalf("docker exec failed: %v", err) - } - defer ptmx.Close() - - // Call "sleep 100 | cat" in the shell. We pipe to cat so that there - // will be two processes in the foreground process group. - if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) - - // Send a ^C to the pty, which should kill sleep and cat, but not the - // shell. \x03 is ASCII "end of text", which is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // The shell should still be alive at this point. Sleep should have - // exited with code 2+128=130. We'll exit with 10 plus that number, so - // that we can be sure that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Exec process should exit with code 10+130=140. - ps, err := cmd.Process.Wait() - if err != nil { - t.Fatalf("error waiting for exec process: %v", err) - } - ws := ps.Sys().(syscall.WaitStatus) - if !ws.Exited() { - t.Errorf("ws.Exited got false, want true") - } - if got, want := ws.ExitStatus(), 140; got != want { - t.Errorf("ws.ExitedStatus got %d, want %d", got, want) - } -} - -// Test that failure to exec returns proper error message. -func TestExecError(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-error-test") - - // Start the container. - if err := d.Run("alpine", "sleep", "1000"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - _, err := d.Exec("no_can_find") - if err == nil { - t.Fatalf("docker exec didn't fail") - } - if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) { - t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want) - } -} - -// Test that exec inherits environment from run. -func TestExecEnv(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-env-test") - - // Start the container with env FOO=BAR. - if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Exec "echo $FOO". - got, err := d.Exec("/bin/sh", "-c", "echo $FOO") - if err != nil { - t.Fatalf("docker exec failed: %v", err) - } - if got, want := strings.TrimSpace(got), "BAR"; got != want { - t.Errorf("bad output from 'docker exec'. Got %q; Want %q.", got, want) - } -} - -// TestRunEnvHasHome tests that run always has HOME environment set. -func TestRunEnvHasHome(t *testing.T) { - // Base alpine image does not have any environment variables set. - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("run-env-test") - - // Exec "echo $HOME". The 'bin' user's home dir is '/bin'. - got, err := d.RunFg("--user", "bin", "alpine", "/bin/sh", "-c", "echo $HOME") - if err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - if got, want := strings.TrimSpace(got), "/bin"; got != want { - t.Errorf("bad output from 'docker run'. Got %q; Want %q.", got, want) - } -} - -// Test that exec always has HOME environment set, even when not set in run. -func TestExecEnvHasHome(t *testing.T) { - // Base alpine image does not have any environment variables set. - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-env-home-test") - - // We will check that HOME is set for root user, and also for a new - // non-root user we will create. - newUID := 1234 - newHome := "/foo/bar" - - // Create a new user with a home directory, and then sleep. - script := fmt.Sprintf(` - mkdir -p -m 777 %s && \ - adduser foo -D -u %d -h %s && \ - sleep 1000`, newHome, newUID, newHome) - if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Exec "echo $HOME", and expect to see "/root". - got, err := d.Exec("/bin/sh", "-c", "echo $HOME") - if err != nil { - t.Fatalf("docker exec failed: %v", err) - } - if want := "/root"; !strings.Contains(got, want) { - t.Errorf("wanted exec output to contain %q, got %q", want, got) - } - - // Execute the same as uid 123 and expect newHome. - got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME") - if err != nil { - t.Fatalf("docker exec failed: %v", err) - } - if want := newHome; !strings.Contains(got, want) { - t.Errorf("wanted exec output to contain %q, got %q", want, got) - } -} diff --git a/test/e2e/integration.go b/test/e2e/integration.go deleted file mode 100644 index 4cd5f6c24..000000000 --- a/test/e2e/integration.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package integration is empty. See integration_test.go for description. -package integration diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go deleted file mode 100644 index cc4fbbaed..000000000 --- a/test/e2e/integration_test.go +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package integration provides end-to-end integration tests for runsc. -// -// Each test calls docker commands to start up a container, and tests that it is -// behaving properly, with various runsc commands. The container is killed and -// deleted at the end. -// -// Setup instruction in test/README.md. -package integration - -import ( - "flag" - "fmt" - "net" - "net/http" - "os" - "strconv" - "strings" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" -) - -// httpRequestSucceeds sends a request to a given url and checks that the status is OK. -func httpRequestSucceeds(client http.Client, server string, port int) error { - url := fmt.Sprintf("http://%s:%d", server, port) - // Ensure that content is being served. - resp, err := client.Get(url) - if err != nil { - return fmt.Errorf("error reaching http server: %v", err) - } - if want := http.StatusOK; resp.StatusCode != want { - return fmt.Errorf("wrong response code, got: %d, want: %d", resp.StatusCode, want) - } - return nil -} - -// TestLifeCycle tests a basic Create/Start/Stop docker container life cycle. -func TestLifeCycle(t *testing.T) { - if err := dockerutil.Pull("nginx"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("lifecycle-test") - if err := d.Create("-p", "80", "nginx"); err != nil { - t.Fatal("docker create failed:", err) - } - if err := d.Start(); err != nil { - d.CleanUp() - t.Fatal("docker start failed:", err) - } - - // Test that container is working - port, err := d.FindPort(80) - if err != nil { - t.Fatal("docker.FindPort(80) failed: ", err) - } - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) - } - client := http.Client{Timeout: time.Duration(2 * time.Second)} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { - t.Error("http request failed:", err) - } - - if err := d.Stop(); err != nil { - d.CleanUp() - t.Fatal("docker stop failed:", err) - } - if err := d.Remove(); err != nil { - t.Fatal("docker rm failed:", err) - } -} - -func TestPauseResume(t *testing.T) { - const img = "gcr.io/gvisor-presubmit/python-hello" - if !testutil.IsCheckpointSupported() { - t.Log("Checkpoint is not supported, skipping test.") - return - } - - if err := dockerutil.Pull(img); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("pause-resume-test") - if err := d.Run("-p", "8080", img); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) - if err != nil { - t.Fatal("docker.FindPort(8080) failed:", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) - } - - // Check that container is working. - client := http.Client{Timeout: time.Duration(2 * time.Second)} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { - t.Error("http request failed:", err) - } - - if err := d.Pause(); err != nil { - t.Fatal("docker pause failed:", err) - } - - // Check if container is paused. - switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) { - case nil: - t.Errorf("http req expected to fail but it succeeded") - case net.Error: - if !v.Timeout() { - t.Errorf("http req got error %v, wanted timeout", v) - } - default: - t.Errorf("http req got unexpected error %v", v) - } - - if err := d.Unpause(); err != nil { - t.Fatal("docker unpause failed:", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) - } - - // Check if container is working again. - if err := httpRequestSucceeds(client, "localhost", port); err != nil { - t.Error("http request failed:", err) - } -} - -func TestCheckpointRestore(t *testing.T) { - const img = "gcr.io/gvisor-presubmit/python-hello" - if !testutil.IsCheckpointSupported() { - t.Log("Pause/resume is not supported, skipping test.") - return - } - - if err := dockerutil.Pull(img); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("save-restore-test") - if err := d.Run("-p", "8080", img); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - if err := d.Checkpoint("test"); err != nil { - t.Fatal("docker checkpoint failed:", err) - } - - if _, err := d.Wait(30 * time.Second); err != nil { - t.Fatal(err) - } - - // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed. - if err := testutil.Poll(func() error { return d.Restore("test") }, 15*time.Second); err != nil { - t.Fatal("docker restore failed:", err) - } - - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) - if err != nil { - t.Fatal("docker.FindPort(8080) failed:", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) - } - - // Check if container is working again. - client := http.Client{Timeout: time.Duration(2 * time.Second)} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { - t.Error("http request failed:", err) - } -} - -// Create client and server that talk to each other using the local IP. -func TestConnectToSelf(t *testing.T) { - d := dockerutil.MakeDocker("connect-to-self-test") - - // Creates server that replies "server" and exists. Sleeps at the end because - // 'docker exec' gets killed if the init process exists before it can finish. - if err := d.Run("ubuntu:trusty", "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil { - t.Fatal("docker run failed:", err) - } - defer d.CleanUp() - - // Finds IP address for host. - ip, err := d.Exec("/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") - if err != nil { - t.Fatal("docker exec failed:", err) - } - ip = strings.TrimRight(ip, "\n") - - // Runs client that sends "client" to the server and exits. - reply, err := d.Exec("/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) - if err != nil { - t.Fatal("docker exec failed:", err) - } - - // Ensure both client and server got the message from each other. - if want := "server\n"; reply != want { - t.Errorf("Error on server, want: %q, got: %q", want, reply) - } - if _, err := d.WaitForOutput("^client\n$", 1*time.Second); err != nil { - t.Fatal("docker.WaitForOutput(client) timeout:", err) - } -} - -func TestMemLimit(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") - cmd := "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'" - out, err := d.RunFg("--memory=500MB", "alpine", "sh", "-c", cmd) - if err != nil { - t.Fatal("docker run failed:", err) - } - defer d.CleanUp() - - // Remove warning message that swap isn't present. - if strings.HasPrefix(out, "WARNING") { - lines := strings.Split(out, "\n") - if len(lines) != 3 { - t.Fatalf("invalid output: %s", out) - } - out = lines[1] - } - - got, err := strconv.ParseUint(strings.TrimSpace(out), 10, 64) - if err != nil { - t.Fatalf("failed to parse %q: %v", out, err) - } - if want := uint64(500 * 1024); got != want { - t.Errorf("MemTotal got: %d, want: %d", got, want) - } -} - -func TestNumCPU(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") - cmd := "cat /proc/cpuinfo | grep 'processor.*:' | wc -l" - out, err := d.RunFg("--cpuset-cpus=0", "alpine", "sh", "-c", cmd) - if err != nil { - t.Fatal("docker run failed:", err) - } - defer d.CleanUp() - - got, err := strconv.Atoi(strings.TrimSpace(out)) - if err != nil { - t.Fatalf("failed to parse %q: %v", out, err) - } - if want := 1; got != want { - t.Errorf("MemTotal got: %d, want: %d", got, want) - } -} - -// TestJobControl tests that job control characters are handled properly. -func TestJobControl(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("job-control-test") - - // Start the container with an attached PTY. - _, ptmx, err := d.RunWithPty("alpine", "sh") - if err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer ptmx.Close() - defer d.CleanUp() - - // Call "sleep 100" in the shell. - if _, err := ptmx.Write([]byte("sleep 100\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) - - // Send a ^C to the pty, which should kill sleep, but not the shell. - // \x03 is ASCII "end of text", which is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // The shell should still be alive at this point. Sleep should have - // exited with code 2+128=130. We'll exit with 10 plus that number, so - // that we can be sure that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Wait for the container to exit. - got, err := d.Wait(5 * time.Second) - if err != nil { - t.Fatalf("error getting exit code: %v", err) - } - // Container should exit with code 10+130=140. - if want := syscall.WaitStatus(140); got != want { - t.Errorf("container exited with code %d want %d", got, want) - } -} - -// TestTmpFile checks that files inside '/tmp' are not overridden. In addition, -// it checks that working dir is created if it doesn't exit. -func TestTmpFile(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("tmp-file-test") - if err := d.Run("-w=/tmp/foo/bar", "--read-only", "alpine", "touch", "/tmp/foo/bar/file"); err != nil { - t.Fatal("docker run failed:", err) - } - defer d.CleanUp() -} - -func TestMain(m *testing.M) { - dockerutil.EnsureSupportedDockerVersion() - flag.Parse() - os.Exit(m.Run()) -} diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go deleted file mode 100644 index 2488be383..000000000 --- a/test/e2e/regression_test.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package integration - -import ( - "strings" - "testing" - - "gvisor.dev/gvisor/runsc/dockerutil" -) - -// Test that UDS can be created using overlay when parent directory is in lower -// layer only (b/134090485). -// -// Prerequisite: the directory where the socket file is created must not have -// been open for write before bind(2) is called. -func TestBindOverlay(t *testing.T) { - if err := dockerutil.Pull("ubuntu:trusty"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("bind-overlay-test") - - cmd := "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p" - got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd) - if err != nil { - t.Fatal("docker run failed:", err) - } - - if want := "foobar-asdf"; !strings.Contains(got, want) { - t.Fatalf("docker run output is missing %q: %s", want, got) - } - defer d.CleanUp() -} diff --git a/test/image/BUILD b/test/image/BUILD deleted file mode 100644 index 7392ac54e..000000000 --- a/test/image/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "image_test", - size = "large", - srcs = [ - "image_test.go", - ], - data = [ - "latin10k.txt", - "mysql.sql", - "ruby.rb", - "ruby.sh", - ], - library = ":image", - tags = [ - # Requires docker and runsc to be configured before the test runs. - "manual", - "local", - ], - visibility = ["//:sandbox"], - deps = [ - "//runsc/dockerutil", - "//runsc/testutil", - ], -) - -go_library( - name = "image", - srcs = ["image.go"], -) diff --git a/test/image/image.go b/test/image/image.go deleted file mode 100644 index 297f1ab92..000000000 --- a/test/image/image.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package image is empty. See image_test.go for description. -package image diff --git a/test/image/image_test.go b/test/image/image_test.go deleted file mode 100644 index 0a1e19d6f..000000000 --- a/test/image/image_test.go +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package image provides end-to-end image tests for runsc. - -// Each test calls docker commands to start up a container, and tests that it -// is behaving properly, like connecting to a port or looking at the output. -// The container is killed and deleted at the end. -// -// Setup instruction in test/README.md. -package image - -import ( - "flag" - "fmt" - "io/ioutil" - "log" - "net/http" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" -) - -func TestHelloWorld(t *testing.T) { - d := dockerutil.MakeDocker("hello-test") - if err := d.Run("hello-world"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - if _, err := d.WaitForOutput("Hello from Docker!", 5*time.Second); err != nil { - t.Fatalf("docker didn't say hello: %v", err) - } -} - -func runHTTPRequest(port int) error { - url := fmt.Sprintf("http://localhost:%d/not-found", port) - resp, err := http.Get(url) - if err != nil { - return fmt.Errorf("error reaching http server: %v", err) - } - if want := http.StatusNotFound; resp.StatusCode != want { - return fmt.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) - } - - url = fmt.Sprintf("http://localhost:%d/latin10k.txt", port) - resp, err = http.Get(url) - if err != nil { - return fmt.Errorf("Error reaching http server: %v", err) - } - if want := http.StatusOK; resp.StatusCode != want { - return fmt.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) - } - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("Error reading http response: %v", err) - } - defer resp.Body.Close() - - // READALL is the last word in the file. Ensures everything was read. - if want := "READALL"; strings.HasSuffix(string(body), want) { - return fmt.Errorf("response doesn't contain %q, resp: %q", want, body) - } - return nil -} - -func testHTTPServer(t *testing.T, port int) { - const requests = 10 - ch := make(chan error, requests) - for i := 0; i < requests; i++ { - go func() { - start := time.Now() - err := runHTTPRequest(port) - log.Printf("Response time %v: %v", time.Since(start).String(), err) - ch <- err - }() - } - - for i := 0; i < requests; i++ { - err := <-ch - if err != nil { - t.Errorf("testHTTPServer(%d) failed: %v", port, err) - } - } -} - -func TestHttpd(t *testing.T) { - if err := dockerutil.Pull("httpd"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("http-test") - - dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } - - // Start the container. - mountArg := dockerutil.MountArg(dir, "/usr/local/apache2/htdocs", dockerutil.ReadOnly) - if err := d.Run("-p", "80", mountArg, "httpd"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Find where port 80 is mapped to. - port, err := d.FindPort(80) - if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Errorf("WaitForHTTP() timeout: %v", err) - } - - testHTTPServer(t, port) -} - -func TestNginx(t *testing.T) { - if err := dockerutil.Pull("nginx"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("net-test") - - dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } - - // Start the container. - mountArg := dockerutil.MountArg(dir, "/usr/share/nginx/html", dockerutil.ReadOnly) - if err := d.Run("-p", "80", mountArg, "nginx"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Find where port 80 is mapped to. - port, err := d.FindPort(80) - if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Errorf("WaitForHTTP() timeout: %v", err) - } - - testHTTPServer(t, port) -} - -func TestMysql(t *testing.T) { - if err := dockerutil.Pull("mysql"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("mysql-test") - - // Start the container. - if err := d.Run("-e", "MYSQL_ROOT_PASSWORD=foobar123", "mysql"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Wait until it's up and running. - if _, err := d.WaitForOutput("port: 3306 MySQL Community Server", 3*time.Minute); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) - } - - client := dockerutil.MakeDocker("mysql-client-test") - dir, err := dockerutil.PrepareFiles("test/image/mysql.sql") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } - - // Tell mysql client to connect to the server and execute the file in verbose - // mode to verify the output. - args := []string{ - dockerutil.LinkArg(&d, "mysql"), - dockerutil.MountArg(dir, "/sql", dockerutil.ReadWrite), - "mysql", - "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql", - } - if err := client.Run(args...); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer client.CleanUp() - - // Ensure file executed to the end and shutdown mysql. - if _, err := client.WaitForOutput("--------------\nshutdown\n--------------", 15*time.Second); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) - } - if _, err := d.WaitForOutput("mysqld: Shutdown complete", 30*time.Second); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) - } -} - -func TestPythonHello(t *testing.T) { - // TODO(b/136503277): Once we have more complete python runtime tests, - // we can drop this one. - const img = "gcr.io/gvisor-presubmit/python-hello" - if err := dockerutil.Pull(img); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("python-hello-test") - if err := d.Run("-p", "8080", img); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) - if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatalf("WaitForHTTP() timeout: %v", err) - } - - // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) - resp, err := http.Get(url) - if err != nil { - t.Errorf("Error reaching http server: %v", err) - } - if want := http.StatusOK; resp.StatusCode != want { - t.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) - } -} - -func TestTomcat(t *testing.T) { - if err := dockerutil.Pull("tomcat:8.0"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("tomcat-test") - if err := d.Run("-p", "8080", "tomcat:8.0"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) - if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatalf("WaitForHTTP() timeout: %v", err) - } - - // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) - resp, err := http.Get(url) - if err != nil { - t.Errorf("Error reaching http server: %v", err) - } - if want := http.StatusOK; resp.StatusCode != want { - t.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) - } -} - -func TestRuby(t *testing.T) { - if err := dockerutil.Pull("ruby"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("ruby-test") - - dir, err := dockerutil.PrepareFiles("test/image/ruby.rb", "test/image/ruby.sh") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } - if err := os.Chmod(filepath.Join(dir, "ruby.sh"), 0333); err != nil { - t.Fatalf("os.Chmod(%q, 0333) failed: %v", dir, err) - } - - if err := d.Run("-p", "8080", dockerutil.MountArg(dir, "/src", dockerutil.ReadOnly), "ruby", "/src/ruby.sh"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) - if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) - } - - // Wait until it's up and running, 'gem install' can take some time. - if err := testutil.WaitForHTTP(port, 1*time.Minute); err != nil { - t.Fatalf("WaitForHTTP() timeout: %v", err) - } - - // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) - resp, err := http.Get(url) - if err != nil { - t.Errorf("error reaching http server: %v", err) - } - if want := http.StatusOK; resp.StatusCode != want { - t.Errorf("wrong response code, got: %d, want: %d", resp.StatusCode, want) - } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading body: %v", err) - } - if got, want := string(body), "Hello World"; !strings.Contains(got, want) { - t.Errorf("invalid body content, got: %q, want: %q", got, want) - } -} - -func TestStdio(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("stdio-test") - - wantStdout := "hello stdout" - wantStderr := "bonjour stderr" - cmd := fmt.Sprintf("echo %q; echo %q 1>&2;", wantStdout, wantStderr) - if err := d.Run("alpine", "/bin/sh", "-c", cmd); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - for _, want := range []string{wantStdout, wantStderr} { - if _, err := d.WaitForOutput(want, 5*time.Second); err != nil { - t.Fatalf("docker didn't get output %q : %v", want, err) - } - } -} - -func TestMain(m *testing.M) { - dockerutil.EnsureSupportedDockerVersion() - flag.Parse() - os.Exit(m.Run()) -} diff --git a/test/image/latin10k.txt b/test/image/latin10k.txt deleted file mode 100644 index 61341e00b..000000000 --- a/test/image/latin10k.txt +++ /dev/null @@ -1,33 +0,0 @@ -Lorem ipsum dolor sit amet, consectetur adipiscing elit. Cras ut placerat felis. Maecenas urna est, auctor a efficitur sit amet, egestas et augue. Curabitur dignissim scelerisque nunc vel cursus. Ut vehicula est pretium, consectetur nunc non, pharetra ligula. Curabitur ut ultricies metus. Suspendisse pulvinar, orci sed fermentum vestibulum, eros turpis molestie lectus, nec elementum risus dolor mattis felis. Donec ultrices ipsum sem, at pretium lacus convallis at. Mauris nulla enim, tincidunt non bibendum at, vehicula pulvinar mauris. - -Duis in dapibus turpis. Pellentesque maximus magna odio, ac congue libero laoreet quis. Maecenas euismod risus in justo aliquam accumsan. Nunc quis ornare arcu, sit amet sodales elit. Phasellus nec scelerisque nisl, a tincidunt arcu. Proin ornare est nunc, sed suscipit orci interdum et. Suspendisse condimentum venenatis diam in tempor. Aliquam egestas lectus in rutrum tempus. Donec id egestas eros. Donec molestie consequat purus, sed posuere odio venenatis vitae. Nunc placerat augue id vehicula varius. In hac habitasse platea dictumst. Proin at est accumsan, venenatis quam a, fermentum risus. Phasellus posuere pellentesque enim, id suscipit magna consequat ut. Quisque ut tortor ante. - -Cras ut vulputate metus, a laoreet lectus. Vivamus ultrices molestie odio in tristique. Morbi faucibus mi eget sollicitudin fringilla. Fusce vitae lacinia ligula. Sed egestas sed diam eu posuere. Maecenas justo nisl, venenatis vel nibh vel, cursus aliquam velit. Praesent lacinia dui id erat venenatis rhoncus. Morbi gravida felis ante, sit amet vehicula orci rhoncus vitae. - -Sed finibus sagittis dictum. Proin auctor suscipit sem et mattis. Phasellus libero ligula, pellentesque ut felis porttitor, fermentum sollicitudin orci. Nulla eu nulla nibh. Fusce a eros risus. Proin vel magna risus. Donec nec elit eleifend, scelerisque sapien vitae, pharetra quam. Donec porttitor mauris scelerisque, tempus orci hendrerit, dapibus felis. Nullam libero elit, sollicitudin a aliquam at, ultrices in erat. Mauris eget ligula sodales, porta turpis et, scelerisque odio. Mauris mollis leo vitae purus gravida, in tempor nunc efficitur. Nulla facilisis posuere augue, nec pellentesque lectus eleifend ac. Vestibulum convallis est a feugiat tincidunt. Donec vitae enim volutpat, tincidunt eros eu, malesuada nibh. - -Quisque molestie, magna ornare elementum convallis, erat enim sagittis ipsum, eget porttitor sapien arcu id purus. Donec ut cursus diam. Nulla rutrum nulla et mi fermentum, vel tempus tellus posuere. Proin vitae pharetra nulla, nec ornare ex. Nulla consequat, augue a accumsan euismod, turpis leo ornare ligula, a pulvinar enim dolor ut augue. Quisque volutpat, lectus a varius mollis, nisl eros feugiat sem, at egestas lacus justo eu elit. Vestibulum scelerisque mauris est, sagittis interdum nunc accumsan sit amet. Maecenas aliquet ex ut lacus ornare, eu sagittis nibh imperdiet. Duis ultrices nisi velit, sed sodales risus sollicitudin et. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Etiam a accumsan augue, vitae pulvinar nulla. Pellentesque euismod sodales magna, nec luctus eros mattis eget. Sed lacinia suscipit lectus, eget consectetur dui pellentesque sed. Nullam nec mattis tellus. - -Aliquam erat volutpat. Praesent lobortis massa porttitor eros tincidunt, nec consequat diam pharetra. Duis efficitur non lorem sed mattis. Suspendisse justo nunc, pulvinar eu porttitor at, facilisis id eros. Suspendisse potenti. Cras molestie aliquet orci ut fermentum. In tempus aliquet eros nec suscipit. Suspendisse in mauris ut lectus ultrices blandit sit amet vitae est. Nam magna massa, porttitor ut semper id, feugiat vel quam. Suspendisse dignissim posuere scelerisque. Donec scelerisque lorem efficitur suscipit suscipit. Nunc luctus ligula et scelerisque lacinia. - -Suspendisse potenti. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Sed ultrices, sem in venenatis scelerisque, tellus ipsum porttitor urna, et iaculis lectus odio ac nisi. Integer luctus dui urna, at sollicitudin elit dapibus eu. Praesent nibh ante, porttitor a ante in, ullamcorper pretium felis. Aliquam vel tortor imperdiet, imperdiet lorem et, cursus mi. Proin tempus velit est, ut hendrerit metus gravida sed. Sed nibh sapien, faucibus quis ipsum in, scelerisque lacinia elit. In nec magna eu magna laoreet rhoncus. Donec vitae rutrum mauris. Integer urna felis, consequat at rhoncus vitae, auctor quis elit. Duis a pulvinar sem, nec gravida nisl. Nam non dapibus purus. Praesent vestibulum turpis nec erat porttitor, a scelerisque purus tincidunt. - -Nam fringilla leo nisi, nec placerat nisl luctus eget. Aenean malesuada nunc porta sapien sodales convallis. Suspendisse ut massa tempor, ullamcorper mi ut, faucibus turpis. Vivamus at sagittis metus. Donec varius ac mi eget sodales. Nulla feugiat, nulla eu fringilla fringilla, nunc lorem sollicitudin quam, vitae lacinia velit lorem eu orci. Mauris leo urna, pellentesque ac posuere non, pellentesque sit amet quam. - -Vestibulum porta diam urna, a aliquet nibh vestibulum et. Proin interdum bibendum nisl sed rhoncus. Sed vel diam hendrerit, faucibus ante et, hendrerit diam. Nunc dolor augue, mattis non dolor vel, luctus sodales neque. Cras malesuada fermentum dolor eu lobortis. Integer dapibus volutpat consequat. Maecenas posuere feugiat nunc. Donec vel mollis elit, volutpat consequat enim. Nulla id nisi finibus orci imperdiet elementum. Phasellus ultrices, elit vitae consequat rutrum, nisl est congue massa, quis condimentum justo nisi vitae turpis. Maecenas aliquet risus sit amet accumsan elementum. Proin non finibus elit, sit amet lobortis augue. - -Morbi pretium pulvinar sem vel sollicitudin. Proin imperdiet fringilla leo, non pellentesque lacus gravida nec. Vivamus ullamcorper consectetur ligula eu consectetur. Curabitur sit amet tempus purus. Curabitur quam quam, tincidunt eu tempus vel, volutpat at ipsum. Maecenas lobortis elit ac justo interdum, sit amet mattis ligula mollis. Sed posuere ligula et felis convallis tempor. Aliquam nec mollis velit. Donec varius sit amet erat at imperdiet. Nulla ipsum justo, tempor non sollicitudin gravida, dignissim vel orci. In hac habitasse platea dictumst. Cras cursus tellus id arcu aliquet accumsan. Phasellus ac erat dui. - -Duis mollis metus at mi luctus aliquam. Duis varius eget erat ac porttitor. Phasellus lobortis sagittis lacinia. Etiam sagittis eget erat in pulvinar. Phasellus sodales risus nec vulputate accumsan. Cras sit amet pellentesque dui. Praesent consequat felis mi, at vulputate diam convallis a. Donec hendrerit nibh vel justo consequat dictum. In euismod, dui sit amet malesuada suscipit, mauris ex rhoncus eros, sed ornare arcu nunc eu urna. Pellentesque eget erat augue. Integer rutrum mauris sem, nec sodales nulla cursus vel. Vivamus porta, urna vel varius vulputate, nulla arcu malesuada dui, a ultrices magna ante sed nibh. - -Morbi ultricies aliquam lorem id bibendum. Donec sit amet nunc vitae massa gravida eleifend hendrerit vel libero. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Nulla vestibulum tempus condimentum. Aliquam dolor ipsum, condimentum in sapien et, tempor iaculis nulla. Aenean non pharetra augue. Maecenas mattis dignissim maximus. Fusce elementum tincidunt massa sit amet lobortis. Phasellus nec pharetra dui, et malesuada ante. Nullam commodo pretium tellus. Praesent sollicitudin, enim eget imperdiet scelerisque, odio felis vulputate dolor, eget auctor neque tellus ac lorem. - -In consectetur augue et sapien feugiat varius. Nam tortor mi, consectetur ac felis non, elementum venenatis augue. Suspendisse ut tellus in est sagittis cursus. Quisque faucibus, neque sit amet semper congue, nibh augue finibus odio, vitae interdum dolor arcu eget arcu. Curabitur dictum risus massa, non tincidunt urna molestie non. Maecenas eu quam purus. Donec vulputate, dui eu accumsan blandit, mauris tortor tristique mi, sed blandit leo quam id quam. Ut venenatis sagittis malesuada. Integer non auctor orci. Duis consectetur massa felis. Fusce euismod est sit amet bibendum finibus. Vestibulum dolor ex, tempor at elit in, iaculis cursus dui. Nunc sed neque ac risus rutrum tempus sit amet at ante. In hac habitasse platea dictumst. - -Donec rutrum, velit nec viverra tincidunt, est velit viverra neque, quis auctor leo ex at lectus. Morbi eget purus nisi. Aliquam lacus dui, interdum vitae elit at, venenatis dignissim est. Duis ac mollis lorem. Vivamus a vestibulum quam. Maecenas non metus dolor. Praesent tortor nunc, tristique at nisl molestie, vulputate eleifend diam. Integer ultrices lacus odio, vel imperdiet enim accumsan id. Sed ligula tortor, interdum eu velit eget, pharetra pulvinar magna. Sed non lacus in eros tincidunt sagittis ac vel justo. Donec vitae leo sagittis, accumsan ante sit amet, accumsan odio. Ut volutpat ultricies tortor. Vestibulum tempus purus et est tristique sagittis quis vitae turpis. - -Nam iaculis neque lacus, eget euismod turpis blandit eget. In hac habitasse platea dictumst. Phasellus justo neque, scelerisque sit amet risus ut, pretium commodo nisl. Phasellus auctor sapien sed ex bibendum fermentum. Proin maximus odio a ante ornare, a feugiat lorem egestas. Etiam efficitur tortor a ante tincidunt interdum. Nullam non est ac massa congue efficitur sit amet nec eros. Nullam at ipsum vel mauris tincidunt efficitur. Duis pulvinar nisl elit, id auctor risus laoreet ac. Sed nunc mauris, tristique id leo ut, condimentum congue nunc. Sed ultricies, mauris et convallis faucibus, justo ex faucibus est, at lobortis purus justo non arcu. Integer vel facilisis elit, dapibus imperdiet mauris. - -Pellentesque non mattis turpis, eget bibendum velit. Fusce sollicitudin ante ac tincidunt rhoncus. Praesent porta scelerisque consequat. Donec eleifend faucibus sollicitudin. Quisque vitae purus eget tortor tempor ultrices. Maecenas mauris diam, semper vitae est non, imperdiet tempor magna. Duis elit lacus, auctor vestibulum enim eget, rhoncus porttitor tortor. - -Donec non rhoncus nibh. Cras dapibus justo vitae nunc accumsan, id congue erat egestas. Aenean at ante ante. Duis eleifend imperdiet dREADALL diff --git a/test/image/mysql.sql b/test/image/mysql.sql deleted file mode 100644 index 51554b98d..000000000 --- a/test/image/mysql.sql +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -SHOW databases; -USE mysql; - -CREATE TABLE foo (id int); -INSERT INTO foo VALUES(1); -SELECT * FROM foo; -DROP TABLE foo; - -shutdown; diff --git a/test/image/ruby.rb b/test/image/ruby.rb deleted file mode 100644 index aced49c6d..000000000 --- a/test/image/ruby.rb +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -require 'sinatra' - -set :bind, "0.0.0.0" -set :port, 8080 - -get '/' do - 'Hello World' -end - diff --git a/test/image/ruby.sh b/test/image/ruby.sh deleted file mode 100644 index ebe8d5b0e..000000000 --- a/test/image/ruby.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -e - -gem install sinatra -ruby /src/ruby.rb diff --git a/test/iptables/BUILD b/test/iptables/BUILD deleted file mode 100644 index 6bb3b82b5..000000000 --- a/test/iptables/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "iptables", - testonly = 1, - srcs = [ - "filter_input.go", - "filter_output.go", - "iptables.go", - "iptables_util.go", - "nat.go", - ], - visibility = ["//test/iptables:__subpackages__"], - deps = [ - "//runsc/testutil", - ], -) - -go_test( - name = "iptables_test", - srcs = [ - "iptables_test.go", - ], - library = ":iptables", - tags = [ - "local", - "manual", - ], - deps = [ - "//pkg/log", - "//runsc/dockerutil", - "//runsc/testutil", - ], -) diff --git a/test/iptables/README.md b/test/iptables/README.md deleted file mode 100644 index cc8a2fcac..000000000 --- a/test/iptables/README.md +++ /dev/null @@ -1,54 +0,0 @@ -# iptables Tests - -iptables tests are run via `scripts/iptables_test.sh`. - -iptables requires raw socket support, so you must add the `--net-raw=true` flag -to `/etc/docker/daemon.json` in order to use it. - -## Test Structure - -Each test implements `TestCase`, providing (1) a function to run inside the -container and (2) a function to run locally. Those processes are given each -others' IP addresses. The test succeeds when both functions succeed. - -The function inside the container (`ContainerAction`) typically sets some -iptables rules and then tries to send or receive packets. The local function -(`LocalAction`) will typically just send or receive packets. - -### Adding Tests - -1) Add your test to the `iptables` package. - -2) Register the test in an `init` function via `RegisterTestCase` (see -`filter_input.go` as an example). - -3) Add it to `iptables_test.go` (see the other tests in that file). - -Your test is now runnable with bazel! - -## Run individual tests - -Build and install `runsc`. Re-run this when you modify gVisor: - -```bash -$ bazel build //runsc && sudo cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc $(which runsc) -``` - -Build the testing Docker container. Re-run this when you modify the test code in -this directory: - -```bash -$ bazel run //test/iptables/runner:runner-image -- --norun -``` - -Run an individual test via: - -```bash -$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME> -``` - -To run an individual test with `runc`: - -```bash -$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME> --test_arg=--runtime=runc -``` diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go deleted file mode 100644 index 141d20fbb..000000000 --- a/test/iptables/filter_input.go +++ /dev/null @@ -1,598 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "errors" - "fmt" - "net" - "time" -) - -const ( - dropPort = 2401 - acceptPort = 2402 - sendloopDuration = 2 * time.Second - network = "udp4" - chainName = "foochain" -) - -func init() { - RegisterTestCase(FilterInputDropAll{}) - RegisterTestCase(FilterInputDropDifferentUDPPort{}) - RegisterTestCase(FilterInputDropOnlyUDP{}) - RegisterTestCase(FilterInputDropTCPDestPort{}) - RegisterTestCase(FilterInputDropTCPSrcPort{}) - RegisterTestCase(FilterInputDropUDPPort{}) - RegisterTestCase(FilterInputDropUDP{}) - RegisterTestCase(FilterInputCreateUserChain{}) - RegisterTestCase(FilterInputDefaultPolicyAccept{}) - RegisterTestCase(FilterInputDefaultPolicyDrop{}) - RegisterTestCase(FilterInputReturnUnderflow{}) - RegisterTestCase(FilterInputSerializeJump{}) - RegisterTestCase(FilterInputJumpBasic{}) - RegisterTestCase(FilterInputJumpReturn{}) - RegisterTestCase(FilterInputJumpReturnDrop{}) - RegisterTestCase(FilterInputJumpBuiltin{}) - RegisterTestCase(FilterInputJumpTwice{}) -} - -// FilterInputDropUDP tests that we can drop UDP traffic. -type FilterInputDropUDP struct{} - -// Name implements TestCase.Name. -func (FilterInputDropUDP) Name() string { - return "FilterInputDropUDP" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { - return err - } - - // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { - return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { - return fmt.Errorf("error reading: %v", err) - } - - // At this point we know that reading timed out and never received a - // packet. - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) -} - -// FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic. -type FilterInputDropOnlyUDP struct{} - -// Name implements TestCase.Name. -func (FilterInputDropOnlyUDP) Name() string { - return "FilterInputDropOnlyUDP" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { - return err - } - - // Listen for a TCP connection, which should be allowed. - if err := listenTCP(acceptPort, sendloopDuration); err != nil { - return fmt.Errorf("failed to establish a connection %v", err) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDropOnlyUDP) LocalAction(ip net.IP) error { - // Try to establish a TCP connection with the container, which should - // succeed. - return connectTCP(ip, acceptPort, sendloopDuration) -} - -// FilterInputDropUDPPort tests that we can drop UDP traffic by port. -type FilterInputDropUDPPort struct{} - -// Name implements TestCase.Name. -func (FilterInputDropUDPPort) Name() string { - return "FilterInputDropUDPPort" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDPPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { - return err - } - - // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { - return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { - return fmt.Errorf("error reading: %v", err) - } - - // At this point we know that reading timed out and never received a - // packet. - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) -} - -// FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port -// doesn't drop packets on other ports. -type FilterInputDropDifferentUDPPort struct{} - -// Name implements TestCase.Name. -func (FilterInputDropDifferentUDPPort) Name() string { - return "FilterInputDropDifferentUDPPort" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { - return err - } - - // Listen for UDP packets on another port. - if err := listenUDP(acceptPort, sendloopDuration); err != nil { - return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} - -// FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports. -type FilterInputDropTCPDestPort struct{} - -// Name implements TestCase.Name. -func (FilterInputDropTCPDestPort) Name() string { - return "FilterInputDropTCPDestPort" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { - return err - } - - // Listen for TCP packets on drop port. - if err := listenTCP(dropPort, sendloopDuration); err == nil { - return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, dropPort, sendloopDuration); err == nil { - return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) - } - - return nil -} - -// FilterInputDropTCPSrcPort tests that connections are not accepted on specified source ports. -type FilterInputDropTCPSrcPort struct{} - -// Name implements TestCase.Name. -func (FilterInputDropTCPSrcPort) Name() string { - return "FilterInputDropTCPSrcPort" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error { - // Drop anything from an ephemeral port. - if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { - return err - } - - // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { - return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { - return fmt.Errorf("connection should not be accepted, but was") - } - - return nil -} - -// FilterInputDropAll tests that we can drop all traffic to the INPUT chain. -type FilterInputDropAll struct{} - -// Name implements TestCase.Name. -func (FilterInputDropAll) Name() string { - return "FilterInputDropAll" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropAll) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-j", "DROP"); err != nil { - return err - } - - // Listen for all packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { - return fmt.Errorf("packets should have been dropped, but got a packet") - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { - return fmt.Errorf("error reading: %v", err) - } - - // At this point we know that reading timed out and never received a - // packet. - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDropAll) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) -} - -// FilterInputMultiUDPRules verifies that multiple UDP rules are applied -// correctly. This has the added benefit of testing whether we're serializing -// rules correctly -- if we do it incorrectly, the iptables tool will -// misunderstand and save the wrong tables. -type FilterInputMultiUDPRules struct{} - -// Name implements TestCase.Name. -func (FilterInputMultiUDPRules) Name() string { - return "FilterInputMultiUDPRules" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error { - rules := [][]string{ - {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, - {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, - {"-L"}, - } - return filterTableRules(rules) -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputMultiUDPRules) LocalAction(ip net.IP) error { - // No-op. - return nil -} - -// FilterInputRequireProtocolUDP checks that "-m udp" requires "-p udp" to be -// specified. -type FilterInputRequireProtocolUDP struct{} - -// Name implements TestCase.Name. -func (FilterInputRequireProtocolUDP) Name() string { - return "FilterInputRequireProtocolUDP" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil { - return errors.New("expected iptables to fail with out \"-p udp\", but succeeded") - } - return nil -} - -func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP) error { - // No-op. - return nil -} - -// FilterInputCreateUserChain tests chain creation. -type FilterInputCreateUserChain struct{} - -// Name implements TestCase.Name. -func (FilterInputCreateUserChain) Name() string { - return "FilterInputCreateUserChain" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error { - rules := [][]string{ - // Create a chain. - {"-N", chainName}, - // Add a simple rule to the chain. - {"-A", chainName, "-j", "DROP"}, - } - return filterTableRules(rules) -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputCreateUserChain) LocalAction(ip net.IP) error { - // No-op. - return nil -} - -// FilterInputDefaultPolicyAccept tests the default ACCEPT policy. -type FilterInputDefaultPolicyAccept struct{} - -// Name implements TestCase.Name. -func (FilterInputDefaultPolicyAccept) Name() string { - return "FilterInputDefaultPolicyAccept" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP) error { - // Set the default policy to accept, then receive a packet. - if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil { - return err - } - return listenUDP(acceptPort, sendloopDuration) -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} - -// FilterInputDefaultPolicyDrop tests the default DROP policy. -type FilterInputDefaultPolicyDrop struct{} - -// Name implements TestCase.Name. -func (FilterInputDefaultPolicyDrop) Name() string { - return "FilterInputDefaultPolicyDrop" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error { - if err := filterTable("-P", "INPUT", "DROP"); err != nil { - return err - } - - // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { - return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { - return fmt.Errorf("error reading: %v", err) - } - - // At this point we know that reading timed out and never received a - // packet. - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} - -// FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes -// the underflow rule (i.e. default policy) to be executed. -type FilterInputReturnUnderflow struct{} - -// Name implements TestCase.Name. -func (FilterInputReturnUnderflow) Name() string { - return "FilterInputReturnUnderflow" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { - // Add a RETURN rule followed by an unconditional accept, and set the - // default policy to DROP. - rules := [][]string{ - {"-A", "INPUT", "-j", "RETURN"}, - {"-A", "INPUT", "-j", "DROP"}, - {"-P", "INPUT", "ACCEPT"}, - } - if err := filterTableRules(rules); err != nil { - return err - } - - // We should receive packets, as the RETURN rule will trigger the default - // ACCEPT policy. - return listenUDP(acceptPort, sendloopDuration) -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} - -// FilterInputSerializeJump verifies that we can serialize jumps. -type FilterInputSerializeJump struct{} - -// Name implements TestCase.Name. -func (FilterInputSerializeJump) Name() string { - return "FilterInputSerializeJump" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputSerializeJump) ContainerAction(ip net.IP) error { - // Write a JUMP rule, the serialize it with `-L`. - rules := [][]string{ - {"-N", chainName}, - {"-A", "INPUT", "-j", chainName}, - {"-L"}, - } - return filterTableRules(rules) -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputSerializeJump) LocalAction(ip net.IP) error { - // No-op. - return nil -} - -// FilterInputJumpBasic jumps to a chain and executes a rule there. -type FilterInputJumpBasic struct{} - -// Name implements TestCase.Name. -func (FilterInputJumpBasic) Name() string { - return "FilterInputJumpBasic" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBasic) ContainerAction(ip net.IP) error { - rules := [][]string{ - {"-P", "INPUT", "DROP"}, - {"-N", chainName}, - {"-A", "INPUT", "-j", chainName}, - {"-A", chainName, "-j", "ACCEPT"}, - } - if err := filterTableRules(rules); err != nil { - return err - } - - // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBasic) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} - -// FilterInputJumpReturn jumps, returns, and executes a rule. -type FilterInputJumpReturn struct{} - -// Name implements TestCase.Name. -func (FilterInputJumpReturn) Name() string { - return "FilterInputJumpReturn" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { - rules := [][]string{ - {"-N", chainName}, - {"-P", "INPUT", "ACCEPT"}, - {"-A", "INPUT", "-j", chainName}, - {"-A", chainName, "-j", "RETURN"}, - {"-A", chainName, "-j", "DROP"}, - } - if err := filterTableRules(rules); err != nil { - return err - } - - // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturn) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} - -// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets. -type FilterInputJumpReturnDrop struct{} - -// Name implements TestCase.Name. -func (FilterInputJumpReturnDrop) Name() string { - return "FilterInputJumpReturnDrop" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error { - rules := [][]string{ - {"-N", chainName}, - {"-A", "INPUT", "-j", chainName}, - {"-A", "INPUT", "-j", "DROP"}, - {"-A", chainName, "-j", "RETURN"}, - } - if err := filterTableRules(rules); err != nil { - return err - } - - // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { - return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { - return fmt.Errorf("error reading: %v", err) - } - - // At this point we know that reading timed out and never received a - // packet. - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) -} - -// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal. -type FilterInputJumpBuiltin struct{} - -// Name implements TestCase.Name. -func (FilterInputJumpBuiltin) Name() string { - return "FilterInputJumpBuiltin" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil { - return fmt.Errorf("iptables should be unable to jump to a built-in chain") - } - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error { - // No-op. - return nil -} - -// FilterInputJumpTwice jumps twice, then returns twice and executes a rule. -type FilterInputJumpTwice struct{} - -// Name implements TestCase.Name. -func (FilterInputJumpTwice) Name() string { - return "FilterInputJumpTwice" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { - const chainName2 = chainName + "2" - rules := [][]string{ - {"-P", "INPUT", "DROP"}, - {"-N", chainName}, - {"-N", chainName2}, - {"-A", "INPUT", "-j", chainName}, - {"-A", chainName, "-j", chainName2}, - {"-A", "INPUT", "-j", "ACCEPT"}, - } - if err := filterTableRules(rules); err != nil { - return err - } - - // UDP packets should jump and return twice, eventually hitting the - // ACCEPT rule. - return listenUDP(acceptPort, sendloopDuration) -} - -// LocalAction implements TestCase.LocalAction. -func (FilterInputJumpTwice) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go deleted file mode 100644 index 1314a5a92..000000000 --- a/test/iptables/filter_output.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "fmt" - "net" -) - -func init() { - RegisterTestCase(FilterOutputDropTCPDestPort{}) - RegisterTestCase(FilterOutputDropTCPSrcPort{}) -} - -// FilterOutputDropTCPDestPort tests that connections are not accepted on -// specified source ports. -type FilterOutputDropTCPDestPort struct{} - -// Name implements TestCase.Name. -func (FilterOutputDropTCPDestPort) Name() string { - return "FilterOutputDropTCPDestPort" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { - return err - } - - // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { - return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { - return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) - } - - return nil -} - -// FilterOutputDropTCPSrcPort tests that connections are not accepted on -// specified source ports. -type FilterOutputDropTCPSrcPort struct{} - -// Name implements TestCase.Name. -func (FilterOutputDropTCPSrcPort) Name() string { - return "FilterOutputDropTCPSrcPort" -} - -// ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { - return err - } - - // Listen for TCP packets on drop port. - if err := listenTCP(dropPort, sendloopDuration); err == nil { - return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, dropPort, sendloopDuration); err == nil { - return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) - } - - return nil -} diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go deleted file mode 100644 index 2e565d988..000000000 --- a/test/iptables/iptables.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package iptables contains a set of iptables tests implemented as TestCases -package iptables - -import ( - "fmt" - "net" -) - -// IPExchangePort is the port the container listens on to receive the IP -// address of the local process. -const IPExchangePort = 2349 - -// A TestCase contains one action to run in the container and one to run -// locally. The actions run concurrently and each must succeed for the test -// pass. -type TestCase interface { - // Name returns the name of the test. - Name() string - - // ContainerAction runs inside the container. It receives the IP of the - // local process. - ContainerAction(ip net.IP) error - - // LocalAction runs locally. It receives the IP of the container. - LocalAction(ip net.IP) error -} - -// Tests maps test names to TestCase. -// -// New TestCases are added by calling RegisterTestCase in an init function. -var Tests = map[string]TestCase{} - -// RegisterTestCase registers tc so it can be run. -func RegisterTestCase(tc TestCase) { - if _, ok := Tests[tc.Name()]; ok { - panic(fmt.Sprintf("TestCase %s already registered.", tc.Name())) - } - Tests[tc.Name()] = tc -} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go deleted file mode 100644 index 56ba78107..000000000 --- a/test/iptables/iptables_test.go +++ /dev/null @@ -1,305 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "flag" - "fmt" - "net" - "os" - "path" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" -) - -const timeout = 18 * time.Second - -var image = flag.String("image", "bazel/test/iptables/runner:runner-image", "image to run tests in") - -type result struct { - output string - err error -} - -// singleTest runs a TestCase. Each test follows a pattern: -// - Create a container. -// - Get the container's IP. -// - Send the container our IP. -// - Start a new goroutine running the local action of the test. -// - Wait for both the container and local actions to finish. -// -// Container output is logged to $TEST_UNDECLARED_OUTPUTS_DIR if it exists, or -// to stderr. -func singleTest(test TestCase) error { - if _, ok := Tests[test.Name()]; !ok { - return fmt.Errorf("no test found with name %q. Has it been registered?", test.Name()) - } - - // Create and start the container. - cont := dockerutil.MakeDocker("gvisor-iptables") - defer cont.CleanUp() - resultChan := make(chan *result) - go func() { - output, err := cont.RunFg("--cap-add=NET_ADMIN", *image, "-name", test.Name()) - logContainer(output, err) - resultChan <- &result{output, err} - }() - - // Get the container IP. - ip, err := getIP(cont) - if err != nil { - return fmt.Errorf("failed to get container IP: %v", err) - } - - // Give the container our IP. - if err := sendIP(ip); err != nil { - return fmt.Errorf("failed to send IP to container: %v", err) - } - - // Run our side of the test. - errChan := make(chan error) - go func() { - errChan <- test.LocalAction(ip) - }() - - // Wait for both the container and local tests to finish. - var res *result - to := time.After(timeout) - for localDone := false; res == nil || !localDone; { - select { - case res = <-resultChan: - log.Infof("Container finished.") - case err, localDone = <-errChan: - log.Infof("Local finished.") - if err != nil { - return fmt.Errorf("local test failed: %v", err) - } - case <-to: - return fmt.Errorf("timed out after %f seconds", timeout.Seconds()) - } - } - - return res.err -} - -func getIP(cont dockerutil.Docker) (net.IP, error) { - // The container might not have started yet, so retry a few times. - var ipStr string - to := time.After(timeout) - for ipStr == "" { - ipStr, _ = cont.FindIP() - select { - case <-to: - return net.IP{}, fmt.Errorf("timed out getting IP after %f seconds", timeout.Seconds()) - default: - time.Sleep(250 * time.Millisecond) - } - } - ip := net.ParseIP(ipStr) - if ip == nil { - return net.IP{}, fmt.Errorf("invalid IP: %q", ipStr) - } - log.Infof("Container has IP of %s", ipStr) - return ip, nil -} - -func sendIP(ip net.IP) error { - contAddr := net.TCPAddr{ - IP: ip, - Port: IPExchangePort, - } - var conn *net.TCPConn - // The container may not be listening when we first connect, so retry - // upon error. - cb := func() error { - c, err := net.DialTCP("tcp4", nil, &contAddr) - conn = c - return err - } - if err := testutil.Poll(cb, timeout); err != nil { - return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err) - } - if _, err := conn.Write([]byte{0}); err != nil { - return fmt.Errorf("error writing to container: %v", err) - } - return nil -} - -func logContainer(output string, err error) { - msg := fmt.Sprintf("Container error: %v\nContainer output:\n%v", err, output) - if artifactsDir := os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); artifactsDir != "" { - fpath := path.Join(artifactsDir, "container.log") - if file, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE, 0644); err != nil { - log.Warningf("Failed to open log file %q: %v", fpath, err) - } else { - defer file.Close() - if _, err := file.Write([]byte(msg)); err == nil { - return - } - log.Warningf("Failed to write to log file %s: %v", fpath, err) - } - } - - // We couldn't write to the output directory -- just log to stderr. - log.Infof(msg) -} - -func TestFilterInputDropUDP(t *testing.T) { - if err := singleTest(FilterInputDropUDP{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDropUDPPort(t *testing.T) { - if err := singleTest(FilterInputDropUDPPort{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDropDifferentUDPPort(t *testing.T) { - if err := singleTest(FilterInputDropDifferentUDPPort{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDropAll(t *testing.T) { - if err := singleTest(FilterInputDropAll{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDropOnlyUDP(t *testing.T) { - if err := singleTest(FilterInputDropOnlyUDP{}); err != nil { - t.Fatal(err) - } -} - -func TestNATRedirectUDPPort(t *testing.T) { - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - if err := singleTest(NATRedirectUDPPort{}); err != nil { - t.Fatal(err) - } -} - -func TestNATRedirectTCPPort(t *testing.T) { - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - if err := singleTest(NATRedirectTCPPort{}); err != nil { - t.Fatal(err) - } -} - -func TestNATDropUDP(t *testing.T) { - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - if err := singleTest(NATDropUDP{}); err != nil { - t.Fatal(err) - } -} - -func TestNATAcceptAll(t *testing.T) { - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - if err := singleTest(NATAcceptAll{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDropTCPDestPort(t *testing.T) { - if err := singleTest(FilterInputDropTCPDestPort{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDropTCPSrcPort(t *testing.T) { - if err := singleTest(FilterInputDropTCPSrcPort{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputCreateUserChain(t *testing.T) { - if err := singleTest(FilterInputCreateUserChain{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDefaultPolicyAccept(t *testing.T) { - if err := singleTest(FilterInputDefaultPolicyAccept{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDefaultPolicyDrop(t *testing.T) { - if err := singleTest(FilterInputDefaultPolicyDrop{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputReturnUnderflow(t *testing.T) { - if err := singleTest(FilterInputReturnUnderflow{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterOutputDropTCPDestPort(t *testing.T) { - t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).") - if err := singleTest(FilterOutputDropTCPDestPort{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterOutputDropTCPSrcPort(t *testing.T) { - t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).") - if err := singleTest(FilterOutputDropTCPSrcPort{}); err != nil { - t.Fatal(err) - } -} - -func TestJumpSerialize(t *testing.T) { - if err := singleTest(FilterInputSerializeJump{}); err != nil { - t.Fatal(err) - } -} - -func TestJumpBasic(t *testing.T) { - if err := singleTest(FilterInputJumpBasic{}); err != nil { - t.Fatal(err) - } -} - -func TestJumpReturn(t *testing.T) { - if err := singleTest(FilterInputJumpReturn{}); err != nil { - t.Fatal(err) - } -} - -func TestJumpReturnDrop(t *testing.T) { - if err := singleTest(FilterInputJumpReturnDrop{}); err != nil { - t.Fatal(err) - } -} - -func TestJumpBuiltin(t *testing.T) { - if err := singleTest(FilterInputJumpBuiltin{}); err != nil { - t.Fatal(err) - } -} - -func TestJumpTwice(t *testing.T) { - if err := singleTest(FilterInputJumpTwice{}); err != nil { - t.Fatal(err) - } -} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go deleted file mode 100644 index 1f8dac4f1..000000000 --- a/test/iptables/iptables_util.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "fmt" - "net" - "os/exec" - "time" - - "gvisor.dev/gvisor/runsc/testutil" -) - -const iptablesBinary = "iptables" - -// filterTable calls `iptables -t filter` with the given args. -func filterTable(args ...string) error { - return tableCmd("filter", args) -} - -// natTable calls `iptables -t nat` with the given args. -func natTable(args ...string) error { - return tableCmd("nat", args) -} - -func tableCmd(table string, args []string) error { - args = append([]string{"-t", table}, args...) - cmd := exec.Command(iptablesBinary, args...) - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) - } - return nil -} - -// filterTableRules is like filterTable, but runs multiple iptables commands. -func filterTableRules(argsList [][]string) error { - for _, args := range argsList { - if err := filterTable(args...); err != nil { - return err - } - } - return nil -} - -// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for -// the first read on that port. -func listenUDP(port int, timeout time.Duration) error { - localAddr := net.UDPAddr{ - Port: port, - } - conn, err := net.ListenUDP(network, &localAddr) - if err != nil { - return err - } - defer conn.Close() - conn.SetDeadline(time.Now().Add(timeout)) - _, err = conn.Read([]byte{0}) - return err -} - -// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified -// over a duration. -func sendUDPLoop(ip net.IP, port int, duration time.Duration) error { - // Send packets for a few seconds. - remote := net.UDPAddr{ - IP: ip, - Port: port, - } - conn, err := net.DialUDP(network, nil, &remote) - if err != nil { - return err - } - defer conn.Close() - - to := time.After(duration) - for timedOut := false; !timedOut; { - // This may return an error (connection refused) if the remote - // hasn't started listening yet or they're dropping our - // packets. So we ignore Write errors and depend on the remote - // to report a failure if it doesn't get a packet it needs. - conn.Write([]byte{0}) - select { - case <-to: - timedOut = true - default: - time.Sleep(200 * time.Millisecond) - } - } - - return nil -} - -// listenTCP listens for connections on a TCP port. -func listenTCP(port int, timeout time.Duration) error { - localAddr := net.TCPAddr{ - Port: port, - } - - // Starts listening on port. - lConn, err := net.ListenTCP("tcp4", &localAddr) - if err != nil { - return err - } - defer lConn.Close() - - // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - conn, err := lConn.AcceptTCP() - if err != nil { - return err - } - conn.Close() - return nil -} - -// connectTCP connects to the given IP and port from an ephemeral local address. -func connectTCP(ip net.IP, port int, timeout time.Duration) error { - contAddr := net.TCPAddr{ - IP: ip, - Port: port, - } - // The container may not be listening when we first connect, so retry - // upon error. - callback := func() error { - conn, err := net.DialTCP("tcp4", nil, &contAddr) - if conn != nil { - conn.Close() - } - return err - } - if err := testutil.Poll(callback, timeout); err != nil { - return fmt.Errorf("timed out waiting to connect IP, most recent error: %v", err) - } - - return nil -} diff --git a/test/iptables/nat.go b/test/iptables/nat.go deleted file mode 100644 index 6ca6b46ca..000000000 --- a/test/iptables/nat.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "fmt" - "net" -) - -const ( - redirectPort = 42 -) - -func init() { - RegisterTestCase(NATRedirectUDPPort{}) - RegisterTestCase(NATRedirectTCPPort{}) - RegisterTestCase(NATDropUDP{}) - RegisterTestCase(NATAcceptAll{}) -} - -// NATRedirectUDPPort tests that packets are redirected to different port. -type NATRedirectUDPPort struct{} - -// Name implements TestCase.Name. -func (NATRedirectUDPPort) Name() string { - return "NATRedirectUDPPort" -} - -// ContainerAction implements TestCase.ContainerAction. -func (NATRedirectUDPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { - return err - } - - if err := listenUDP(redirectPort, sendloopDuration); err != nil { - return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (NATRedirectUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} - -// NATRedirectTCPPort tests that connections are redirected on specified ports. -type NATRedirectTCPPort struct{} - -// Name implements TestCase.Name. -func (NATRedirectTCPPort) Name() string { - return "NATRedirectTCPPort" -} - -// ContainerAction implements TestCase.ContainerAction. -func (NATRedirectTCPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { - return err - } - - // Listen for TCP packets on redirect port. - return listenTCP(redirectPort, sendloopDuration) -} - -// LocalAction implements TestCase.LocalAction. -func (NATRedirectTCPPort) LocalAction(ip net.IP) error { - return connectTCP(ip, dropPort, sendloopDuration) -} - -// NATDropUDP tests that packets are not received in ports other than redirect port. -type NATDropUDP struct{} - -// Name implements TestCase.Name. -func (NATDropUDP) Name() string { - return "NATDropUDP" -} - -// ContainerAction implements TestCase.ContainerAction. -func (NATDropUDP) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { - return err - } - - if err := listenUDP(acceptPort, sendloopDuration); err == nil { - return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (NATDropUDP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} - -// NATAcceptAll tests that all UDP packets are accepted. -type NATAcceptAll struct{} - -// Name implements TestCase.Name. -func (NATAcceptAll) Name() string { - return "NATAcceptAll" -} - -// ContainerAction implements TestCase.ContainerAction. -func (NATAcceptAll) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { - return err - } - - if err := listenUDP(acceptPort, sendloopDuration); err != nil { - return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) - } - - return nil -} - -// LocalAction implements TestCase.LocalAction. -func (NATAcceptAll) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) -} diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD deleted file mode 100644 index b9199387a..000000000 --- a/test/iptables/runner/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "container_image", "go_binary", "go_image") - -package(licenses = ["notice"]) - -go_binary( - name = "runner", - testonly = 1, - srcs = ["main.go"], - deps = ["//test/iptables"], -) - -container_image( - name = "iptables-base", - base = "@iptables-test//image", -) - -go_image( - name = "runner-image", - testonly = 1, - srcs = ["main.go"], - base = ":iptables-base", - deps = ["//test/iptables"], -) diff --git a/test/iptables/runner/Dockerfile b/test/iptables/runner/Dockerfile deleted file mode 100644 index b77db44a1..000000000 --- a/test/iptables/runner/Dockerfile +++ /dev/null @@ -1,4 +0,0 @@ -# This Dockerfile builds the image hosted at -# gcr.io/gvisor-presubmit/iptables-test. -FROM ubuntu -RUN apt update && apt install -y iptables diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go deleted file mode 100644 index 3c794114e..000000000 --- a/test/iptables/runner/main.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package main runs iptables tests from within a docker container. -package main - -import ( - "flag" - "fmt" - "log" - "net" - - "gvisor.dev/gvisor/test/iptables" -) - -var name = flag.String("name", "", "name of the test to run") - -func main() { - flag.Parse() - - // Find out which test we're running. - test, ok := iptables.Tests[*name] - if !ok { - log.Fatalf("No test found named %q", *name) - } - log.Printf("Running test %q", *name) - - // Get the IP of the local process. - ip, err := getIP() - if err != nil { - log.Fatal(err) - } - - // Run the test. - if err := test.ContainerAction(ip); err != nil { - log.Fatalf("Failed running test %q: %v", *name, err) - } -} - -// getIP listens for a connection from the local process and returns the source -// IP of that connection. -func getIP() (net.IP, error) { - localAddr := net.TCPAddr{ - Port: iptables.IPExchangePort, - } - listener, err := net.ListenTCP("tcp4", &localAddr) - if err != nil { - return net.IP{}, fmt.Errorf("failed listening for IP: %v", err) - } - defer listener.Close() - conn, err := listener.AcceptTCP() - if err != nil { - return net.IP{}, fmt.Errorf("failed accepting IP: %v", err) - } - defer conn.Close() - log.Printf("Connected to %v", conn.RemoteAddr()) - - return conn.RemoteAddr().(*net.TCPAddr).IP, nil -} diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD deleted file mode 100644 index fb0b2db41..000000000 --- a/test/packetdrill/BUILD +++ /dev/null @@ -1,48 +0,0 @@ -load("defs.bzl", "packetdrill_linux_test", "packetdrill_netstack_test", "packetdrill_test") - -package(licenses = ["notice"]) - -packetdrill_test( - name = "packetdrill_sanity_test", - scripts = ["sanity_test.pkt"], -) - -packetdrill_test( - name = "accept_ack_drop_test", - scripts = ["accept_ack_drop.pkt"], -) - -packetdrill_test( - name = "fin_wait2_timeout_test", - scripts = ["fin_wait2_timeout.pkt"], -) - -packetdrill_linux_test( - name = "tcp_user_timeout_test_linux_test", - scripts = ["linux/tcp_user_timeout.pkt"], -) - -packetdrill_netstack_test( - name = "tcp_user_timeout_test_netstack_test", - scripts = ["netstack/tcp_user_timeout.pkt"], -) - -packetdrill_test( - name = "listen_close_before_handshake_complete_test", - scripts = ["listen_close_before_handshake_complete.pkt"], -) - -packetdrill_test( - name = "no_rst_to_rst_test", - scripts = ["no_rst_to_rst.pkt"], -) - -packetdrill_test( - name = "tcp_defer_accept_test", - scripts = ["tcp_defer_accept.pkt"], -) - -packetdrill_test( - name = "tcp_defer_accept_timeout_test", - scripts = ["tcp_defer_accept_timeout.pkt"], -) diff --git a/test/packetdrill/Dockerfile b/test/packetdrill/Dockerfile deleted file mode 100644 index bd4451355..000000000 --- a/test/packetdrill/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -FROM ubuntu:bionic - -RUN apt-get update -RUN apt-get install -y net-tools git iptables iputils-ping netcat tcpdump jq tar -RUN hash -r -RUN git clone --branch packetdrill-v2.0 \ - https://github.com/google/packetdrill.git -RUN cd packetdrill/gtests/net/packetdrill && ./configure && \ - apt-get install -y bison flex make && make diff --git a/test/packetdrill/accept_ack_drop.pkt b/test/packetdrill/accept_ack_drop.pkt deleted file mode 100644 index 76e638fd4..000000000 --- a/test/packetdrill/accept_ack_drop.pkt +++ /dev/null @@ -1,27 +0,0 @@ -// Test that the accept works if the final ACK is dropped and an ack with data -// follows the dropped ack. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -// Set backlog to 1 so that we can easily test. -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0.0 > S. 0:0(0) ack 1 <...> - -+0.0 < . 1:5(4) ack 1 win 257 -+0.0 > . 1:1(0) ack 5 <...> - -// This should cause connection to transition to connected state. -+0.000 accept(3, ..., ...) = 4 -+0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 - -// Now read the data and we should get 4 bytes. -+0.000 read(4,..., 4) = 4 -+0.000 close(4) = 0 - -+0.0 > F. 1:1(0) ack 5 <...> -+0.0 < F. 5:5(0) ack 2 win 257 -+0.01 > . 2:2(0) ack 6 <...>
\ No newline at end of file diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl deleted file mode 100644 index f499c177b..000000000 --- a/test/packetdrill/defs.bzl +++ /dev/null @@ -1,87 +0,0 @@ -"""Defines a rule for packetdrill test targets.""" - -def _packetdrill_test_impl(ctx): - test_runner = ctx.executable._test_runner - runner = ctx.actions.declare_file("%s-runner" % ctx.label.name) - - script_paths = [] - for script in ctx.files.scripts: - script_paths.append(script.short_path) - runner_content = "\n".join([ - "#!/bin/bash", - # This test will run part in a distinct user namespace. This can cause - # permission problems, because all runfiles may not be owned by the - # current user, and no other users will be mapped in that namespace. - # Make sure that everything is readable here. - "find . -type f -exec chmod a+rx {} \\;", - "find . -type d -exec chmod a+rx {} \\;", - "%s %s --init_script %s $@ -- %s\n" % ( - test_runner.short_path, - " ".join(ctx.attr.flags), - ctx.files._init_script[0].short_path, - " ".join(script_paths), - ), - ]) - ctx.actions.write(runner, runner_content, is_executable = True) - - transitive_files = depset() - if hasattr(ctx.attr._test_runner, "data_runfiles"): - transitive_files = depset(ctx.attr._test_runner.data_runfiles.files) - runfiles = ctx.runfiles( - files = [test_runner] + ctx.files._init_script + ctx.files.scripts, - transitive_files = transitive_files, - collect_default = True, - collect_data = True, - ) - return [DefaultInfo(executable = runner, runfiles = runfiles)] - -_packetdrill_test = rule( - attrs = { - "_test_runner": attr.label( - executable = True, - cfg = "host", - allow_files = True, - default = "packetdrill_test.sh", - ), - "_init_script": attr.label( - allow_single_file = True, - default = "packetdrill_setup.sh", - ), - "flags": attr.string_list( - mandatory = False, - default = [], - ), - "scripts": attr.label_list( - mandatory = True, - allow_files = True, - ), - }, - test = True, - implementation = _packetdrill_test_impl, -) - -_PACKETDRILL_TAGS = ["local", "manual"] - -def packetdrill_linux_test(name, **kwargs): - if "tags" not in kwargs: - kwargs["tags"] = _PACKETDRILL_TAGS - _packetdrill_test( - name = name, - flags = ["--dut_platform", "linux"], - **kwargs - ) - -def packetdrill_netstack_test(name, **kwargs): - if "tags" not in kwargs: - kwargs["tags"] = _PACKETDRILL_TAGS - _packetdrill_test( - name = name, - # This is the default runtime unless - # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. - flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"], - **kwargs - ) - -def packetdrill_test(name, **kwargs): - packetdrill_linux_test(name + "_linux_test", **kwargs) - packetdrill_netstack_test(name + "_netstack_test", **kwargs) diff --git a/test/packetdrill/fin_wait2_timeout.pkt b/test/packetdrill/fin_wait2_timeout.pkt deleted file mode 100644 index 613f0bec9..000000000 --- a/test/packetdrill/fin_wait2_timeout.pkt +++ /dev/null @@ -1,23 +0,0 @@ -// Test that a socket in FIN_WAIT_2 eventually times out and a subsequent -// packet generates a RST. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0 > S. 0:0(0) ack 1 <...> -+0 < P. 1:1(0) ack 1 win 257 - -+0.100 accept(3, ..., ...) = 4 -// set FIN_WAIT2 timeout to 1 seconds. -+0.100 setsockopt(4, SOL_TCP, TCP_LINGER2, [1], 4) = 0 -+0 close(4) = 0 - -+0 > F. 1:1(0) ack 1 <...> -+0 < . 1:1(0) ack 2 win 257 - -+1.1 < . 1:1(0) ack 2 win 257 -+0 > R 2:2(0) win 0 diff --git a/test/packetdrill/linux/tcp_user_timeout.pkt b/test/packetdrill/linux/tcp_user_timeout.pkt deleted file mode 100644 index 38018cb42..000000000 --- a/test/packetdrill/linux/tcp_user_timeout.pkt +++ /dev/null @@ -1,39 +0,0 @@ -// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection -// if there is pending unacked data after the user specified timeout. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0 > S. 0:0(0) ack 1 <...> -+0.1 < . 1:1(0) ack 1 win 32792 - -+0.100 accept(3, ..., ...) = 4 - -// Okay, we received nothing, and decide to close this idle socket. -// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth -// trying hard to cleanly close this flow, at the price of keeping -// a TCP structure in kernel for about 1 minute! -+2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0 - -// The write/ack is required mainly for netstack as netstack does -// not update its RTO during the handshake. -+0 write(4, ..., 100) = 100 -+0 > P. 1:101(100) ack 1 <...> -+0 < . 1:1(0) ack 101 win 32792 - -+0 close(4) = 0 - -+0 > F. 101:101(0) ack 1 <...> -+.3~+.400 > F. 101:101(0) ack 1 <...> -+.3~+.400 > F. 101:101(0) ack 1 <...> -+.6~+.800 > F. 101:101(0) ack 1 <...> -+1.2~+1.300 > F. 101:101(0) ack 1 <...> - -// We finally receive something from the peer, but it is way too late -// Our socket vanished because TCP_USER_TIMEOUT was really small. -+.1 < . 1:2(1) ack 102 win 32792 -+0 > R 102:102(0) win 0 diff --git a/test/packetdrill/listen_close_before_handshake_complete.pkt b/test/packetdrill/listen_close_before_handshake_complete.pkt deleted file mode 100644 index 51c3f1a32..000000000 --- a/test/packetdrill/listen_close_before_handshake_complete.pkt +++ /dev/null @@ -1,31 +0,0 @@ -// Test that closing a listening socket closes any connections in SYN-RCVD -// state and any packets bound for these connections generate a RESET. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -// Set backlog to 1 so that we can easily test. -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0 > S. 0:0(0) ack 1 <...> - -+0.100 close(3) = 0 -+0.1 < P. 1:1(0) ack 1 win 257 - -// Linux generates a reset with no ack number/bit set. This is contradictory to -// what is specified in Rule 1 under Reset Generation in -// https://tools.ietf.org/html/rfc793#section-3.4. -// "1. If the connection does not exist (CLOSED) then a reset is sent -// in response to any incoming segment except another reset. In -// particular, SYNs addressed to a non-existent connection are rejected -// by this means. -// -// If the incoming segment has an ACK field, the reset takes its -// sequence number from the ACK field of the segment, otherwise the -// reset has sequence number zero and the ACK field is set to the sum -// of the sequence number and segment length of the incoming segment. -// The connection remains in the CLOSED state." - -+0.0 > R 1:1(0) win 0
\ No newline at end of file diff --git a/test/packetdrill/netstack/tcp_user_timeout.pkt b/test/packetdrill/netstack/tcp_user_timeout.pkt deleted file mode 100644 index 60103adba..000000000 --- a/test/packetdrill/netstack/tcp_user_timeout.pkt +++ /dev/null @@ -1,38 +0,0 @@ -// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection -// if there is pending unacked data after the user specified timeout. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0 > S. 0:0(0) ack 1 <...> -+0.1 < . 1:1(0) ack 1 win 32792 - -+0.100 accept(3, ..., ...) = 4 - -// Okay, we received nothing, and decide to close this idle socket. -// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth -// trying hard to cleanly close this flow, at the price of keeping -// a TCP structure in kernel for about 1 minute! -+2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0 - -// The write/ack is required mainly for netstack as netstack does -// not update its RTO during the handshake. -+0 write(4, ..., 100) = 100 -+0 > P. 1:101(100) ack 1 <...> -+0 < . 1:1(0) ack 101 win 32792 - -+0 close(4) = 0 - -+0 > F. 101:101(0) ack 1 <...> -+.2~+.300 > F. 101:101(0) ack 1 <...> -+.4~+.500 > F. 101:101(0) ack 1 <...> -+.8~+.900 > F. 101:101(0) ack 1 <...> - -// We finally receive something from the peer, but it is way too late -// Our socket vanished because TCP_USER_TIMEOUT was really small. -+1.61 < . 1:2(1) ack 102 win 32792 -+0 > R 102:102(0) win 0 diff --git a/test/packetdrill/no_rst_to_rst.pkt b/test/packetdrill/no_rst_to_rst.pkt deleted file mode 100644 index 612747827..000000000 --- a/test/packetdrill/no_rst_to_rst.pkt +++ /dev/null @@ -1,36 +0,0 @@ -// Test a RST is not generated in response to a RST and a RST is correctly -// generated when an accepted endpoint is RST due to an incoming RST. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0 > S. 0:0(0) ack 1 <...> -+0 < P. 1:1(0) ack 1 win 257 - -+0.100 accept(3, ..., ...) = 4 - -+0.200 < R 1:1(0) win 0 - -+0.300 read(4,..., 4) = -1 ECONNRESET (Connection Reset by Peer) - -+0.00 < . 1:1(0) ack 1 win 257 - -// Linux generates a reset with no ack number/bit set. This is contradictory to -// what is specified in Rule 1 under Reset Generation in -// https://tools.ietf.org/html/rfc793#section-3.4. -// "1. If the connection does not exist (CLOSED) then a reset is sent -// in response to any incoming segment except another reset. In -// particular, SYNs addressed to a non-existent connection are rejected -// by this means. -// -// If the incoming segment has an ACK field, the reset takes its -// sequence number from the ACK field of the segment, otherwise the -// reset has sequence number zero and the ACK field is set to the sum -// of the sequence number and segment length of the incoming segment. -// The connection remains in the CLOSED state." - -+0.00 > R 1:1(0) win 0
\ No newline at end of file diff --git a/test/packetdrill/packetdrill_setup.sh b/test/packetdrill/packetdrill_setup.sh deleted file mode 100755 index b858072f0..000000000 --- a/test/packetdrill/packetdrill_setup.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script runs both within the sentry context and natively. It should tweak -# TCP parameters to match expectations found in the script files. -sysctl -q net.ipv4.tcp_sack=1 -sysctl -q net.ipv4.tcp_rmem="4096 2097152 $((8*1024*1024))" -sysctl -q net.ipv4.tcp_wmem="4096 2097152 $((8*1024*1024))" - -# There may be errors from the above, but they will show up in the test logs and -# we always want to proceed from this point. It's possible that values were -# already set correctly and the nodes were not available in the namespace. -exit 0 diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh deleted file mode 100755 index c8268170f..000000000 --- a/test/packetdrill/packetdrill_test.sh +++ /dev/null @@ -1,225 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Run a packetdrill test. Two docker containers are made, one for the -# Device-Under-Test (DUT) and one for the test runner. Each is attached with -# two networks, one for control packets that aid the test and one for test -# packets which are sent as part of the test and observed for correctness. - -set -euxo pipefail - -function failure() { - local lineno=$1 - local msg=$2 - local filename="$0" - echo "FAIL: $filename:$lineno: $msg" -} -trap 'failure ${LINENO} "$BASH_COMMAND"' ERR - -declare -r LONGOPTS="dut_platform:,init_script:,runtime:" - -# Don't use declare below so that the error from getopt will end the script. -PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@") - -eval set -- "$PARSED" - -while true; do - case "$1" in - --dut_platform) - # Either "linux" or "netstack". - declare -r DUT_PLATFORM="$2" - shift 2 - ;; - --init_script) - declare -r INIT_SCRIPT="$2" - shift 2 - ;; - --runtime) - # Not readonly because there might be multiple --runtime arguments and we - # want to use just the last one. Only used if --dut_platform is - # "netstack". - declare RUNTIME="$2" - shift 2 - ;; - --) - shift - break - ;; - *) - echo "Programming error" - exit 3 - esac -done - -# All the other arguments are scripts. -declare -r scripts="$@" - -# Check that the required flags are defined in a way that is safe for "set -u". -if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then - if [[ -z "${RUNTIME-}" ]]; then - echo "FAIL: Missing --runtime argument: ${RUNTIME-}" - exit 2 - fi - declare -r RUNTIME_ARG="--runtime ${RUNTIME}" -elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then - declare -r RUNTIME_ARG="" -else - echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}" - exit 2 -fi -if [[ ! -x "${INIT_SCRIPT-}" ]]; then - echo "FAIL: Bad or missing --init_script: ${INIT_SCRIPT-}" - exit 2 -fi - -# Variables specific to the control network and interface start with CTRL_. -# Variables specific to the test network and interface start with TEST_. -# Variables specific to the DUT start with DUT_. -# Variables specific to the test runner start with TEST_RUNNER_. -declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill" -# Use random numbers so that test networks don't collide. -declare -r CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)" -declare -r TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)" -declare -r tolerance_usecs=100000 -# On both DUT and test runner, testing packets are on the eth2 interface. -declare -r TEST_DEVICE="eth2" -# Number of bits in the *_NET_PREFIX variables. -declare -r NET_MASK="24" -function new_net_prefix() { - # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. - echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)" -} -# Last bits of the DUT's IP address. -declare -r DUT_NET_SUFFIX=".10" -# Control port. -declare -r CTRL_PORT="40000" -# Last bits of the test runner's IP address. -declare -r TEST_RUNNER_NET_SUFFIX=".20" -declare -r TIMEOUT="60" -declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetdrill" - -# Make sure that docker is installed. -docker --version - -function finish { - local cleanup_success=1 - for net in "${CTRL_NET}" "${TEST_NET}"; do - # Kill all processes attached to ${net}. - for docker_command in "kill" "rm"; do - (docker network inspect "${net}" \ - --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \ - | xargs -r docker "${docker_command}") || \ - cleanup_success=0 - done - # Remove the network. - docker network rm "${net}" || \ - cleanup_success=0 - done - - if ((!$cleanup_success)); then - echo "FAIL: Cleanup command failed" - exit 4 - fi -} -trap finish EXIT - -# Subnet for control packets between test runner and DUT. -declare CTRL_NET_PREFIX=$(new_net_prefix) -while ! docker network create \ - "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do - sleep 0.1 - declare CTRL_NET_PREFIX=$(new_net_prefix) -done - -# Subnet for the packets that are part of the test. -declare TEST_NET_PREFIX=$(new_net_prefix) -while ! docker network create \ - "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do - sleep 0.1 - declare TEST_NET_PREFIX=$(new_net_prefix) -done - -docker pull "${IMAGE_TAG}" - -# Create the DUT container and connect to network. -DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \ - --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG}) -docker network connect "${CTRL_NET}" \ - --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \ - || (docker kill ${DUT}; docker rm ${DUT}; false) -docker network connect "${TEST_NET}" \ - --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \ - || (docker kill ${DUT}; docker rm ${DUT}; false) -docker start "${DUT}" - -# Create the test runner container and connect to network. -TEST_RUNNER=$(docker create --privileged --rm \ - --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG}) -docker network connect "${CTRL_NET}" \ - --ip "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \ - || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false) -docker network connect "${TEST_NET}" \ - --ip "${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \ - || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false) -docker start "${TEST_RUNNER}" - -# Run tcpdump in the test runner unbuffered, without dns resolution, just on the -# interface with the test packets. -docker exec -t ${TEST_RUNNER} tcpdump -U -n -i "${TEST_DEVICE}" & - -# Start a packetdrill server on the test_runner. The packetdrill server sends -# packets and asserts that they are received. -docker exec -d "${TEST_RUNNER}" \ - ${PACKETDRILL} --wire_server --wire_server_dev="${TEST_DEVICE}" \ - --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \ - --wire_server_port="${CTRL_PORT}" \ - --local_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \ - --remote_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" - -# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will -# issue a RST. To prevent this IPtables can be used to filter those out. -docker exec "${TEST_RUNNER}" \ - iptables -A OUTPUT -p tcp --tcp-flags RST RST -j DROP - -# Wait for the packetdrill server on the test runner to come. Attempt to -# connect to it from the DUT every 100 milliseconds until success. -while ! docker exec "${DUT}" \ - nc -zv "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${CTRL_PORT}"; do - sleep 0.1 -done - -# Copy the packetdrill setup script to the DUT. -docker cp -L "${INIT_SCRIPT}" "${DUT}:packetdrill_setup.sh" - -# Copy the packetdrill scripts to the DUT. -declare -a dut_scripts -for script in $scripts; do - docker cp -L "${script}" "${DUT}:$(basename ${script})" - dut_scripts+=("/$(basename ${script})") -done - -# Start a packetdrill client on the DUT. The packetdrill client runs POSIX -# socket commands and also sends instructions to the server. -docker exec -t "${DUT}" \ - ${PACKETDRILL} --wire_client --wire_client_dev="${TEST_DEVICE}" \ - --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \ - --wire_server_port="${CTRL_PORT}" \ - --local_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" \ - --remote_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \ - --init_scripts=/packetdrill_setup.sh \ - --tolerance_usecs="${tolerance_usecs}" "${dut_scripts[@]}" - -echo PASS: No errors. diff --git a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt deleted file mode 100644 index a86b90ce6..000000000 --- a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt +++ /dev/null @@ -1,9 +0,0 @@ -// Test that a listening socket generates a RST when it receives an -// ACK and syn cookies are not in use. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -+0 listen(3, 1) = 0 -+0.1 < . 1:1(0) ack 1 win 32792 -+0 > R 1:1(0) ack 0 win 0
\ No newline at end of file diff --git a/test/packetdrill/sanity_test.pkt b/test/packetdrill/sanity_test.pkt deleted file mode 100644 index b3b58c366..000000000 --- a/test/packetdrill/sanity_test.pkt +++ /dev/null @@ -1,7 +0,0 @@ -// Basic sanity test. One system call. -// -// All of the plumbing has to be working however, and the packetdrill wire -// client needs to be able to connect to the wire server and send the script, -// probe local interfaces, run through the test w/ timings, etc. - -0.000 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 diff --git a/test/packetdrill/tcp_defer_accept.pkt b/test/packetdrill/tcp_defer_accept.pkt deleted file mode 100644 index a17f946db..000000000 --- a/test/packetdrill/tcp_defer_accept.pkt +++ /dev/null @@ -1,48 +0,0 @@ -// Test that a bare ACK does not complete a connection when TCP_DEFER_ACCEPT -// timeout is not hit but an ACK w/ data does complete and deliver the -// connection to the accept queue. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [5], 4) = 0 -+0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0 -+0 bind(3, ..., ...) = 0 - -// Set backlog to 1 so that we can easily test. -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0.0 > S. 0:0(0) ack 1 <...> - -// Send a bare ACK this should not complete the connection as we -// set the TCP_DEFER_ACCEPT above. -+0.0 < . 1:1(0) ack 1 win 257 - -// The bare ACK should be dropped and no connection should be delivered -// to the accept queue. -+0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) - -// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT -// to 5 seconds above. -+2.5 < . 1:1(0) ack 1 win 257 -+0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) - -// set accept socket back to blocking. -+0.000 fcntl(3, F_SETFL, O_RDWR) = 0 - -// Now send an ACK w/ data. This should complete the connection -// and deliver the socket to the accept queue. -+0.1 < . 1:5(4) ack 1 win 257 -+0.0 > . 1:1(0) ack 5 <...> - -// This should cause connection to transition to connected state. -+0.000 accept(3, ..., ...) = 4 -+0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 - -// Now read the data and we should get 4 bytes. -+0.000 read(4,..., 4) = 4 -+0.000 close(4) = 0 - -+0.0 > F. 1:1(0) ack 5 <...> -+0.0 < F. 5:5(0) ack 2 win 257 -+0.01 > . 2:2(0) ack 6 <...>
\ No newline at end of file diff --git a/test/packetdrill/tcp_defer_accept_timeout.pkt b/test/packetdrill/tcp_defer_accept_timeout.pkt deleted file mode 100644 index 201fdeb14..000000000 --- a/test/packetdrill/tcp_defer_accept_timeout.pkt +++ /dev/null @@ -1,48 +0,0 @@ -// Test that a bare ACK is accepted after TCP_DEFER_ACCEPT timeout -// is hit and a connection is delivered. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [3], 4) = 0 -+0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0 -+0 bind(3, ..., ...) = 0 - -// Set backlog to 1 so that we can easily test. -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0.0 > S. 0:0(0) ack 1 <...> - -// Send a bare ACK this should not complete the connection as we -// set the TCP_DEFER_ACCEPT above. -+0.0 < . 1:1(0) ack 1 win 257 - -// The bare ACK should be dropped and no connection should be delivered -// to the accept queue. -+0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) - -// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT -// to 5 seconds above. -+2.5 < . 1:1(0) ack 1 win 257 -+0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) - -// set accept socket back to blocking. -+0.000 fcntl(3, F_SETFL, O_RDWR) = 0 - -// We should see one more retransmit of the SYN-ACK as a last ditch -// attempt when TCP_DEFER_ACCEPT timeout is hit to trigger another -// ACK or a packet with data. -+.35~+2.35 > S. 0:0(0) ack 1 <...> - -// Now send another bare ACK after TCP_DEFER_ACCEPT time has been passed. -+0.0 < . 1:1(0) ack 1 win 257 - -// The ACK above should cause connection to transition to connected state. -+0.000 accept(3, ..., ...) = 4 -+0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 - -+0.000 close(4) = 0 - -+0.0 > F. 1:1(0) ack 1 <...> -+0.0 < F. 1:1(0) ack 2 win 257 -+0.01 > . 2:2(0) ack 2 <...> diff --git a/test/perf/BUILD b/test/perf/BUILD deleted file mode 100644 index 0a0def6a3..000000000 --- a/test/perf/BUILD +++ /dev/null @@ -1,116 +0,0 @@ -load("//test/runner:defs.bzl", "syscall_test") - -package(licenses = ["notice"]) - -syscall_test( - test = "//test/perf/linux:clock_getres_benchmark", -) - -syscall_test( - test = "//test/perf/linux:clock_gettime_benchmark", -) - -syscall_test( - test = "//test/perf/linux:death_benchmark", -) - -syscall_test( - test = "//test/perf/linux:epoll_benchmark", -) - -syscall_test( - size = "large", - test = "//test/perf/linux:fork_benchmark", -) - -syscall_test( - size = "large", - test = "//test/perf/linux:futex_benchmark", -) - -syscall_test( - size = "enormous", - tags = ["nogotsan"], - test = "//test/perf/linux:getdents_benchmark", -) - -syscall_test( - size = "large", - test = "//test/perf/linux:getpid_benchmark", -) - -syscall_test( - size = "enormous", - tags = ["nogotsan"], - test = "//test/perf/linux:gettid_benchmark", -) - -syscall_test( - size = "large", - test = "//test/perf/linux:mapping_benchmark", -) - -syscall_test( - size = "large", - add_overlay = True, - test = "//test/perf/linux:open_benchmark", -) - -syscall_test( - test = "//test/perf/linux:pipe_benchmark", -) - -syscall_test( - size = "large", - add_overlay = True, - test = "//test/perf/linux:randread_benchmark", -) - -syscall_test( - size = "large", - add_overlay = True, - test = "//test/perf/linux:read_benchmark", -) - -syscall_test( - size = "large", - test = "//test/perf/linux:sched_yield_benchmark", -) - -syscall_test( - size = "large", - test = "//test/perf/linux:send_recv_benchmark", -) - -syscall_test( - size = "large", - add_overlay = True, - test = "//test/perf/linux:seqwrite_benchmark", -) - -syscall_test( - size = "enormous", - test = "//test/perf/linux:signal_benchmark", -) - -syscall_test( - test = "//test/perf/linux:sleep_benchmark", -) - -syscall_test( - size = "large", - add_overlay = True, - test = "//test/perf/linux:stat_benchmark", -) - -syscall_test( - size = "enormous", - add_overlay = True, - test = "//test/perf/linux:unlink_benchmark", -) - -syscall_test( - size = "large", - add_overlay = True, - test = "//test/perf/linux:write_benchmark", -) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD deleted file mode 100644 index b4e907826..000000000 --- a/test/perf/linux/BUILD +++ /dev/null @@ -1,356 +0,0 @@ -load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -cc_binary( - name = "getpid_benchmark", - testonly = 1, - srcs = [ - "getpid_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:test_main", - ], -) - -cc_binary( - name = "send_recv_benchmark", - testonly = 1, - srcs = [ - "send_recv_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/syscalls/linux:socket_test_util", - "//test/util:file_descriptor", - "//test/util:logging", - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "@com_google_absl//absl/synchronization", - ], -) - -cc_binary( - name = "gettid_benchmark", - testonly = 1, - srcs = [ - "gettid_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:test_main", - ], -) - -cc_binary( - name = "sched_yield_benchmark", - testonly = 1, - srcs = [ - "sched_yield_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "clock_getres_benchmark", - testonly = 1, - srcs = [ - "clock_getres_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:test_main", - ], -) - -cc_binary( - name = "clock_gettime_benchmark", - testonly = 1, - srcs = [ - "clock_gettime_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:test_main", - "@com_google_absl//absl/time", - ], -) - -cc_binary( - name = "open_benchmark", - testonly = 1, - srcs = [ - "open_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:fs_util", - "//test/util:logging", - "//test/util:temp_path", - "//test/util:test_main", - ], -) - -cc_binary( - name = "read_benchmark", - testonly = 1, - srcs = [ - "read_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:fs_util", - "//test/util:logging", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "randread_benchmark", - testonly = 1, - srcs = [ - "randread_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:file_descriptor", - "//test/util:logging", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "@com_google_absl//absl/random", - ], -) - -cc_binary( - name = "write_benchmark", - testonly = 1, - srcs = [ - "write_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:logging", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "seqwrite_benchmark", - testonly = 1, - srcs = [ - "seqwrite_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:logging", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "@com_google_absl//absl/random", - ], -) - -cc_binary( - name = "pipe_benchmark", - testonly = 1, - srcs = [ - "pipe_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:logging", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "fork_benchmark", - testonly = 1, - srcs = [ - "fork_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:logging", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "@com_google_absl//absl/synchronization", - ], -) - -cc_binary( - name = "futex_benchmark", - testonly = 1, - srcs = [ - "futex_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:logging", - "//test/util:test_main", - "//test/util:thread_util", - "@com_google_absl//absl/time", - ], -) - -cc_binary( - name = "epoll_benchmark", - testonly = 1, - srcs = [ - "epoll_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:epoll_util", - "//test/util:file_descriptor", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "@com_google_absl//absl/time", - ], -) - -cc_binary( - name = "death_benchmark", - testonly = 1, - srcs = [ - "death_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:logging", - "//test/util:test_main", - ], -) - -cc_binary( - name = "mapping_benchmark", - testonly = 1, - srcs = [ - "mapping_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:logging", - "//test/util:memory_util", - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "signal_benchmark", - testonly = 1, - srcs = [ - "signal_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:logging", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "getdents_benchmark", - testonly = 1, - srcs = [ - "getdents_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:file_descriptor", - "//test/util:fs_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sleep_benchmark", - testonly = 1, - srcs = [ - "sleep_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:logging", - "//test/util:test_main", - ], -) - -cc_binary( - name = "stat_benchmark", - testonly = 1, - srcs = [ - "stat_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:fs_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "@com_google_absl//absl/strings", - ], -) - -cc_binary( - name = "unlink_benchmark", - testonly = 1, - srcs = [ - "unlink_benchmark.cc", - ], - deps = [ - gbenchmark, - gtest, - "//test/util:fs_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc deleted file mode 100644 index b051293ad..000000000 --- a/test/perf/linux/clock_getres_benchmark.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <time.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" - -namespace gvisor { -namespace testing { - -namespace { - -// clock_getres(1) is very nearly a no-op syscall, but it does require copying -// out to a userspace struct. It thus provides a nice small copy-out benchmark. -void BM_ClockGetRes(benchmark::State& state) { - struct timespec ts; - for (auto _ : state) { - clock_getres(CLOCK_MONOTONIC, &ts); - } -} - -BENCHMARK(BM_ClockGetRes); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc deleted file mode 100644 index 6691bebd9..000000000 --- a/test/perf/linux/clock_gettime_benchmark.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <pthread.h> -#include <time.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "benchmark/benchmark.h" - -namespace gvisor { -namespace testing { - -namespace { - -void BM_ClockGettimeThreadCPUTime(benchmark::State& state) { - clockid_t clockid; - ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid)); - struct timespec tp; - - for (auto _ : state) { - clock_gettime(clockid, &tp); - } -} - -BENCHMARK(BM_ClockGettimeThreadCPUTime); - -void BM_VDSOClockGettime(benchmark::State& state) { - const clockid_t clock = state.range(0); - struct timespec tp; - absl::Time start = absl::Now(); - - // Don't benchmark the calibration phase. - while (absl::Now() < start + absl::Milliseconds(2100)) { - clock_gettime(clock, &tp); - } - - for (auto _ : state) { - clock_gettime(clock, &tp); - } -} - -BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc deleted file mode 100644 index cb2b6fd07..000000000 --- a/test/perf/linux/death_benchmark.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" - -namespace gvisor { -namespace testing { - -namespace { - -// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing -// the ability of gVisor (on whatever platform) to execute all the related -// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH. -TEST(DeathTest, ZeroEqualsOne) { - EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), ""); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc deleted file mode 100644 index 0b121338a..000000000 --- a/test/perf/linux/epoll_benchmark.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/epoll.h> -#include <sys/eventfd.h> - -#include <atomic> -#include <cerrno> -#include <cstdint> -#include <cstdlib> -#include <ctime> -#include <memory> - -#include "gtest/gtest.h" -#include "absl/time/time.h" -#include "benchmark/benchmark.h" -#include "test/util/epoll_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Returns a new eventfd. -PosixErrorOr<FileDescriptor> NewEventFD() { - int fd = eventfd(0, /* flags = */ 0); - MaybeSave(); - if (fd < 0) { - return PosixError(errno, "eventfd"); - } - return FileDescriptor(fd); -} - -// Also stolen from epoll.cc unit tests. -void BM_EpollTimeout(benchmark::State& state) { - constexpr int kFDsPerEpoll = 3; - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < kFDsPerEpoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO( - RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); - } - - struct epoll_event result[kFDsPerEpoll]; - int timeout_ms = state.range(0); - - for (auto _ : state) { - EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms)); - } -} - -BENCHMARK(BM_EpollTimeout)->Range(0, 8); - -// Also stolen from epoll.cc unit tests. -void BM_EpollAllEvents(benchmark::State& state) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - const int fds_per_epoll = state.range(0); - constexpr uint64_t kEventVal = 5; - - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < fds_per_epoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO( - RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); - - ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)), - SyscallSucceedsWithValue(sizeof(kEventVal))); - } - - std::vector<struct epoll_event> result(fds_per_epoll); - - for (auto _ : state) { - EXPECT_EQ(fds_per_epoll, - epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0)); - } -} - -BENCHMARK(BM_EpollAllEvents)->Range(2, 1024); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc deleted file mode 100644 index 84fdbc8a0..000000000 --- a/test/perf/linux/fork_benchmark.cc +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/synchronization/barrier.h" -#include "benchmark/benchmark.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr int kBusyMax = 250; - -// Do some CPU-bound busy-work. -int busy(int max) { - // Prevent the compiler from optimizing this work away, - volatile int count = 0; - - for (int i = 1; i < max; i++) { - for (int j = 2; j < i / 2; j++) { - if (i % j == 0) { - count++; - } - } - } - - return count; -} - -void BM_CPUBoundUniprocess(benchmark::State& state) { - for (auto _ : state) { - busy(kBusyMax); - } -} - -BENCHMARK(BM_CPUBoundUniprocess); - -void BM_CPUBoundAsymmetric(benchmark::State& state) { - const size_t max = state.max_iterations; - pid_t child = fork(); - if (child == 0) { - for (int i = 0; i < max; i++) { - busy(kBusyMax); - } - _exit(0); - } - ASSERT_THAT(child, SyscallSucceeds()); - ASSERT_TRUE(state.KeepRunningBatch(max)); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(0, WEXITSTATUS(status)); - ASSERT_FALSE(state.KeepRunning()); -} - -BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime(); - -void BM_CPUBoundSymmetric(benchmark::State& state) { - std::vector<pid_t> children; - auto child_cleanup = Cleanup([&] { - for (const pid_t child : children) { - int status; - EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(0, WEXITSTATUS(status)); - } - ASSERT_FALSE(state.KeepRunning()); - }); - - const int processes = state.range(0); - for (int i = 0; i < processes; i++) { - size_t cur = (state.max_iterations + (processes - 1)) / processes; - if ((state.iterations() + cur) >= state.max_iterations) { - cur = state.max_iterations - state.iterations(); - } - pid_t child = fork(); - if (child == 0) { - for (int i = 0; i < cur; i++) { - busy(kBusyMax); - } - _exit(0); - } - ASSERT_THAT(child, SyscallSucceeds()); - if (cur > 0) { - // We can have a zero cur here, depending. - ASSERT_TRUE(state.KeepRunningBatch(cur)); - } - children.push_back(child); - } -} - -BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime(); - -// Child routine for ProcessSwitch/ThreadSwitch. -// Reads from readfd and writes the result to writefd. -void SwitchChild(int readfd, int writefd) { - while (1) { - char buf; - int ret = ReadFd(readfd, &buf, 1); - if (ret == 0) { - break; - } - TEST_CHECK_MSG(ret == 1, "read failed"); - - ret = WriteFd(writefd, &buf, 1); - if (ret == -1) { - TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure"); - break; - } - TEST_CHECK_MSG(ret == 1, "write failed"); - } -} - -// Send bytes in a loop through a series of pipes, each passing through a -// different process. -// -// Proc 0 Proc 1 -// * ----------> * -// ^ Pipe 1 | -// | | -// | Pipe 0 | Pipe 2 -// | | -// | | -// | Pipe 3 v -// * <---------- * -// Proc 3 Proc 2 -// -// This exercises context switching through multiple processes. -void BM_ProcessSwitch(benchmark::State& state) { - // Code below assumes there are at least two processes. - const int num_processes = state.range(0); - ASSERT_GE(num_processes, 2); - - std::vector<pid_t> children; - auto child_cleanup = Cleanup([&] { - for (const pid_t child : children) { - int status; - EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(0, WEXITSTATUS(status)); - } - }); - - // Must come after children, as the FDs must be closed before the children - // will exit. - std::vector<FileDescriptor> read_fds; - std::vector<FileDescriptor> write_fds; - - for (int i = 0; i < num_processes; i++) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - read_fds.emplace_back(fds[0]); - write_fds.emplace_back(fds[1]); - } - - // This process is one of the processes in the loop. It will be considered - // index 0. - for (int i = 1; i < num_processes; i++) { - // Read from current pipe index, write to next. - const int read_index = i; - const int read_fd = read_fds[read_index].get(); - - const int write_index = (i + 1) % num_processes; - const int write_fd = write_fds[write_index].get(); - - // std::vector isn't safe to use from the fork child. - FileDescriptor* read_array = read_fds.data(); - FileDescriptor* write_array = write_fds.data(); - - pid_t child = fork(); - if (!child) { - // Close all other FDs. - for (int j = 0; j < num_processes; j++) { - if (j != read_index) { - read_array[j].reset(); - } - if (j != write_index) { - write_array[j].reset(); - } - } - - SwitchChild(read_fd, write_fd); - _exit(0); - } - ASSERT_THAT(child, SyscallSucceeds()); - children.push_back(child); - } - - // Read from current pipe index (0), write to next (1). - const int read_index = 0; - const int read_fd = read_fds[read_index].get(); - - const int write_index = 1; - const int write_fd = write_fds[write_index].get(); - - // Kick start the loop. - char buf = 'a'; - ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); - - for (auto _ : state) { - ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); - ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); - } -} - -BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime(); - -// Equivalent to BM_ThreadSwitch using threads instead of processes. -void BM_ThreadSwitch(benchmark::State& state) { - // Code below assumes there are at least two threads. - const int num_threads = state.range(0); - ASSERT_GE(num_threads, 2); - - // Must come after threads, as the FDs must be closed before the children - // will exit. - std::vector<std::unique_ptr<ScopedThread>> threads; - std::vector<FileDescriptor> read_fds; - std::vector<FileDescriptor> write_fds; - - for (int i = 0; i < num_threads; i++) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - read_fds.emplace_back(fds[0]); - write_fds.emplace_back(fds[1]); - } - - // This thread is one of the threads in the loop. It will be considered - // index 0. - for (int i = 1; i < num_threads; i++) { - // Read from current pipe index, write to next. - // - // Transfer ownership of the FDs to the thread. - const int read_index = i; - const int read_fd = read_fds[read_index].release(); - - const int write_index = (i + 1) % num_threads; - const int write_fd = write_fds[write_index].release(); - - threads.emplace_back(std::make_unique<ScopedThread>([read_fd, write_fd] { - FileDescriptor read(read_fd); - FileDescriptor write(write_fd); - SwitchChild(read.get(), write.get()); - })); - } - - // Read from current pipe index (0), write to next (1). - const int read_index = 0; - const int read_fd = read_fds[read_index].get(); - - const int write_index = 1; - const int write_fd = write_fds[write_index].get(); - - // Kick start the loop. - char buf = 'a'; - ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); - - for (auto _ : state) { - ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); - ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); - } - - // The two FDs still owned by this thread are closed, causing the next thread - // to exit its loop and close its FDs, and so on until all threads exit. -} - -BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime(); - -void BM_ThreadStart(benchmark::State& state) { - const int num_threads = state.range(0); - - for (auto _ : state) { - state.PauseTiming(); - - auto barrier = new absl::Barrier(num_threads + 1); - std::vector<std::unique_ptr<ScopedThread>> threads; - - state.ResumeTiming(); - - for (size_t i = 0; i < num_threads; ++i) { - threads.emplace_back(std::make_unique<ScopedThread>([barrier] { - if (barrier->Block()) { - delete barrier; - } - })); - } - - if (barrier->Block()) { - delete barrier; - } - - state.PauseTiming(); - - for (const auto& thread : threads) { - thread->Join(); - } - - state.ResumeTiming(); - } -} - -BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime(); - -// Benchmark the complete fork + exit + wait. -void BM_ProcessLifecycle(benchmark::State& state) { - const int num_procs = state.range(0); - - std::vector<pid_t> pids(num_procs); - for (auto _ : state) { - for (size_t i = 0; i < num_procs; ++i) { - int pid = fork(); - if (pid == 0) { - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - pids[i] = pid; - } - - for (const int pid : pids) { - ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0), - SyscallSucceedsWithValue(pid)); - } - } -} - -BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc deleted file mode 100644 index b349d50bf..000000000 --- a/test/perf/linux/futex_benchmark.cc +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <linux/futex.h> - -#include <atomic> -#include <cerrno> -#include <cstdint> -#include <cstdlib> -#include <ctime> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -inline int FutexWait(std::atomic<int32_t>* v, int32_t val) { - return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, nullptr); -} - -inline int FutexWaitRelativeTimeout(std::atomic<int32_t>* v, int32_t val, - const struct timespec* reltime) { - return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, reltime); -} - -inline int FutexWaitAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val, - const struct timespec* abstime) { - return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, abstime); -} - -inline int FutexWaitBitsetAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val, - int32_t bits, - const struct timespec* abstime) { - return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME, - val, abstime, nullptr, bits); -} - -inline int FutexWake(std::atomic<int32_t>* v, int32_t count) { - return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count); -} - -// This just uses FUTEX_WAKE on an address with nothing waiting, very simple. -void BM_FutexWakeNop(benchmark::State& state) { - std::atomic<int32_t> v(0); - - for (auto _ : state) { - EXPECT_EQ(0, FutexWake(&v, 1)); - } -} - -BENCHMARK(BM_FutexWakeNop); - -// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the -// syscall won't wait. -void BM_FutexWaitNop(benchmark::State& state) { - std::atomic<int32_t> v(0); - - for (auto _ : state) { - EXPECT_EQ(-EAGAIN, FutexWait(&v, 1)); - } -} - -BENCHMARK(BM_FutexWaitNop); - -// This uses FUTEX_WAIT with a timeout on an address whose value never -// changes, such that it always times out. Timeout overhead can be estimated by -// timer overruns for short timeouts. -void BM_FutexWaitTimeout(benchmark::State& state) { - const int timeout_ns = state.range(0); - std::atomic<int32_t> v(0); - auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns)); - - for (auto _ : state) { - EXPECT_EQ(-ETIMEDOUT, FutexWaitRelativeTimeout(&v, 0, &ts)); - } -} - -BENCHMARK(BM_FutexWaitTimeout) - ->Arg(1) - ->Arg(10) - ->Arg(100) - ->Arg(1000) - ->Arg(10000); - -// This calls FUTEX_WAIT_BITSET with CLOCK_REALTIME. -void BM_FutexWaitBitset(benchmark::State& state) { - std::atomic<int32_t> v(0); - int timeout_ns = state.range(0); - auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns)); - for (auto _ : state) { - EXPECT_EQ(-ETIMEDOUT, FutexWaitBitsetAbsoluteTimeout(&v, 0, 1, &ts)); - } -} - -BENCHMARK(BM_FutexWaitBitset)->Range(0, 100000); - -int64_t GetCurrentMonotonicTimeNanos() { - struct timespec ts; - TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1); - return ts.tv_sec * 1000000000ULL + ts.tv_nsec; -} - -void SpinNanos(int64_t delay_ns) { - if (delay_ns <= 0) { - return; - } - const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns; - while (GetCurrentMonotonicTimeNanos() < end) { - // spin - } -} - -// Each iteration of FutexRoundtripDelayed involves a thread sending a futex -// wakeup to another thread, which spins for delay_us and then sends a futex -// wakeup back. The time per iteration is 2* (delay_us + kBeforeWakeDelayNs + -// futex/scheduling overhead). -void BM_FutexRoundtripDelayed(benchmark::State& state) { - const int delay_us = state.range(0); - - const int64_t delay_ns = delay_us * 1000; - // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce - // the probability that the wakeup comes before the wait, preventing the wait - // from ever taking effect and causing the benchmark to underestimate the - // actual wakeup time. - constexpr int64_t kBeforeWakeDelayNs = 500; - std::atomic<int32_t> v(0); - ScopedThread t([&] { - for (int i = 0; i < state.max_iterations; i++) { - SpinNanos(delay_ns); - while (v.load(std::memory_order_acquire) == 0) { - FutexWait(&v, 0); - } - SpinNanos(kBeforeWakeDelayNs + delay_ns); - v.store(0, std::memory_order_release); - FutexWake(&v, 1); - } - }); - for (auto _ : state) { - SpinNanos(kBeforeWakeDelayNs + delay_ns); - v.store(1, std::memory_order_release); - FutexWake(&v, 1); - SpinNanos(delay_ns); - while (v.load(std::memory_order_acquire) == 1) { - FutexWait(&v, 1); - } - } -} - -BENCHMARK(BM_FutexRoundtripDelayed) - ->Arg(0) - ->Arg(10) - ->Arg(20) - ->Arg(50) - ->Arg(100); - -// FutexLock is a simple, dumb futex based lock implementation. -// It will try to acquire the lock by atomically incrementing the -// lock word. If it did not increment the lock from 0 to 1, someone -// else has the lock, so it will FUTEX_WAIT until it is woken in -// the unlock path. -class FutexLock { - public: - FutexLock() : lock_word_(0) {} - - void lock(struct timespec* deadline) { - int32_t val; - while ((val = lock_word_.fetch_add(1, std::memory_order_acquire) + 1) != - 1) { - // If we didn't get the lock by incrementing from 0 to 1, - // do a FUTEX_WAIT with the desired current value set to - // val. If val is no longer what the atomic increment returned, - // someone might have set it to 0 so we can try to acquire - // again. - int ret = FutexWaitAbsoluteTimeout(&lock_word_, val, deadline); - if (ret == 0 || ret == -EWOULDBLOCK || ret == -EINTR) { - continue; - } else { - FAIL() << "unexpected FUTEX_WAIT return: " << ret; - } - } - } - - void unlock() { - // Store 0 into the lock word and wake one waiter. We intentionally - // ignore the return value of the FUTEX_WAKE here, since there may be - // no waiters to wake anyway. - lock_word_.store(0, std::memory_order_release); - (void)FutexWake(&lock_word_, 1); - } - - private: - std::atomic<int32_t> lock_word_; -}; - -FutexLock* test_lock; // Used below. - -void FutexContend(benchmark::State& state, int thread_index, - struct timespec* deadline) { - int counter = 0; - if (thread_index == 0) { - test_lock = new FutexLock(); - } - for (auto _ : state) { - test_lock->lock(deadline); - counter++; - test_lock->unlock(); - } - if (thread_index == 0) { - delete test_lock; - } - state.SetItemsProcessed(state.iterations()); -} - -void BM_FutexContend(benchmark::State& state) { - FutexContend(state, state.thread_index, nullptr); -} - -BENCHMARK(BM_FutexContend)->ThreadRange(1, 1024)->UseRealTime(); - -void BM_FutexDeadlineContend(benchmark::State& state) { - auto deadline = absl::ToTimespec(absl::Now() + absl::Minutes(10)); - FutexContend(state, state.thread_index, &deadline); -} - -BENCHMARK(BM_FutexDeadlineContend)->ThreadRange(1, 1024)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc deleted file mode 100644 index afc599ad2..000000000 --- a/test/perf/linux/getdents_benchmark.cc +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -#ifndef SYS_getdents64 -#if defined(__x86_64__) -#define SYS_getdents64 217 -#elif defined(__aarch64__) -#define SYS_getdents64 217 -#else -#error "Unknown architecture" -#endif -#endif // SYS_getdents64 - -namespace gvisor { -namespace testing { - -namespace { - -constexpr int kBufferSize = 16384; - -PosixErrorOr<TempPath> CreateDirectory(int count, - std::vector<std::string>* files) { - ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir()); - - ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, - Open(dir.path(), O_RDONLY | O_DIRECTORY)); - - for (int i = 0; i < count; i++) { - auto file = NewTempRelPath(); - auto res = MknodAt(dfd, file, S_IFREG | 0644, 0); - RETURN_IF_ERRNO(res); - files->push_back(file); - } - - return std::move(dir); -} - -PosixError CleanupDirectory(const TempPath& dir, - std::vector<std::string>* files) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, - Open(dir.path(), O_RDONLY | O_DIRECTORY)); - - for (auto it = files->begin(); it != files->end(); ++it) { - auto res = UnlinkAt(dfd, *it, 0); - RETURN_IF_ERRNO(res); - } - return NoError(); -} - -// Creates a directory containing `files` files, and reads all the directory -// entries from the directory using a single FD. -void BM_GetdentsSameFD(benchmark::State& state) { - // Create directory with given files. - const int count = state.range(0); - - // Keep a vector of all of the file TempPaths that is destroyed before dir. - // - // Normally, we'd simply allow dir to recursively clean up the contained - // files, but that recursive cleanup uses getdents, which may be very slow in - // extreme benchmarks. - TempPath dir; - std::vector<std::string> files; - dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); - char buffer[kBufferSize]; - - // We read all directory entries on each iteration, but report this as a - // "batch" iteration so that reported times are per file. - while (state.KeepRunningBatch(count)) { - ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds()); - - int ret; - do { - ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), - SyscallSucceeds()); - } while (ret > 0); - } - - ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); - - state.SetItemsProcessed(state.iterations()); -} - -BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime(); - -// Creates a directory containing `files` files, and reads all the directory -// entries from the directory using a new FD each time. -void BM_GetdentsNewFD(benchmark::State& state) { - // Create directory with given files. - const int count = state.range(0); - - // Keep a vector of all of the file TempPaths that is destroyed before dir. - // - // Normally, we'd simply allow dir to recursively clean up the contained - // files, but that recursive cleanup uses getdents, which may be very slow in - // extreme benchmarks. - TempPath dir; - std::vector<std::string> files; - dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); - char buffer[kBufferSize]; - - // We read all directory entries on each iteration, but report this as a - // "batch" iteration so that reported times are per file. - while (state.KeepRunningBatch(count)) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); - - int ret; - do { - ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), - SyscallSucceeds()); - } while (ret > 0); - } - - ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); - - state.SetItemsProcessed(state.iterations()); -} - -BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc deleted file mode 100644 index db74cb264..000000000 --- a/test/perf/linux/getpid_benchmark.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/syscall.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" - -namespace gvisor { -namespace testing { - -namespace { - -void BM_Getpid(benchmark::State& state) { - for (auto _ : state) { - syscall(SYS_getpid); - } -} - -BENCHMARK(BM_Getpid); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/gettid_benchmark.cc b/test/perf/linux/gettid_benchmark.cc deleted file mode 100644 index 8f4961f5e..000000000 --- a/test/perf/linux/gettid_benchmark.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" - -namespace gvisor { -namespace testing { - -namespace { - -void BM_Gettid(benchmark::State& state) { - for (auto _ : state) { - syscall(SYS_gettid); - } -} - -BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc deleted file mode 100644 index 39c30fe69..000000000 --- a/test/perf/linux/mapping_benchmark.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdlib.h> -#include <sys/mman.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" -#include "test/util/memory_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Conservative value for /proc/sys/vm/max_map_count, which limits the number of -// VMAs, minus a safety margin for VMAs that already exist for the test binary. -// The default value for max_map_count is -// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530. -constexpr size_t kMaxVMAs = 64001; - -// Map then unmap pages without touching them. -void BM_MapUnmap(benchmark::State& state) { - // Number of pages to map. - const int pages = state.range(0); - - while (state.KeepRunning()) { - void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); - - int ret = munmap(addr, pages * kPageSize); - TEST_CHECK_MSG(ret == 0, "munmap failed"); - } -} - -BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime(); - -// Map, touch, then unmap pages. -void BM_MapTouchUnmap(benchmark::State& state) { - // Number of pages to map. - const int pages = state.range(0); - - while (state.KeepRunning()) { - void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); - - char* c = reinterpret_cast<char*>(addr); - char* end = c + pages * kPageSize; - while (c < end) { - *c = 42; - c += kPageSize; - } - - int ret = munmap(addr, pages * kPageSize); - TEST_CHECK_MSG(ret == 0, "munmap failed"); - } -} - -BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime(); - -// Map and touch many pages, unmapping all at once. -// -// NOTE(b/111429208): This is a regression test to ensure performant mapping and -// allocation even with tons of mappings. -void BM_MapTouchMany(benchmark::State& state) { - // Number of pages to map. - const int page_count = state.range(0); - - while (state.KeepRunning()) { - std::vector<void*> pages; - - for (int i = 0; i < page_count; i++) { - void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); - - char* c = reinterpret_cast<char*>(addr); - *c = 42; - - pages.push_back(addr); - } - - for (void* addr : pages) { - int ret = munmap(addr, kPageSize); - TEST_CHECK_MSG(ret == 0, "munmap failed"); - } - } - - state.SetBytesProcessed(kPageSize * page_count * state.iterations()); -} - -BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime(); - -void BM_PageFault(benchmark::State& state) { - // Map the region in which we will take page faults. To ensure that each page - // fault maps only a single page, each page we touch must correspond to a - // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However, - // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that - // if there are background threads running, they can't inadvertently creating - // mappings in our gaps that are unmapped when the test ends. - size_t test_pages = kMaxVMAs; - // Ensure that test_pages is odd, since we want the test region to both - // begin and end with a mapped page. - if (test_pages % 2 == 0) { - test_pages--; - } - const size_t test_region_bytes = test_pages * kPageSize; - // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on - // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE - // so that Linux pre-allocates the shmem file used to back the mapping. - Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE)); - for (size_t i = 0; i < test_pages / 2; i++) { - ASSERT_THAT( - mprotect(reinterpret_cast<void*>(m.addr() + ((2 * i + 1) * kPageSize)), - kPageSize, PROT_NONE), - SyscallSucceeds()); - } - - const size_t mapped_pages = test_pages / 2 + 1; - // "Start" at the end of the mapped region to force the mapped region to be - // reset, since we mapped it with MAP_POPULATE. - size_t cur_page = mapped_pages; - for (auto _ : state) { - if (cur_page >= mapped_pages) { - // We've reached the end of our mapped region and have to reset it to - // incur page faults again. - state.PauseTiming(); - ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED), - SyscallSucceeds()); - cur_page = 0; - state.ResumeTiming(); - } - const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize); - const char c = *reinterpret_cast<volatile char*>(addr); - benchmark::DoNotOptimize(c); - cur_page++; - } -} - -BENCHMARK(BM_PageFault)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc deleted file mode 100644 index 68008f6d5..000000000 --- a/test/perf/linux/open_benchmark.cc +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdlib.h> -#include <unistd.h> - -#include <memory> -#include <string> -#include <vector> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/fs_util.h" -#include "test/util/logging.h" -#include "test/util/temp_path.h" - -namespace gvisor { -namespace testing { - -namespace { - -void BM_Open(benchmark::State& state) { - const int size = state.range(0); - std::vector<TempPath> cache; - for (int i = 0; i < size; i++) { - auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - cache.emplace_back(std::move(path)); - } - - unsigned int seed = 1; - for (auto _ : state) { - const int chosen = rand_r(&seed) % size; - int fd = open(cache[chosen].path().c_str(), O_RDONLY); - TEST_CHECK(fd != -1); - close(fd); - } -} - -BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc deleted file mode 100644 index 8f5f6a2a3..000000000 --- a/test/perf/linux/pipe_benchmark.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdlib.h> -#include <sys/stat.h> -#include <unistd.h> - -#include <cerrno> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -void BM_Pipe(benchmark::State& state) { - int fds[2]; - TEST_CHECK(pipe(fds) == 0); - - const int size = state.range(0); - std::vector<char> wbuf(size); - std::vector<char> rbuf(size); - RandomizeBuffer(wbuf.data(), size); - - ScopedThread t([&] { - auto const fd = fds[1]; - for (int i = 0; i < state.max_iterations; i++) { - TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size); - } - }); - - for (auto _ : state) { - TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size); - } - - t.Join(); - - close(fds[0]); - close(fds[1]); - - state.SetBytesProcessed(static_cast<int64_t>(size) * - static_cast<int64_t>(state.iterations())); -} - -BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc deleted file mode 100644 index b0eb8c24e..000000000 --- a/test/perf/linux/randread_benchmark.cc +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdlib.h> -#include <sys/stat.h> -#include <sys/uio.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Create a 1GB file that will be read from at random positions. This should -// invalid any performance gains from caching. -const uint64_t kFileSize = 1ULL << 30; - -// How many bytes to write at once to initialize the file used to read from. -const uint32_t kWriteSize = 65536; - -// Largest benchmarked read unit. -const uint32_t kMaxRead = 1UL << 26; - -TempPath CreateFile(uint64_t file_size) { - auto path = TempPath::CreateFile().ValueOrDie(); - FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie(); - - // Try to minimize syscalls by using maximum size writev() requests. - std::vector<char> buffer(kWriteSize); - RandomizeBuffer(buffer.data(), buffer.size()); - const std::vector<std::vector<struct iovec>> iovecs_list = - GenerateIovecs(file_size, buffer.data(), buffer.size()); - for (const auto& iovecs : iovecs_list) { - TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0); - } - - return path; -} - -// Global test state, initialized once per process lifetime. -struct GlobalState { - const TempPath tmpfile; - explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {} -}; - -GlobalState& GetGlobalState() { - // This gets created only once throughout the lifetime of the process. - // Use a dynamically allocated object (that is never deleted) to avoid order - // of destruction of static storage variables issues. - static GlobalState* const state = - // The actual file size is the maximum random seek range (kFileSize) + the - // maximum read size so we can read that number of bytes at the end of the - // file. - new GlobalState(CreateFile(kFileSize + kMaxRead)); - return *state; -} - -void BM_RandRead(benchmark::State& state) { - const int size = state.range(0); - - GlobalState& global_state = GetGlobalState(); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY)); - std::vector<char> buf(size); - - unsigned int seed = 1; - for (auto _ : state) { - TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), - rand_r(&seed) % kFileSize) == size); - } - - state.SetBytesProcessed(static_cast<int64_t>(size) * - static_cast<int64_t>(state.iterations())); -} - -BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc deleted file mode 100644 index 62445867d..000000000 --- a/test/perf/linux/read_benchmark.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdlib.h> -#include <sys/stat.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/fs_util.h" -#include "test/util/logging.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -void BM_Read(benchmark::State& state) { - const int size = state.range(0); - const std::string contents(size, 0); - auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode)); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY)); - - std::vector<char> buf(size); - for (auto _ : state) { - TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size); - } - - state.SetBytesProcessed(static_cast<int64_t>(size) * - static_cast<int64_t>(state.iterations())); -} - -BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/sched_yield_benchmark.cc b/test/perf/linux/sched_yield_benchmark.cc deleted file mode 100644 index 6756b5575..000000000 --- a/test/perf/linux/sched_yield_benchmark.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sched.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -void BM_Sched_yield(benchmark::State& state) { - for (auto ignored : state) { - TEST_CHECK(sched_yield() == 0); - } -} - -BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc deleted file mode 100644 index d73e49523..000000000 --- a/test/perf/linux/send_recv_benchmark.cc +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <poll.h> -#include <sys/ioctl.h> -#include <sys/socket.h> - -#include <cstring> - -#include "gtest/gtest.h" -#include "absl/synchronization/notification.h" -#include "benchmark/benchmark.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr ssize_t kMessageSize = 1024; - -class Message { - public: - explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {} - - explicit Message(int byte, int sz) : Message(byte, sz, 0) {} - - explicit Message(int byte, int sz, int cmsg_sz) - : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) { - iov_.iov_base = buffer_.data(); - iov_.iov_len = sz; - hdr_.msg_iov = &iov_; - hdr_.msg_iovlen = 1; - hdr_.msg_control = cmsg_buffer_.data(); - hdr_.msg_controllen = cmsg_sz; - } - - struct msghdr* header() { - return &hdr_; - } - - private: - std::vector<char> buffer_; - std::vector<char> cmsg_buffer_; - struct iovec iov_ = {}; - struct msghdr hdr_ = {}; -}; - -void BM_Recvmsg(benchmark::State& state) { - int sockets[2]; - TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); - FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); - absl::Notification notification; - Message send_msg('a'), recv_msg; - - ScopedThread t([&send_msg, &send_socket, ¬ification] { - while (!notification.HasBeenNotified()) { - sendmsg(send_socket.get(), send_msg.header(), 0); - } - }); - - int64_t bytes_received = 0; - for (auto ignored : state) { - int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); - TEST_CHECK(n > 0); - bytes_received += n; - } - - notification.Notify(); - recv_socket.reset(); - - state.SetBytesProcessed(bytes_received); -} - -BENCHMARK(BM_Recvmsg)->UseRealTime(); - -void BM_Sendmsg(benchmark::State& state) { - int sockets[2]; - TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); - FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); - absl::Notification notification; - Message send_msg('a'), recv_msg; - - ScopedThread t([&recv_msg, &recv_socket, ¬ification] { - while (!notification.HasBeenNotified()) { - recvmsg(recv_socket.get(), recv_msg.header(), 0); - } - }); - - int64_t bytes_sent = 0; - for (auto ignored : state) { - int n = sendmsg(send_socket.get(), send_msg.header(), 0); - TEST_CHECK(n > 0); - bytes_sent += n; - } - - notification.Notify(); - send_socket.reset(); - - state.SetBytesProcessed(bytes_sent); -} - -BENCHMARK(BM_Sendmsg)->UseRealTime(); - -void BM_Recvfrom(benchmark::State& state) { - int sockets[2]; - TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); - FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); - absl::Notification notification; - char send_buffer[kMessageSize], recv_buffer[kMessageSize]; - - ScopedThread t([&send_socket, &send_buffer, ¬ification] { - while (!notification.HasBeenNotified()) { - sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); - } - }); - - int bytes_received = 0; - for (auto ignored : state) { - int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, - nullptr); - TEST_CHECK(n > 0); - bytes_received += n; - } - - notification.Notify(); - recv_socket.reset(); - - state.SetBytesProcessed(bytes_received); -} - -BENCHMARK(BM_Recvfrom)->UseRealTime(); - -void BM_Sendto(benchmark::State& state) { - int sockets[2]; - TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); - FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); - absl::Notification notification; - char send_buffer[kMessageSize], recv_buffer[kMessageSize]; - - ScopedThread t([&recv_socket, &recv_buffer, ¬ification] { - while (!notification.HasBeenNotified()) { - recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, - nullptr); - } - }); - - int64_t bytes_sent = 0; - for (auto ignored : state) { - int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); - TEST_CHECK(n > 0); - bytes_sent += n; - } - - notification.Notify(); - send_socket.reset(); - - state.SetBytesProcessed(bytes_sent); -} - -BENCHMARK(BM_Sendto)->UseRealTime(); - -PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - addr.ss_family = family; - switch (family) { - case AF_INET: - reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = - htonl(INADDR_LOOPBACK); - break; - case AF_INET6: - reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = - in6addr_loopback; - break; - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } - return addr; -} - -// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate -// space for control messages. Note that we do not expect to receive any. -void BM_RecvmsgWithControlBuf(benchmark::State& state) { - auto listen_socket = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); - - // Initialize address to the loopback one. - sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6)); - socklen_t addrlen = sizeof(addr); - - // Bind to some port then start listening. - ASSERT_THAT(bind(listen_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the address we're listening on, then connect to it. We need to do this - // because we're allowing the stack to pick a port for us. - ASSERT_THAT(getsockname(listen_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - auto send_socket = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); - - ASSERT_THAT( - RetryEINTR(connect)(send_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - // Accept the connection. - auto recv_socket = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); - - absl::Notification notification; - Message send_msg('a'); - // Create a msghdr with a buffer allocated for control messages. - Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24); - - ScopedThread t([&send_msg, &send_socket, ¬ification] { - while (!notification.HasBeenNotified()) { - sendmsg(send_socket.get(), send_msg.header(), 0); - } - }); - - int64_t bytes_received = 0; - for (auto ignored : state) { - int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); - TEST_CHECK(n > 0); - bytes_received += n; - } - - notification.Notify(); - recv_socket.reset(); - - state.SetBytesProcessed(bytes_received); -} - -BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime(); - -// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes. -// -// state.Args[0] indicates whether the underlying socket should be blocking or -// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking. -// state.Args[1] is the size of the payload to be used per sendmsg call. -void BM_SendmsgTCP(benchmark::State& state) { - auto listen_socket = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); - - // Initialize address to the loopback one. - sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); - socklen_t addrlen = sizeof(addr); - - // Bind to some port then start listening. - ASSERT_THAT(bind(listen_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the address we're listening on, then connect to it. We need to do this - // because we're allowing the stack to pick a port for us. - ASSERT_THAT(getsockname(listen_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - auto send_socket = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); - - ASSERT_THAT( - RetryEINTR(connect)(send_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - // Accept the connection. - auto recv_socket = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); - - // Check if we want to run the test w/ a blocking send socket - // or non-blocking. - const int blocking = state.range(0); - if (!blocking) { - // Set the send FD to O_NONBLOCK. - int opts; - ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds()); - opts |= O_NONBLOCK; - ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds()); - } - - absl::Notification notification; - - // Get the buffer size we should use for this iteration of the test. - const int buf_size = state.range(1); - Message send_msg('a', buf_size), recv_msg(0, buf_size); - - ScopedThread t([&recv_msg, &recv_socket, ¬ification] { - while (!notification.HasBeenNotified()) { - TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0); - } - }); - - int64_t bytes_sent = 0; - int ncalls = 0; - for (auto ignored : state) { - int sent = 0; - while (true) { - struct msghdr hdr = {}; - struct iovec iov = {}; - struct msghdr* snd_header = send_msg.header(); - iov.iov_base = static_cast<char*>(snd_header->msg_iov->iov_base) + sent; - iov.iov_len = snd_header->msg_iov->iov_len - sent; - hdr.msg_iov = &iov; - hdr.msg_iovlen = 1; - int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0); - ncalls++; - if (n > 0) { - sent += n; - if (sent == buf_size) { - break; - } - // n can be > 0 but less than requested size. In which case we don't - // poll. - continue; - } - // Poll the fd for it to become writable. - struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), - SyscallSucceedsWithValue(0)); - } - bytes_sent += static_cast<int64_t>(sent); - } - - notification.Notify(); - send_socket.reset(); - state.SetBytesProcessed(bytes_sent); -} - -void Args(benchmark::internal::Benchmark* benchmark) { - for (int blocking = 0; blocking < 2; blocking++) { - for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) { - benchmark->Args({blocking, buf_size}); - } - } -} - -BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc deleted file mode 100644 index af49e4477..000000000 --- a/test/perf/linux/seqwrite_benchmark.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdlib.h> -#include <sys/stat.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// The maximum file size of the test file, when writes get beyond this point -// they wrap around. This should be large enough to blow away caches. -const uint64_t kMaxFile = 1 << 30; - -// Perform writes of various sizes sequentially to one file. Wraps around if it -// goes above a certain maximum file size. -void BM_SeqWrite(benchmark::State& state) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); - - const int size = state.range(0); - std::vector<char> buf(size); - RandomizeBuffer(buf.data(), buf.size()); - - // Start writes at offset 0. - uint64_t offset = 0; - for (auto _ : state) { - TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) == - buf.size()); - offset += buf.size(); - // Wrap around if going above the maximum file size. - if (offset >= kMaxFile) { - offset = 0; - } - } - - state.SetBytesProcessed(static_cast<int64_t>(size) * - static_cast<int64_t>(state.iterations())); -} - -BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc deleted file mode 100644 index cec679191..000000000 --- a/test/perf/linux/signal_benchmark.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <string.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -void FixupHandler(int sig, siginfo_t* si, void* void_ctx) { - static unsigned int dataval = 0; - - // Skip the offending instruction. - ucontext_t* ctx = reinterpret_cast<ucontext_t*>(void_ctx); - ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast<greg_t>(&dataval); -} - -void BM_FaultSignalFixup(benchmark::State& state) { - // Set up the signal handler. - struct sigaction sa = {}; - sigemptyset(&sa.sa_mask); - sa.sa_sigaction = FixupHandler; - sa.sa_flags = SA_SIGINFO; - TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0); - - // Fault, fault, fault. - for (auto _ : state) { - // Trigger the segfault. - asm volatile( - "movq $0, %%rax\n" - "movq $0x77777777, (%%rax)\n" - : - : - : "rax"); - } -} - -BENCHMARK(BM_FaultSignalFixup)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc deleted file mode 100644 index 99ef05117..000000000 --- a/test/perf/linux/sleep_benchmark.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <sys/syscall.h> -#include <time.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Sleep for 'param' nanoseconds. -void BM_Sleep(benchmark::State& state) { - const int nanoseconds = state.range(0); - - for (auto _ : state) { - struct timespec ts; - ts.tv_sec = 0; - ts.tv_nsec = nanoseconds; - - int ret; - do { - ret = syscall(SYS_nanosleep, &ts, &ts); - if (ret < 0) { - TEST_CHECK(errno == EINTR); - } - } while (ret < 0); - } -} - -BENCHMARK(BM_Sleep) - ->Arg(0) - ->Arg(1) - ->Arg(1000) // 1us - ->Arg(1000 * 1000) // 1ms - ->Arg(10 * 1000 * 1000) // 10ms - ->Arg(50 * 1000 * 1000) // 50ms - ->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc deleted file mode 100644 index f15424482..000000000 --- a/test/perf/linux/stat_benchmark.cc +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "benchmark/benchmark.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Creates a file in a nested directory hierarchy at least `depth` directories -// deep, and stats that file multiple times. -void BM_Stat(benchmark::State& state) { - // Create nested directories with given depth. - int depth = state.range(0); - const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - std::string dir_path = top_dir.path(); - - while (depth-- > 0) { - // Don't use TempPath because it will make paths too long to use. - // - // The top_dir destructor will clean up this whole tree. - dir_path = JoinPath(dir_path, absl::StrCat(depth)); - ASSERT_NO_ERRNO(Mkdir(dir_path, 0755)); - } - - // Create the file that will be stat'd. - const TempPath file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path)); - - struct stat st; - for (auto _ : state) { - ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds()); - } -} - -BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc deleted file mode 100644 index 92243a042..000000000 --- a/test/perf/linux/unlink_benchmark.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Creates a directory containing `files` files, and unlinks all the files. -void BM_Unlink(benchmark::State& state) { - // Create directory with given files. - const int file_count = state.range(0); - - // We unlink all files on each iteration, but report this as a "batch" - // iteration so that reported times are per file. - TempPath dir; - while (state.KeepRunningBatch(file_count)) { - state.PauseTiming(); - // N.B. dir is declared outside the loop so that destruction of the previous - // iteration's directory occurs here, inside of PauseTiming. - dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - std::vector<TempPath> files; - for (int i = 0; i < file_count; i++) { - TempPath file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - files.push_back(std::move(file)); - } - state.ResumeTiming(); - - while (!files.empty()) { - // Destructor unlinks. - files.pop_back(); - } - } - - state.SetItemsProcessed(state.iterations()); -} - -BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc deleted file mode 100644 index 7b060c70e..000000000 --- a/test/perf/linux/write_benchmark.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdlib.h> -#include <sys/stat.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -void BM_Write(benchmark::State& state) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); - - const int size = state.range(0); - std::vector<char> buf(size); - RandomizeBuffer(buf.data(), size); - - for (auto _ : state) { - TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size); - } - - state.SetBytesProcessed(static_cast<int64_t>(size) * - static_cast<int64_t>(state.iterations())); -} - -BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime(); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/root/BUILD b/test/root/BUILD deleted file mode 100644 index ddc9b4955..000000000 --- a/test/root/BUILD +++ /dev/null @@ -1,46 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "root", - srcs = ["root.go"], -) - -go_test( - name = "root_test", - size = "small", - srcs = [ - "cgroup_test.go", - "chroot_test.go", - "crictl_test.go", - "main_test.go", - "oom_score_adj_test.go", - "runsc_test.go", - ], - data = [ - "//runsc", - ], - library = ":root", - tags = [ - # Requires docker and runsc to be configured before the test runs. - # Also test only runs as root. - "manual", - "local", - ], - visibility = ["//:sandbox"], - deps = [ - "//runsc/boot", - "//runsc/cgroup", - "//runsc/container", - "//runsc/criutil", - "//runsc/dockerutil", - "//runsc/specutils", - "//runsc/testutil", - "//test/root/testdata", - "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@com_github_syndtr_gocapability//capability:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go deleted file mode 100644 index 4038661cb..000000000 --- a/test/root/cgroup_test.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package root - -import ( - "bufio" - "fmt" - "io/ioutil" - "os" - "os/exec" - "path/filepath" - "strconv" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/runsc/cgroup" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" -) - -func verifyPid(pid int, path string) error { - f, err := os.Open(path) - if err != nil { - return err - } - defer f.Close() - - var gots []int - scanner := bufio.NewScanner(f) - for scanner.Scan() { - got, err := strconv.Atoi(scanner.Text()) - if err != nil { - return err - } - if got == pid { - return nil - } - gots = append(gots, got) - } - if scanner.Err() != nil { - return scanner.Err() - } - return fmt.Errorf("got: %s, want: %d", gots, pid) -} - -// TestCgroup sets cgroup options and checks that cgroup was properly configured. -func TestMemCGroup(t *testing.T) { - allocMemSize := 128 << 20 - if err := dockerutil.Pull("python"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("memusage-test") - - // Start a new container and allocate the specified about of memory. - args := []string{ - "--memory=256MB", - "python", - "python", - "-c", - fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize), - } - if err := d.Run(args...); err != nil { - t.Fatal("docker create failed:", err) - } - defer d.CleanUp() - - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) - } - t.Logf("cgroup ID: %s", gid) - - path := filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.usage_in_bytes") - memUsage := 0 - - // Wait when the container will allocate memory. - start := time.Now() - for time.Now().Sub(start) < 30*time.Second { - outRaw, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("failed to read %q: %v", path, err) - } - out := strings.TrimSpace(string(outRaw)) - memUsage, err = strconv.Atoi(out) - if err != nil { - t.Fatalf("Atoi(%v): %v", out, err) - } - - if memUsage > allocMemSize { - return - } - - time.Sleep(100 * time.Millisecond) - } - - t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20) -} - -// TestCgroup sets cgroup options and checks that cgroup was properly configured. -func TestCgroup(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") - - // This is not a comprehensive list of attributes. - // - // Note that we are specifically missing cpusets, which fail if specified. - // In any case, it's unclear if cpusets can be reliably tested here: these - // are often run on a single core virtual machine, and there is only a single - // CPU available in our current set, and every container's set. - attrs := []struct { - arg string - ctrl string - file string - want string - skipIfNotFound bool - }{ - { - arg: "--cpu-shares=1000", - ctrl: "cpu", - file: "cpu.shares", - want: "1000", - }, - { - arg: "--cpu-period=2000", - ctrl: "cpu", - file: "cpu.cfs_period_us", - want: "2000", - }, - { - arg: "--cpu-quota=3000", - ctrl: "cpu", - file: "cpu.cfs_quota_us", - want: "3000", - }, - { - arg: "--kernel-memory=100MB", - ctrl: "memory", - file: "memory.kmem.limit_in_bytes", - want: "104857600", - }, - { - arg: "--memory=1GB", - ctrl: "memory", - file: "memory.limit_in_bytes", - want: "1073741824", - }, - { - arg: "--memory-reservation=500MB", - ctrl: "memory", - file: "memory.soft_limit_in_bytes", - want: "524288000", - }, - { - arg: "--memory-swap=2GB", - ctrl: "memory", - file: "memory.memsw.limit_in_bytes", - want: "2147483648", - skipIfNotFound: true, // swap may be disabled on the machine. - }, - { - arg: "--memory-swappiness=5", - ctrl: "memory", - file: "memory.swappiness", - want: "5", - }, - { - arg: "--blkio-weight=750", - ctrl: "blkio", - file: "blkio.weight", - want: "750", - }, - } - - args := make([]string, 0, len(attrs)) - for _, attr := range attrs { - args = append(args, attr.arg) - } - - args = append(args, "alpine", "sleep", "10000") - if err := d.Run(args...); err != nil { - t.Fatal("docker create failed:", err) - } - defer d.CleanUp() - - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) - } - t.Logf("cgroup ID: %s", gid) - - // Check list of attributes defined above. - for _, attr := range attrs { - path := filepath.Join("/sys/fs/cgroup", attr.ctrl, "docker", gid, attr.file) - out, err := ioutil.ReadFile(path) - if err != nil { - if os.IsNotExist(err) && attr.skipIfNotFound { - t.Logf("skipped %s/%s", attr.ctrl, attr.file) - continue - } - t.Fatalf("failed to read %q: %v", path, err) - } - if got := strings.TrimSpace(string(out)); got != attr.want { - t.Errorf("arg: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.arg, attr.ctrl, attr.file, got, attr.want) - } - } - - // Check that sandbox is inside cgroup. - controllers := []string{ - "blkio", - "cpu", - "cpuset", - "memory", - "net_cls", - "net_prio", - "devices", - "freezer", - "perf_event", - "pids", - "systemd", - } - pid, err := d.SandboxPid() - if err != nil { - t.Fatalf("SandboxPid: %v", err) - } - for _, ctrl := range controllers { - path := filepath.Join("/sys/fs/cgroup", ctrl, "docker", gid, "cgroup.procs") - if err := verifyPid(pid, path); err != nil { - t.Errorf("cgroup control %q processes: %v", ctrl, err) - } - } -} - -func TestCgroupParent(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") - - parent := testutil.RandomName("runsc") - if err := d.Run("--cgroup-parent", parent, "alpine", "sleep", "10000"); err != nil { - t.Fatal("docker create failed:", err) - } - defer d.CleanUp() - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) - } - t.Logf("cgroup ID: %s", gid) - - // Check that sandbox is inside cgroup. - pid, err := d.SandboxPid() - if err != nil { - t.Fatalf("SandboxPid: %v", err) - } - - // Finds cgroup for the sandbox's parent process to check that cgroup is - // created in the right location relative to the parent. - cmd := fmt.Sprintf("grep PPid: /proc/%d/status | sed 's/PPid:\\s//'", pid) - ppid, err := exec.Command("bash", "-c", cmd).CombinedOutput() - if err != nil { - t.Fatalf("Executing %q: %v", cmd, err) - } - cgroups, err := cgroup.LoadPaths(strings.TrimSpace(string(ppid))) - if err != nil { - t.Fatalf("cgroup.LoadPath(%s): %v", ppid, err) - } - path := filepath.Join("/sys/fs/cgroup/memory", cgroups["memory"], parent, gid, "cgroup.procs") - if err := verifyPid(pid, path); err != nil { - t.Errorf("cgroup control %q processes: %v", "memory", err) - } -} diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go deleted file mode 100644 index be0f63d18..000000000 --- a/test/root/chroot_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package root is used for tests that requires sysadmin privileges run. -package root - -import ( - "fmt" - "io/ioutil" - "os/exec" - "path/filepath" - "strconv" - "strings" - "testing" - - "gvisor.dev/gvisor/runsc/dockerutil" -) - -// TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned -// up after the sandbox is destroyed. -func TestChroot(t *testing.T) { - d := dockerutil.MakeDocker("chroot-test") - if err := d.Run("alpine", "sleep", "10000"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - pid, err := d.SandboxPid() - if err != nil { - t.Fatalf("Docker.SandboxPid(): %v", err) - } - - // Check that sandbox is chroot'ed. - procRoot := filepath.Join("/proc", strconv.Itoa(pid), "root") - chroot, err := filepath.EvalSymlinks(procRoot) - if err != nil { - t.Fatalf("error resolving /proc/<pid>/root symlink: %v", err) - } - if chroot != "/" { - t.Errorf("sandbox is not chroot'd, it should be inside: /, got: %q", chroot) - } - - path, err := filepath.EvalSymlinks(filepath.Join("/proc", strconv.Itoa(pid), "cwd")) - if err != nil { - t.Fatalf("error resolving /proc/<pid>/cwd symlink: %v", err) - } - if chroot != path { - t.Errorf("sandbox current dir is wrong, want: %q, got: %q", chroot, path) - } - - fi, err := ioutil.ReadDir(procRoot) - if err != nil { - t.Fatalf("error listing %q: %v", chroot, err) - } - if want, got := 1, len(fi); want != got { - t.Fatalf("chroot dir got %d entries, want %d", got, want) - } - - // chroot dir is prepared by runsc and should contains only /proc. - if fi[0].Name() != "proc" { - t.Errorf("chroot got children %v, want %v", fi[0].Name(), "proc") - } - - d.CleanUp() -} - -func TestChrootGofer(t *testing.T) { - d := dockerutil.MakeDocker("chroot-test") - if err := d.Run("alpine", "sleep", "10000"); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // It's tricky to find gofers. Get sandbox PID first, then find parent. From - // parent get all immediate children, remove the sandbox, and everything else - // are gofers. - sandPID, err := d.SandboxPid() - if err != nil { - t.Fatalf("Docker.SandboxPid(): %v", err) - } - - // Find sandbox's parent PID. - cmd := fmt.Sprintf("grep PPid /proc/%d/status | awk '{print $2}'", sandPID) - parent, err := exec.Command("sh", "-c", cmd).CombinedOutput() - if err != nil { - t.Fatalf("failed to fetch runsc (%d) parent PID: %v, out:\n%s", sandPID, err, string(parent)) - } - parentPID, err := strconv.Atoi(strings.TrimSpace(string(parent))) - if err != nil { - t.Fatalf("failed to parse PPID %q: %v", string(parent), err) - } - - // Get all children from parent. - childrenOut, err := exec.Command("/usr/bin/pgrep", "-P", strconv.Itoa(parentPID)).CombinedOutput() - if err != nil { - t.Fatalf("failed to fetch containerd-shim children: %v", err) - } - children := strings.Split(strings.TrimSpace(string(childrenOut)), "\n") - - // This where the root directory is mapped on the host and that's where the - // gofer must have chroot'd to. - root := "/root" - - for _, child := range children { - childPID, err := strconv.Atoi(child) - if err != nil { - t.Fatalf("failed to parse child PID %q: %v", child, err) - } - if childPID == sandPID { - // Skip the sandbox, all other immediate children are gofers. - continue - } - - // Check that gofer is chroot'ed. - chroot, err := filepath.EvalSymlinks(filepath.Join("/proc", child, "root")) - if err != nil { - t.Fatalf("error resolving /proc/<pid>/root symlink: %v", err) - } - if root != chroot { - t.Errorf("gofer chroot is wrong, want: %q, got: %q", root, chroot) - } - - path, err := filepath.EvalSymlinks(filepath.Join("/proc", child, "cwd")) - if err != nil { - t.Fatalf("error resolving /proc/<pid>/cwd symlink: %v", err) - } - if root != path { - t.Errorf("gofer current dir is wrong, want: %q, got: %q", root, path) - } - } -} diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go deleted file mode 100644 index 3f90c4c6a..000000000 --- a/test/root/crictl_test.go +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package root - -import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "log" - "net/http" - "os" - "os/exec" - "path" - "path/filepath" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/runsc/criutil" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/root/testdata" -) - -// Tests for crictl have to be run as root (rather than in a user namespace) -// because crictl creates named network namespaces in /var/run/netns/. - -// TestCrictlSanity refers to b/112433158. -func TestCrictlSanity(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("httpd", testdata.Sandbox, testdata.Httpd) - if err != nil { - t.Fatal(err) - } - - // Look for the httpd page. - if err = httpGet(crictl, podID, "index.html"); err != nil { - t.Fatalf("failed to get page: %v", err) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) - } -} - -// TestMountPaths refers to b/117635704. -func TestMountPaths(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("httpd", testdata.Sandbox, testdata.HttpdMountPaths) - if err != nil { - t.Fatal(err) - } - - // Look for the directory available at /test. - if err = httpGet(crictl, podID, "test"); err != nil { - t.Fatalf("failed to get page: %v", err) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) - } -} - -// TestMountPaths refers to b/118728671. -func TestMountOverSymlinks(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, testdata.MountOverSymlink) - if err != nil { - t.Fatal(err) - } - - out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf") - if err != nil { - t.Fatal(err) - } - if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) { - t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out)) - } - - etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf") - if err != nil { - t.Fatal(err) - } - tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf") - if err != nil { - t.Fatal(err) - } - if tmp != etc { - t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp)) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) - } -} - -// TestHomeDir tests that the HOME environment variable is set for -// multi-containers. -func TestHomeDir(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - contSpec := testdata.SimpleSpec("root", "k8s.gcr.io/busybox", []string{"sleep", "1000"}) - podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, contSpec) - if err != nil { - t.Fatal(err) - } - - t.Run("root container", func(t *testing.T) { - out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") - if err != nil { - t.Fatal(err) - } - if got, want := strings.TrimSpace(string(out)), "/root"; got != want { - t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) - } - }) - - t.Run("sub-container", func(t *testing.T) { - // Create a sub container in the same pod. - subContSpec := testdata.SimpleSpec("subcontainer", "k8s.gcr.io/busybox", []string{"sleep", "1000"}) - subContID, err := crictl.StartContainer(podID, "k8s.gcr.io/busybox", testdata.Sandbox, subContSpec) - if err != nil { - t.Fatal(err) - } - - out, err := crictl.Exec(subContID, "sh", "-c", "echo $HOME") - if err != nil { - t.Fatal(err) - } - if got, want := strings.TrimSpace(string(out)), "/root"; got != want { - t.Fatalf("Home directory invalid. Got %q, Want: %q", got, want) - } - - if err := crictl.StopContainer(subContID); err != nil { - t.Fatal(err) - } - }) - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) - } - -} - -// setup sets up before a test. Specifically it: -// * Creates directories and a socket for containerd to utilize. -// * Runs containerd and waits for it to reach a "ready" state for testing. -// * Returns a cleanup function that should be called at the end of the test. -func setup(t *testing.T) (*criutil.Crictl, func(), error) { - var cleanups []func() - cleanupFunc := func() { - for i := len(cleanups) - 1; i >= 0; i-- { - cleanups[i]() - } - } - cleanup := specutils.MakeCleanup(cleanupFunc) - defer cleanup.Clean() - - // Create temporary containerd root and state directories, and a socket - // via which crictl and containerd communicate. - containerdRoot, err := ioutil.TempDir(testutil.TmpDir(), "containerd-root") - if err != nil { - t.Fatalf("failed to create containerd root: %v", err) - } - cleanups = append(cleanups, func() { os.RemoveAll(containerdRoot) }) - containerdState, err := ioutil.TempDir(testutil.TmpDir(), "containerd-state") - if err != nil { - t.Fatalf("failed to create containerd state: %v", err) - } - cleanups = append(cleanups, func() { os.RemoveAll(containerdState) }) - sockAddr := filepath.Join(testutil.TmpDir(), "containerd-test.sock") - - // We rewrite a configuration. This is based on the current docker - // configuration for the runtime under test. - runtime, err := dockerutil.RuntimePath() - if err != nil { - t.Fatalf("error discovering runtime path: %v", err) - } - config, err := testutil.WriteTmpFile("containerd-config", testdata.ContainerdConfig(runtime)) - if err != nil { - t.Fatalf("failed to write containerd config") - } - cleanups = append(cleanups, func() { os.RemoveAll(config) }) - - // Start containerd. - containerd := exec.Command(getContainerd(), - "--config", config, - "--log-level", "debug", - "--root", containerdRoot, - "--state", containerdState, - "--address", sockAddr) - cleanups = append(cleanups, func() { - if err := testutil.KillCommand(containerd); err != nil { - log.Printf("error killing containerd: %v", err) - } - }) - containerdStderr, err := containerd.StderrPipe() - if err != nil { - t.Fatalf("failed to get containerd stderr: %v", err) - } - containerdStdout, err := containerd.StdoutPipe() - if err != nil { - t.Fatalf("failed to get containerd stdout: %v", err) - } - if err := containerd.Start(); err != nil { - t.Fatalf("failed running containerd: %v", err) - } - - // Wait for containerd to boot. Then put all containerd output into a - // buffer to be logged at the end of the test. - testutil.WaitUntilRead(containerdStderr, "Start streaming server", nil, 10*time.Second) - stdoutBuf := &bytes.Buffer{} - stderrBuf := &bytes.Buffer{} - go func() { io.Copy(stdoutBuf, containerdStdout) }() - go func() { io.Copy(stderrBuf, containerdStderr) }() - cleanups = append(cleanups, func() { - t.Logf("containerd stdout: %s", string(stdoutBuf.Bytes())) - t.Logf("containerd stderr: %s", string(stderrBuf.Bytes())) - }) - - cleanup.Release() - return criutil.NewCrictl(20*time.Second, sockAddr), cleanupFunc, nil -} - -// httpGet GETs the contents of a file served from a pod on port 80. -func httpGet(crictl *criutil.Crictl, podID, filePath string) error { - // Get the IP of the httpd server. - ip, err := crictl.PodIP(podID) - if err != nil { - return fmt.Errorf("failed to get IP from pod %q: %v", podID, err) - } - - // GET the page. We may be waiting for the server to start, so retry - // with a timeout. - var resp *http.Response - cb := func() error { - r, err := http.Get(fmt.Sprintf("http://%s", path.Join(ip, filePath))) - resp = r - return err - } - if err := testutil.Poll(cb, 20*time.Second); err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - return fmt.Errorf("bad status returned: %d", resp.StatusCode) - } - return nil -} - -func getContainerd() string { - // Use the local path if it exists, otherwise, use the system one. - if _, err := os.Stat("/usr/local/bin/containerd"); err == nil { - return "/usr/local/bin/containerd" - } - return "/usr/bin/containerd" -} diff --git a/test/root/main_test.go b/test/root/main_test.go deleted file mode 100644 index d74dec85f..000000000 --- a/test/root/main_test.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package root - -import ( - "flag" - "fmt" - "os" - "testing" - - "github.com/syndtr/gocapability/capability" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/specutils" -) - -// TestMain is the main function for root tests. This function checks the -// supported docker version, required capabilities, and configures the executable -// path for runsc. -func TestMain(m *testing.M) { - flag.Parse() - - if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) { - fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.") - os.Exit(1) - } - - dockerutil.EnsureSupportedDockerVersion() - - // Configure exe for tests. - path, err := dockerutil.RuntimePath() - if err != nil { - panic(err.Error()) - } - specutils.ExePath = path - - os.Exit(m.Run()) -} diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go deleted file mode 100644 index 126f0975a..000000000 --- a/test/root/oom_score_adj_test.go +++ /dev/null @@ -1,386 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package root - -import ( - "fmt" - "os" - "testing" - - specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/runsc/boot" - "gvisor.dev/gvisor/runsc/container" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" -) - -var ( - maxOOMScoreAdj = 1000 - highOOMScoreAdj = 500 - lowOOMScoreAdj = -500 - minOOMScoreAdj = -1000 -) - -// Tests for oom_score_adj have to be run as root (rather than in a user -// namespace) because we need to adjust oom_score_adj for PIDs other than our -// own and test values below 0. - -// TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a -// single container sandbox. -func TestOOMScoreAdjSingle(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - ppid, err := specutils.GetParentPid(os.Getpid()) - if err != nil { - t.Fatalf("getting parent pid: %v", err) - } - parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid) - if err != nil { - t.Fatalf("getting parent oom_score_adj: %v", err) - } - - testCases := []struct { - Name string - - // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then - // no value is set. - OOMScoreAdj *int - }{ - { - Name: "max", - OOMScoreAdj: &maxOOMScoreAdj, - }, - { - Name: "high", - OOMScoreAdj: &highOOMScoreAdj, - }, - { - Name: "low", - OOMScoreAdj: &lowOOMScoreAdj, - }, - { - Name: "min", - OOMScoreAdj: &minOOMScoreAdj, - }, - { - Name: "nil", - OOMScoreAdj: &parentOOMScoreAdj, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.Name, func(t *testing.T) { - id := testutil.UniqueContainerID() - s := testutil.NewSpecWithArgs("sleep", "1000") - s.Process.OOMScoreAdj = testCase.OOMScoreAdj - - containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id}) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - c := containers[0] - - // Verify the gofer's oom_score_adj - if testCase.OOMScoreAdj != nil { - goferScore, err := specutils.GetOOMScoreAdj(c.GoferPid) - if err != nil { - t.Fatalf("error reading gofer oom_score_adj: %v", err) - } - if goferScore != *testCase.OOMScoreAdj { - t.Errorf("gofer oom_score_adj got: %d, want: %d", goferScore, *testCase.OOMScoreAdj) - } - - // Verify the sandbox's oom_score_adj. - // - // The sandbox should be the same for all containers so just use - // the first one. - sandboxPid := c.Sandbox.Pid - sandboxScore, err := specutils.GetOOMScoreAdj(sandboxPid) - if err != nil { - t.Fatalf("error reading sandbox oom_score_adj: %v", err) - } - if sandboxScore != *testCase.OOMScoreAdj { - t.Errorf("sandbox oom_score_adj got: %d, want: %d", sandboxScore, *testCase.OOMScoreAdj) - } - } - }) - } -} - -// TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a -// multi-container sandbox. -func TestOOMScoreAdjMulti(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - - ppid, err := specutils.GetParentPid(os.Getpid()) - if err != nil { - t.Fatalf("getting parent pid: %v", err) - } - parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid) - if err != nil { - t.Fatalf("getting parent oom_score_adj: %v", err) - } - - testCases := []struct { - Name string - - // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then - // no value is set. One value for each container. The first value is the - // root container. - OOMScoreAdj []*int - - // Expected is the expected oom_score_adj of the sandbox. If nil, then - // this value is ignored. - Expected *int - - // Remove is a set of container indexes to remove from the sandbox. - Remove []int - - // ExpectedAfterRemove is the expected oom_score_adj of the sandbox - // after containers are removed. Ignored if nil. - ExpectedAfterRemove *int - }{ - // A single container CRI test case. This should not happen in - // practice as there should be at least one container besides the pause - // container. However, we include a test case to ensure sane behavior. - { - Name: "single", - OOMScoreAdj: []*int{&highOOMScoreAdj}, - Expected: &parentOOMScoreAdj, - }, - { - Name: "multi_no_value", - OOMScoreAdj: []*int{nil, nil, nil}, - Expected: &parentOOMScoreAdj, - }, - { - Name: "multi_non_nil_root", - OOMScoreAdj: []*int{&minOOMScoreAdj, nil, nil}, - Expected: &parentOOMScoreAdj, - }, - { - Name: "multi_value", - OOMScoreAdj: []*int{&minOOMScoreAdj, &highOOMScoreAdj, &lowOOMScoreAdj}, - // The lowest value excluding the root container is expected. - Expected: &lowOOMScoreAdj, - }, - { - Name: "multi_min_value", - OOMScoreAdj: []*int{&minOOMScoreAdj, &lowOOMScoreAdj}, - // The lowest value excluding the root container is expected. - Expected: &lowOOMScoreAdj, - }, - { - Name: "multi_max_value", - OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj}, - // The lowest value excluding the root container is expected. - Expected: &highOOMScoreAdj, - }, - { - Name: "remove_adjusted", - OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj}, - // The lowest value excluding the root container is expected. - Expected: &highOOMScoreAdj, - // Remove highOOMScoreAdj container. - Remove: []int{2}, - ExpectedAfterRemove: &maxOOMScoreAdj, - }, - { - // This test removes all non-root sandboxes with a specified oomScoreAdj. - Name: "remove_to_nil", - OOMScoreAdj: []*int{&minOOMScoreAdj, nil, &lowOOMScoreAdj}, - Expected: &lowOOMScoreAdj, - // Remove lowOOMScoreAdj container. - Remove: []int{2}, - // The oom_score_adj expected after remove is that of the parent process. - ExpectedAfterRemove: &parentOOMScoreAdj, - }, - { - Name: "remove_no_effect", - OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj}, - // The lowest value excluding the root container is expected. - Expected: &highOOMScoreAdj, - // Remove the maxOOMScoreAdj container. - Remove: []int{1}, - ExpectedAfterRemove: &highOOMScoreAdj, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.Name, func(t *testing.T) { - var cmds [][]string - var oomScoreAdj []*int - var toRemove []string - - for _, oomScore := range testCase.OOMScoreAdj { - oomScoreAdj = append(oomScoreAdj, oomScore) - cmds = append(cmds, []string{"sleep", "100"}) - } - - specs, ids := createSpecs(cmds...) - for i, spec := range specs { - // Ensure the correct value is set, including no value. - spec.Process.OOMScoreAdj = oomScoreAdj[i] - - for _, j := range testCase.Remove { - if i == j { - toRemove = append(toRemove, ids[i]) - } - } - } - - containers, cleanup, err := startContainers(conf, specs, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() - - for i, c := range containers { - if oomScoreAdj[i] != nil { - // Verify the gofer's oom_score_adj - score, err := specutils.GetOOMScoreAdj(c.GoferPid) - if err != nil { - t.Fatalf("error reading gofer oom_score_adj: %v", err) - } - if score != *oomScoreAdj[i] { - t.Errorf("gofer oom_score_adj got: %d, want: %d", score, *oomScoreAdj[i]) - } - } - } - - // Verify the sandbox's oom_score_adj. - // - // The sandbox should be the same for all containers so just use - // the first one. - sandboxPid := containers[0].Sandbox.Pid - if testCase.Expected != nil { - score, err := specutils.GetOOMScoreAdj(sandboxPid) - if err != nil { - t.Fatalf("error reading sandbox oom_score_adj: %v", err) - } - if score != *testCase.Expected { - t.Errorf("sandbox oom_score_adj got: %d, want: %d", score, *testCase.Expected) - } - } - - if len(toRemove) == 0 { - return - } - - // Remove containers. - for _, removeID := range toRemove { - for _, c := range containers { - if c.ID == removeID { - c.Destroy() - } - } - } - - // Check the new adjusted oom_score_adj. - if testCase.ExpectedAfterRemove != nil { - scoreAfterRemove, err := specutils.GetOOMScoreAdj(sandboxPid) - if err != nil { - t.Fatalf("error reading sandbox oom_score_adj: %v", err) - } - if scoreAfterRemove != *testCase.ExpectedAfterRemove { - t.Errorf("sandbox oom_score_adj got: %d, want: %d", scoreAfterRemove, *testCase.ExpectedAfterRemove) - } - } - }) - } -} - -func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { - var specs []*specs.Spec - var ids []string - rootID := testutil.UniqueContainerID() - - for i, cmd := range cmds { - spec := testutil.NewSpecWithArgs(cmd...) - if i == 0 { - spec.Annotations = map[string]string{ - specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeSandbox, - } - ids = append(ids, rootID) - } else { - spec.Annotations = map[string]string{ - specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer, - specutils.ContainerdSandboxIDAnnotation: rootID, - } - ids = append(ids, testutil.UniqueContainerID()) - } - specs = append(specs, spec) - } - return specs, ids -} - -func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { - if len(conf.RootDir) == 0 { - panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") - } - - var containers []*container.Container - var bundles []string - cleanup := func() { - for _, c := range containers { - c.Destroy() - } - for _, b := range bundles { - os.RemoveAll(b) - } - } - for i, spec := range specs { - bundleDir, err := testutil.SetupBundleDir(spec) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("error setting up container: %v", err) - } - bundles = append(bundles, bundleDir) - - args := container.Args{ - ID: ids[i], - Spec: spec, - BundleDir: bundleDir, - } - cont, err := container.New(conf, args) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("error creating container: %v", err) - } - containers = append(containers, cont) - - if err := cont.Start(conf); err != nil { - cleanup() - return nil, nil, fmt.Errorf("error starting container: %v", err) - } - } - return containers, cleanup, nil -} diff --git a/test/root/root.go b/test/root/root.go deleted file mode 100644 index 0f1d29faf..000000000 --- a/test/root/root.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package root is used for tests that requires sysadmin privileges run. First, -// follow the setup instruction in runsc/test/README.md. You should also have -// docker, containerd, and crictl installed. To run these tests from the -// project root directory: -// -// ./scripts/root_tests.sh -package root diff --git a/test/root/runsc_test.go b/test/root/runsc_test.go deleted file mode 100644 index 90373e2db..000000000 --- a/test/root/runsc_test.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package root - -import ( - "bytes" - "fmt" - "io/ioutil" - "os" - "os/exec" - "path/filepath" - "strconv" - "strings" - "testing" - "time" - - "github.com/cenkalti/backoff" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" -) - -// TestDoKill checks that when "runsc do..." is killed, the sandbox process is -// also terminated. This ensures that parent death signal is propagate to the -// sandbox process correctly. -func TestDoKill(t *testing.T) { - // Make the sandbox process be reparented here when it's killed, so we can - // wait for it. - if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil { - t.Fatalf("prctl(PR_SET_CHILD_SUBREAPER): %v", err) - } - - cmd := exec.Command(specutils.ExePath, "do", "sleep", "10000") - buf := &bytes.Buffer{} - cmd.Stdout = buf - cmd.Stderr = buf - cmd.Start() - - var pid int - findSandbox := func() error { - var err error - pid, err = sandboxPid(cmd.Process.Pid) - if err != nil { - return &backoff.PermanentError{Err: err} - } - if pid == 0 { - return fmt.Errorf("sandbox process not found") - } - return nil - } - if err := testutil.Poll(findSandbox, 10*time.Second); err != nil { - t.Fatalf("failed to find sandbox: %v", err) - } - t.Logf("Found sandbox, pid: %d", pid) - - if err := cmd.Process.Kill(); err != nil { - t.Fatalf("failed to kill run process: %v", err) - } - cmd.Wait() - t.Logf("Parent process killed (%d). Output: %s", cmd.Process.Pid, buf.String()) - - ch := make(chan struct{}) - go func() { - defer func() { ch <- struct{}{} }() - t.Logf("Waiting for sandbox process (%d) termination", pid) - if _, err := unix.Wait4(pid, nil, 0, nil); err != nil { - t.Errorf("error waiting for sandbox process (%d): %v", pid, err) - } - }() - select { - case <-ch: - // Done - case <-time.After(5 * time.Second): - t.Fatalf("timeout waiting for sandbox process (%d) to exit", pid) - } -} - -// sandboxPid looks for the sandbox process inside the process tree starting -// from "pid". It returns 0 and no error if no sandbox process is found. It -// returns error if anything failed. -func sandboxPid(pid int) (int, error) { - cmd := exec.Command("pgrep", "-P", strconv.Itoa(pid)) - buf := &bytes.Buffer{} - cmd.Stdout = buf - if err := cmd.Start(); err != nil { - return 0, err - } - ps, err := cmd.Process.Wait() - if err != nil { - return 0, err - } - if ps.ExitCode() == 1 { - // pgrep returns 1 when no process is found. - return 0, nil - } - - var children []int - for _, line := range strings.Split(buf.String(), "\n") { - if len(line) == 0 { - continue - } - child, err := strconv.Atoi(line) - if err != nil { - return 0, err - } - - cmdline, err := ioutil.ReadFile(filepath.Join("/proc", line, "cmdline")) - if err != nil { - if os.IsNotExist(err) { - // Raced with process exit. - continue - } - return 0, err - } - args := strings.SplitN(string(cmdline), "\x00", 2) - if len(args) == 0 { - return 0, fmt.Errorf("malformed cmdline file: %q", cmdline) - } - // The sandbox process has the first argument set to "runsc-sandbox". - if args[0] == "runsc-sandbox" { - return child, nil - } - - children = append(children, child) - } - - // Sandbox process wasn't found, try another level down. - for _, pid := range children { - sand, err := sandboxPid(pid) - if err != nil { - return 0, err - } - if sand != 0 { - return sand, nil - } - // Not found, continue the search. - } - return 0, nil -} diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD deleted file mode 100644 index 6859541ad..000000000 --- a/test/root/testdata/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testdata", - srcs = [ - "busybox.go", - "containerd_config.go", - "httpd.go", - "httpd_mount_paths.go", - "sandbox.go", - "simple.go", - ], - visibility = [ - "//:sandbox", - ], -) diff --git a/test/root/testdata/busybox.go b/test/root/testdata/busybox.go deleted file mode 100644 index e4dbd2843..000000000 --- a/test/root/testdata/busybox.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testdata - -// MountOverSymlink is a JSON config for a container that /etc/resolv.conf is a -// symlink to /tmp/resolv.conf. -var MountOverSymlink = ` -{ - "metadata": { - "name": "busybox" - }, - "image": { - "image": "k8s.gcr.io/busybox" - }, - "command": [ - "sleep", - "1000" - ] -} -` diff --git a/test/root/testdata/containerd_config.go b/test/root/testdata/containerd_config.go deleted file mode 100644 index e12f1ec88..000000000 --- a/test/root/testdata/containerd_config.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testdata contains data required for root tests. -package testdata - -import "fmt" - -// containerdConfigTemplate is a .toml config for containerd. It contains a -// formatting verb so the runtime field can be set via fmt.Sprintf. -const containerdConfigTemplate = ` -disabled_plugins = ["restart"] -[plugins.linux] - runtime = "%s" - runtime_root = "/tmp/test-containerd/runsc" - shim = "/usr/local/bin/gvisor-containerd-shim" - shim_debug = true - -[plugins.cri.containerd.runtimes.runsc] - runtime_type = "io.containerd.runtime.v1.linux" - runtime_engine = "%s" -` - -// ContainerdConfig returns a containerd config file with the specified -// runtime. -func ContainerdConfig(runtime string) string { - return fmt.Sprintf(containerdConfigTemplate, runtime, runtime) -} diff --git a/test/root/testdata/httpd.go b/test/root/testdata/httpd.go deleted file mode 100644 index 45d5e33d4..000000000 --- a/test/root/testdata/httpd.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testdata - -// Httpd is a JSON config for an httpd container. -const Httpd = ` -{ - "metadata": { - "name": "httpd" - }, - "image":{ - "image": "httpd" - }, - "mounts": [ - ], - "linux": { - }, - "log_path": "httpd.log" -} -` diff --git a/test/root/testdata/httpd_mount_paths.go b/test/root/testdata/httpd_mount_paths.go deleted file mode 100644 index ac3f4446a..000000000 --- a/test/root/testdata/httpd_mount_paths.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testdata - -// HttpdMountPaths is a JSON config for an httpd container with additional -// mounts. -const HttpdMountPaths = ` -{ - "metadata": { - "name": "httpd" - }, - "image":{ - "image": "httpd" - }, - "mounts": [ - { - "container_path": "/var/run/secrets/kubernetes.io/serviceaccount", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/volumes/kubernetes.io~secret/default-token-2rpfx", - "readonly": true - }, - { - "container_path": "/etc/hosts", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/etc-hosts", - "readonly": false - }, - { - "container_path": "/dev/termination-log", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/containers/httpd/d1709580", - "readonly": false - }, - { - "container_path": "/usr/local/apache2/htdocs/test", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064", - "readonly": true - } - ], - "linux": { - }, - "log_path": "httpd.log" -} -` diff --git a/test/root/testdata/sandbox.go b/test/root/testdata/sandbox.go deleted file mode 100644 index 0db210370..000000000 --- a/test/root/testdata/sandbox.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testdata - -// Sandbox is a default JSON config for a sandbox. -const Sandbox = ` -{ - "metadata": { - "name": "default-sandbox", - "namespace": "default", - "attempt": 1, - "uid": "hdishd83djaidwnduwk28bcsb" - }, - "linux": { - }, - "log_directory": "/tmp" -} -` diff --git a/test/root/testdata/simple.go b/test/root/testdata/simple.go deleted file mode 100644 index 1cca53f0c..000000000 --- a/test/root/testdata/simple.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testdata - -import ( - "encoding/json" - "fmt" -) - -// SimpleSpec returns a JSON config for a simple container that runs the -// specified command in the specified image. -func SimpleSpec(name, image string, cmd []string) string { - cmds, err := json.Marshal(cmd) - if err != nil { - // This shouldn't happen. - panic(err) - } - return fmt.Sprintf(` -{ - "metadata": { - "name": %q - }, - "image": { - "image": %q - }, - "command": %s - } -`, name, image, cmds) -} diff --git a/test/runner/BUILD b/test/runner/BUILD deleted file mode 100644 index 9959ef9b0..000000000 --- a/test/runner/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "runner", - testonly = 1, - srcs = ["runner.go"], - data = [ - "//runsc", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//runsc/specutils", - "//runsc/testutil", - "//test/runner/gtest", - "//test/uds", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl deleted file mode 100644 index 56743a526..000000000 --- a/test/runner/defs.bzl +++ /dev/null @@ -1,198 +0,0 @@ -"""Defines a rule for syscall test targets.""" - -load("//tools:defs.bzl", "default_platform", "loopback", "platforms") - -def _runner_test_impl(ctx): - # Generate a runner binary. - runner = ctx.actions.declare_file("%s-runner" % ctx.label.name) - runner_content = "\n".join([ - "#!/bin/bash", - "set -euf -x -o pipefail", - "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then", - " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", - " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", - "fi", - "exec %s %s %s\n" % ( - ctx.files.runner[0].short_path, - " ".join(ctx.attr.runner_args), - ctx.files.test[0].short_path, - ), - ]) - ctx.actions.write(runner, runner_content, is_executable = True) - - # Return with all transitive files. - runfiles = ctx.runfiles( - transitive_files = depset(transitive = [ - depset(target.data_runfiles.files) - for target in (ctx.attr.runner, ctx.attr.test) - if hasattr(target, "data_runfiles") - ]), - files = ctx.files.runner + ctx.files.test, - collect_default = True, - collect_data = True, - ) - return [DefaultInfo(executable = runner, runfiles = runfiles)] - -_runner_test = rule( - attrs = { - "runner": attr.label( - default = "//test/runner:runner", - ), - "test": attr.label( - mandatory = True, - ), - "runner_args": attr.string_list(), - "data": attr.label_list( - allow_files = True, - ), - }, - test = True, - implementation = _runner_test_impl, -) - -def _syscall_test( - test, - shard_count, - size, - platform, - use_tmpfs, - tags, - network = "none", - file_access = "exclusive", - overlay = False, - add_uds_tree = False): - # Prepend "runsc" to non-native platform names. - full_platform = platform if platform == "native" else "runsc_" + platform - - # Name the test appropriately. - name = test.split(":")[1] + "_" + full_platform - if file_access == "shared": - name += "_shared" - if overlay: - name += "_overlay" - if network != "none": - name += "_" + network + "net" - - # Apply all tags. - if tags == None: - tags = [] - - # Add the full_platform and file access in a tag to make it easier to run - # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. - tags += [full_platform, "file_" + file_access] - - # Hash this target into one of 15 buckets. This can be used to - # randomly split targets between different workflows. - hash15 = hash(native.package_name() + name) % 15 - tags.append("hash15:" + str(hash15)) - - # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until - # we figure out how to request ipv4 sockets on Guitar machines. - if network == "host": - tags.append("noguitar") - - # Disable off-host networking. - tags.append("requires-net:loopback") - - runner_args = [ - # Arguments are passed directly to runner binary. - "--platform=" + platform, - "--network=" + network, - "--use-tmpfs=" + str(use_tmpfs), - "--file-access=" + file_access, - "--overlay=" + str(overlay), - "--add-uds-tree=" + str(add_uds_tree), - ] - - # Call the rule above. - _runner_test( - name = name, - test = test, - runner_args = runner_args, - data = [loopback], - size = size, - tags = tags, - shard_count = shard_count, - ) - -def syscall_test( - test, - shard_count = 5, - size = "small", - use_tmpfs = False, - add_overlay = False, - add_uds_tree = False, - add_hostinet = False, - tags = None): - """syscall_test is a macro that will create targets for all platforms. - - Args: - test: the test target. - shard_count: shards for defined tests. - size: the defined test size. - use_tmpfs: use tmpfs in the defined tests. - add_overlay: add an overlay test. - add_uds_tree: add a UDS test. - add_hostinet: add a hostinet test. - tags: starting test tags. - """ - if not tags: - tags = [] - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "native", - use_tmpfs = False, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - for (platform, platform_tags) in platforms.items(): - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = platform, - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = platform_tags + tags, - ) - - if add_overlay: - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = default_platform, - use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. - add_uds_tree = add_uds_tree, - tags = platforms[default_platform] + tags, - overlay = True, - ) - - if not use_tmpfs: - # Also test shared gofer access. - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = default_platform, - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = platforms[default_platform] + tags, - file_access = "shared", - ) - - if add_hostinet: - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = default_platform, - use_tmpfs = use_tmpfs, - network = "host", - add_uds_tree = add_uds_tree, - tags = platforms[default_platform] + tags, - ) diff --git a/test/runner/gtest/BUILD b/test/runner/gtest/BUILD deleted file mode 100644 index de4b2727c..000000000 --- a/test/runner/gtest/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "gtest", - srcs = ["gtest.go"], - visibility = ["//:sandbox"], -) diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go deleted file mode 100644 index 869169ad5..000000000 --- a/test/runner/gtest/gtest.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gtest contains helpers for running google-test tests from Go. -package gtest - -import ( - "fmt" - "os/exec" - "strings" -) - -var ( - // listTestFlag is the flag that will list tests in gtest binaries. - listTestFlag = "--gtest_list_tests" - - // filterTestFlag is the flag that will filter tests in gtest binaries. - filterTestFlag = "--gtest_filter" - - // listBechmarkFlag is the flag that will list benchmarks in gtest binaries. - listBenchmarkFlag = "--benchmark_list_tests" - - // filterBenchmarkFlag is the flag that will run specified benchmarks. - filterBenchmarkFlag = "--benchmark_filter" -) - -// TestCase is a single gtest test case. -type TestCase struct { - // Suite is the suite for this test. - Suite string - - // Name is the name of this individual test. - Name string - - // all indicates that this will run without flags. This takes - // precendence over benchmark below. - all bool - - // benchmark indicates that this is a benchmark. In this case, the - // suite will be empty, and we will use the appropriate test and - // benchmark flags. - benchmark bool -} - -// FullName returns the name of the test including the suite. It is suitable to -// pass to "-gtest_filter". -func (tc TestCase) FullName() string { - return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) -} - -// Args returns arguments to be passed when invoking the test. -func (tc TestCase) Args() []string { - if tc.all { - return []string{} // No arguments. - } - if tc.benchmark { - return []string{ - fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), - fmt.Sprintf("%s=", filterTestFlag), - } - } - return []string{ - fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()), - } -} - -// ParseTestCases calls a gtest test binary to list its test and returns a -// slice with the name and suite of each test. -// -// If benchmarks is true, then benchmarks will be included in the list of test -// cases provided. Note that this requires the binary to support the -// benchmarks_list_tests flag. -func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) { - // Run to extract test cases. - args := append([]string{listTestFlag}, extraArgs...) - cmd := exec.Command(testBin, args...) - out, err := cmd.Output() - if err != nil { - // We failed to list tests with the given flags. Just - // return something that will run the binary with no - // flags, which should execute all tests. - return []TestCase{ - TestCase{ - Suite: "Default", - Name: "All", - all: true, - }, - }, nil - } - - // Parse test output. - var t []TestCase - var suite string - for _, line := range strings.Split(string(out), "\n") { - // Strip comments. - line = strings.Split(line, "#")[0] - - // New suite? - if !strings.HasPrefix(line, " ") { - suite = strings.TrimSuffix(strings.TrimSpace(line), ".") - continue - } - - // Individual test. - name := strings.TrimSpace(line) - - // Do we have a suite yet? - if suite == "" { - return nil, fmt.Errorf("test without a suite: %v", name) - } - - // Add this individual test. - t = append(t, TestCase{ - Suite: suite, - Name: name, - }) - } - - // Finished? - if !benchmarks { - return t, nil - } - - // Run again to extract benchmarks. - args = append([]string{listBenchmarkFlag}, extraArgs...) - cmd = exec.Command(testBin, args...) - out, err = cmd.Output() - if err != nil { - // We were able to enumerate tests above, but not benchmarks? - // We requested them, so we return an error in this case. - exitErr, ok := err.(*exec.ExitError) - if !ok { - return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err) - } - return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) - } - - out = []byte(strings.Trim(string(out), "\n")) - - // Parse benchmark output. - for _, line := range strings.Split(string(out), "\n") { - // Strip comments. - line = strings.Split(line, "#")[0] - - // Single benchmark. - name := strings.TrimSpace(line) - - // Add the single benchmark. - t = append(t, TestCase{ - Suite: "Benchmarks", - Name: name, - benchmark: true, - }) - } - - return t, nil -} diff --git a/test/runner/runner.go b/test/runner/runner.go deleted file mode 100644 index a78ef38e0..000000000 --- a/test/runner/runner.go +++ /dev/null @@ -1,477 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary syscall_test_runner runs the syscall test suites in gVisor -// containers and on the host platform. -package main - -import ( - "flag" - "fmt" - "io/ioutil" - "os" - "os/exec" - "os/signal" - "path/filepath" - "strings" - "syscall" - "testing" - "time" - - specs "github.com/opencontainers/runtime-spec/specs-go" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/runner/gtest" - "gvisor.dev/gvisor/test/uds" -) - -var ( - debug = flag.Bool("debug", false, "enable debug logs") - strace = flag.Bool("strace", false, "enable strace logs") - platform = flag.String("platform", "ptrace", "platform to run on") - network = flag.String("network", "none", "network stack to run on (sandbox, host, none)") - useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp") - fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") - overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") - parallel = flag.Bool("parallel", false, "run tests in parallel") - runscPath = flag.String("runsc", "", "path to runsc binary") - - addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests") -) - -// runTestCaseNative runs the test case directly on the host machine. -func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { - // These tests might be running in parallel, so make sure they have a - // unique test temp dir. - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") - if err != nil { - t.Fatalf("could not create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - // Replace TEST_TMPDIR in the current environment with something - // unique. - env := os.Environ() - newEnvVar := "TEST_TMPDIR=" + tmpDir - var found bool - for i, kv := range env { - if strings.HasPrefix(kv, "TEST_TMPDIR=") { - env[i] = newEnvVar - found = true - break - } - } - if !found { - env = append(env, newEnvVar) - } - // Remove env variables that cause the gunit binary to write output - // files, since they will stomp on eachother, and on the output files - // from this go test. - env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) - - // Remove shard env variables so that the gunit binary does not try to - // intepret them. - env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) - - if *addUDSTree { - socketDir, cleanup, err := uds.CreateSocketTree("/tmp") - if err != nil { - t.Fatalf("failed to create socket tree: %v", err) - } - defer cleanup() - - env = append(env, "TEST_UDS_TREE="+socketDir) - // On Linux, the concept of "attach" location doesn't exist. - // Just pass the same path to make these test identical. - env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) - } - - cmd := exec.Command(testBin, tc.Args()...) - cmd.Env = env - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) - t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) - } -} - -// runRunsc runs spec in runsc in a standard test configuration. -// -// runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR. -// -// Returns an error if the sandboxed application exits non-zero. -func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { - bundleDir, err := testutil.SetupBundleDir(spec) - if err != nil { - return fmt.Errorf("SetupBundleDir failed: %v", err) - } - defer os.RemoveAll(bundleDir) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - return fmt.Errorf("SetupRootDir failed: %v", err) - } - defer os.RemoveAll(rootDir) - - name := tc.FullName() - id := testutil.UniqueContainerID() - log.Infof("Running test %q in container %q", name, id) - specutils.LogSpec(spec) - - args := []string{ - "-root", rootDir, - "-network", *network, - "-log-format=text", - "-TESTONLY-unsafe-nonroot=true", - "-net-raw=true", - fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM), - "-watchdog-action=panic", - "-platform", *platform, - "-file-access", *fileAccess, - } - if *overlay { - args = append(args, "-overlay") - } - if *debug { - args = append(args, "-debug", "-log-packets=true") - } - if *strace { - args = append(args, "-strace") - } - if *addUDSTree { - args = append(args, "-fsgofer-host-uds") - } - - if outDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { - tdir := filepath.Join(outDir, strings.Replace(name, "/", "_", -1)) - if err := os.MkdirAll(tdir, 0755); err != nil { - return fmt.Errorf("could not create test dir: %v", err) - } - debugLogDir, err := ioutil.TempDir(tdir, "runsc") - if err != nil { - return fmt.Errorf("could not create temp dir: %v", err) - } - debugLogDir += "/" - log.Infof("runsc logs: %s", debugLogDir) - args = append(args, "-debug-log", debugLogDir) - - // Default -log sends messages to stderr which makes reading the test log - // difficult. Instead, drop them when debug log is enabled given it's a - // better place for these messages. - args = append(args, "-log=/dev/null") - } - - // Current process doesn't have CAP_SYS_ADMIN, create user namespace and run - // as root inside that namespace to get it. - rArgs := append(args, "run", "--bundle", bundleDir, id) - cmd := exec.Command(*runscPath, rArgs...) - cmd.SysProcAttr = &syscall.SysProcAttr{ - Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS, - // Set current user/group as root inside the namespace. - UidMappings: []syscall.SysProcIDMap{ - {ContainerID: 0, HostID: os.Getuid(), Size: 1}, - }, - GidMappings: []syscall.SysProcIDMap{ - {ContainerID: 0, HostID: os.Getgid(), Size: 1}, - }, - GidMappingsEnableSetgroups: false, - Credential: &syscall.Credential{ - Uid: 0, - Gid: 0, - }, - } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGTERM) - go func() { - s, ok := <-sig - if !ok { - return - } - log.Warningf("%s: Got signal: %v", name, s) - done := make(chan bool) - dArgs := append([]string{}, args...) - dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) - go func(dArgs []string) { - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() - done <- true - }(dArgs) - - timeout := time.After(3 * time.Second) - select { - case <-timeout: - log.Infof("runsc debug --stacks is timeouted") - case <-done: - } - - log.Warningf("Send SIGTERM to the sandbox process") - dArgs = append(args, "debug", - fmt.Sprintf("--signal=%d", syscall.SIGTERM), - id) - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() - }() - - err = cmd.Run() - - signal.Stop(sig) - close(sig) - - return err -} - -// setupUDSTree updates the spec to expose a UDS tree for gofer socket testing. -func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { - socketDir, cleanup, err := uds.CreateSocketTree("/tmp") - if err != nil { - return nil, fmt.Errorf("failed to create socket tree: %v", err) - } - - // Standard access to entire tree. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets", - Source: socketDir, - Type: "bind", - }) - - // Individial attach points for each socket to test mounts that attach - // directly to the sockets. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/stream/echo", - Source: filepath.Join(socketDir, "stream/echo"), - Type: "bind", - }) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/stream/nonlistening", - Source: filepath.Join(socketDir, "stream/nonlistening"), - Type: "bind", - }) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/seqpacket/echo", - Source: filepath.Join(socketDir, "seqpacket/echo"), - Type: "bind", - }) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/seqpacket/nonlistening", - Source: filepath.Join(socketDir, "seqpacket/nonlistening"), - Type: "bind", - }) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/dgram/null", - Source: filepath.Join(socketDir, "dgram/null"), - Type: "bind", - }) - - spec.Process.Env = append(spec.Process.Env, "TEST_UDS_TREE=/tmp/sockets") - spec.Process.Env = append(spec.Process.Env, "TEST_UDS_ATTACH_TREE=/tmp/sockets-attach") - - return cleanup, nil -} - -// runsTestCaseRunsc runs the test case in runsc. -func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { - // Run a new container with the test executable and filter for the - // given test suite and name. - spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...) - - // Mark the root as writeable, as some tests attempt to - // write to the rootfs, and expect EACCES, not EROFS. - spec.Root.Readonly = false - - // Test spec comes with pre-defined mounts that we don't want. Reset it. - spec.Mounts = nil - if *useTmpfs { - // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on - // features only available in gVisor's internal tmpfs may fail. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp", - Type: "tmpfs", - }) - } else { - // Use a gofer-backed directory as '/tmp'. - // - // Tests might be running in parallel, so make sure each has a - // unique test temp dir. - // - // Some tests (e.g., sticky) access this mount from other - // users, so make sure it is world-accessible. - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") - if err != nil { - t.Fatalf("could not create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - if err := os.Chmod(tmpDir, 0777); err != nil { - t.Fatalf("could not chmod temp dir: %v", err) - } - - spec.Mounts = append(spec.Mounts, specs.Mount{ - Type: "bind", - Destination: "/tmp", - Source: tmpDir, - }) - } - - // Set environment variables that indicate we are - // running in gVisor with the given platform and network. - platformVar := "TEST_ON_GVISOR" - networkVar := "GVISOR_NETWORK" - env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) - - // Remove env variables that cause the gunit binary to write output - // files, since they will stomp on eachother, and on the output files - // from this go test. - env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) - - // Remove shard env variables so that the gunit binary does not try to - // intepret them. - env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) - - // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to - // be backed by tmpfs. - for i, kv := range env { - if strings.HasPrefix(kv, "TEST_TMPDIR=") { - env[i] = "TEST_TMPDIR=/tmp" - break - } - } - - spec.Process.Env = env - - if *addUDSTree { - cleanup, err := setupUDSTree(spec) - if err != nil { - t.Fatalf("error creating UDS tree: %v", err) - } - defer cleanup() - } - - if err := runRunsc(tc, spec); err != nil { - t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err) - } -} - -// filterEnv returns an environment with the blacklisted variables removed. -func filterEnv(env, blacklist []string) []string { - var out []string - for _, kv := range env { - ok := true - for _, k := range blacklist { - if strings.HasPrefix(kv, k+"=") { - ok = false - break - } - } - if ok { - out = append(out, kv) - } - } - return out -} - -func fatalf(s string, args ...interface{}) { - fmt.Fprintf(os.Stderr, s+"\n", args...) - os.Exit(1) -} - -func matchString(a, b string) (bool, error) { - return a == b, nil -} - -func main() { - flag.Parse() - if flag.NArg() != 1 { - fatalf("test must be provided") - } - testBin := flag.Args()[0] // Only argument. - - log.SetLevel(log.Info) - if *debug { - log.SetLevel(log.Debug) - } - - if *platform != "native" && *runscPath == "" { - if err := testutil.ConfigureExePath(); err != nil { - panic(err.Error()) - } - *runscPath = specutils.ExePath - } - - // Make sure stdout and stderr are opened with O_APPEND, otherwise logs - // from outside the sandbox can (and will) stomp on logs from inside - // the sandbox. - for _, f := range []*os.File{os.Stdout, os.Stderr} { - flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) - if err != nil { - fatalf("error getting file flags for %v: %v", f, err) - } - if flags&unix.O_APPEND == 0 { - flags |= unix.O_APPEND - if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { - fatalf("error setting file flags for %v: %v", f, err) - } - } - } - - // Get all test cases in each binary. - testCases, err := gtest.ParseTestCases(testBin, true) - if err != nil { - fatalf("ParseTestCases(%q) failed: %v", testBin, err) - } - - // Get subset of tests corresponding to shard. - indices, err := testutil.TestIndicesForShard(len(testCases)) - if err != nil { - fatalf("TestsForShard() failed: %v", err) - } - - // Resolve the absolute path for the binary. - testBin, err = filepath.Abs(testBin) - if err != nil { - fatalf("Abs() failed: %v", err) - } - - // Run the tests. - var tests []testing.InternalTest - for _, tci := range indices { - // Capture tc. - tc := testCases[tci] - tests = append(tests, testing.InternalTest{ - Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name), - F: func(t *testing.T) { - if *parallel { - t.Parallel() - } - if *platform == "native" { - // Run the test case on host. - runTestCaseNative(testBin, tc, t) - } else { - // Run the test case in runsc. - runTestCaseRunsc(testBin, tc, t) - } - }, - }) - } - - testing.Main(matchString, tests, nil, nil) -} diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD deleted file mode 100644 index 2c472bf8d..000000000 --- a/test/runtimes/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -# These packages are used to run language runtime tests inside gVisor sandboxes. - -load("//tools:defs.bzl", "go_binary", "go_test") -load("//test/runtimes:build_defs.bzl", "runtime_test") - -package(licenses = ["notice"]) - -go_binary( - name = "runner", - testonly = 1, - srcs = ["runner.go"], - deps = [ - "//runsc/dockerutil", - "//runsc/testutil", - ], -) - -runtime_test( - name = "go1.12", - blacklist_file = "blacklist_go1.12.csv", - lang = "go", -) - -runtime_test( - name = "java11", - blacklist_file = "blacklist_java11.csv", - lang = "java", -) - -runtime_test( - name = "nodejs12.4.0", - blacklist_file = "blacklist_nodejs12.4.0.csv", - lang = "nodejs", -) - -runtime_test( - name = "php7.3.6", - blacklist_file = "blacklist_php7.3.6.csv", - lang = "php", -) - -runtime_test( - name = "python3.7.3", - blacklist_file = "blacklist_python3.7.3.csv", - lang = "python", -) - -go_test( - name = "blacklist_test", - size = "small", - srcs = ["blacklist_test.go"], - library = ":runner", -) diff --git a/test/runtimes/README.md b/test/runtimes/README.md deleted file mode 100644 index 42d722553..000000000 --- a/test/runtimes/README.md +++ /dev/null @@ -1,56 +0,0 @@ -# Runtimes Tests Dockerfiles - -The Dockerfiles defined under this path are configured to host the execution of -the runtimes language tests. Each Dockerfile can support the language indicated -by its directory. - -The following runtimes are currently supported: - -- Go 1.12 -- Java 11 -- Node.js 12 -- PHP 7.3 -- Python 3.7 - -### Building and pushing the images: - -The canonical source of images is the -[gvisor-presubmit container registry](https://gcr.io/gvisor-presubmit/). You can -build new images with the following command: - -```bash -$ cd images -$ docker build -f Dockerfile_$LANG [-t $NAME] . -``` - -To push them to our container registry, set the tag in the command above to -`gcr.io/gvisor-presubmit/$LANG`, then push them. (Note that you will need -appropriate permissions to the `gvisor-presubmit` GCP project.) - -```bash -gcloud docker -- push gcr.io/gvisor-presubmit/$LANG -``` - -#### Running in Docker locally: - -1) [Install and configure Docker](https://docs.docker.com/install/) - -2) Pull the image you want to run: - -```bash -$ docker pull gcr.io/gvisor-presubmit/$LANG -``` - -3) Run docker with the image. - -```bash -$ docker run [--runtime=runsc] --rm -it $NAME [FLAG] -``` - -Running the command with no flags will cause all the available tests to execute. - -Flags can be added for additional functionality: - -- --list: Print a list of all available tests -- --test <name>: Run a single test from the list of available tests -- --v: Print the language version diff --git a/test/runtimes/blacklist_go1.12.csv b/test/runtimes/blacklist_go1.12.csv deleted file mode 100644 index 8c8ae0c5d..000000000 --- a/test/runtimes/blacklist_go1.12.csv +++ /dev/null @@ -1,16 +0,0 @@ -test name,bug id,comment -cgo_errors,,FLAKY -cgo_test,,FLAKY -go_test:cmd/go,,FLAKY -go_test:cmd/vendor/golang.org/x/sys/unix,b/118783622,/dev devices missing -go_test:net,b/118784196,socket: invalid argument. Works as intended: see bug. -go_test:os,b/118780122,we have a pollable filesystem but that's a surprise -go_test:os/signal,b/118780860,/dev/pts not properly supported -go_test:runtime,b/118782341,sigtrap not reported or caught or something -go_test:syscall,b/118781998,bad bytes -- bad mem addr -race,b/118782931,thread sanitizer. Works as intended: b/62219744. -runtime:cpu124,b/118778254,segmentation fault -test:0_1,,FLAKY -testasan,, -testcarchive,b/118782924,no sigpipe -testshared,,FLAKY diff --git a/test/runtimes/blacklist_java11.csv b/test/runtimes/blacklist_java11.csv deleted file mode 100644 index c012e5a56..000000000 --- a/test/runtimes/blacklist_java11.csv +++ /dev/null @@ -1,126 +0,0 @@ -test name,bug id,comment -com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker -com/sun/jdi/NashornPopFrameTest.java,, -com/sun/jdi/ProcessAttachTest.java,, -com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker -com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,, -com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,, -com/sun/tools/attach/AttachSelf.java,, -com/sun/tools/attach/BasicTests.java,, -com/sun/tools/attach/PermissionTest.java,, -com/sun/tools/attach/StartManagementAgent.java,, -com/sun/tools/attach/TempDirTest.java,, -com/sun/tools/attach/modules/Driver.java,, -java/lang/Character/CheckScript.java,,Fails in Docker -java/lang/Character/CheckUnicode.java,,Fails in Docker -java/lang/Class/GetPackageBootLoaderChildLayer.java,, -java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker -java/lang/String/nativeEncoding/StringPlatformChars.java,, -java/net/DatagramSocket/ReuseAddressTest.java,, -java/net/DatagramSocket/SendDatagramToBadAddress.java,b/78473345, -java/net/Inet4Address/PingThis.java,, -java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103, -java/net/MulticastSocket/MulticastTTL.java,, -java/net/MulticastSocket/Promiscuous.java,, -java/net/MulticastSocket/SetLoopbackMode.java,, -java/net/MulticastSocket/SetTTLAndGetTTL.java,, -java/net/MulticastSocket/Test.java,, -java/net/MulticastSocket/TestDefaults.java,, -java/net/MulticastSocket/TimeToLive.java,, -java/net/NetworkInterface/NetworkInterfaceStreamTest.java,, -java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported -java/net/Socket/TrafficClass.java,b/78527818,Not supported on gVisor -java/net/Socket/UrgentDataTest.java,b/111515323, -java/net/Socket/setReuseAddress/Basic.java,b/78519214,SO_REUSEADDR enabled by default -java/net/SocketOption/OptionsTest.java,,Fails in Docker -java/net/SocketOption/TcpKeepAliveTest.java,, -java/net/SocketPermission/SocketPermissionTest.java,, -java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker -java/net/httpclient/RequestBuilderTest.java,,Fails in Docker -java/net/httpclient/ShortResponseBody.java,, -java/net/httpclient/ShortResponseBodyWithRetry.java,, -java/nio/channels/AsyncCloseAndInterrupt.java,, -java/nio/channels/AsynchronousServerSocketChannel/Basic.java,, -java/nio/channels/AsynchronousSocketChannel/Basic.java,b/77921528,SO_KEEPALIVE is not settable -java/nio/channels/DatagramChannel/BasicMulticastTests.java,, -java/nio/channels/DatagramChannel/SocketOptionTests.java,,Fails in Docker -java/nio/channels/DatagramChannel/UseDGWithIPv6.java,, -java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker -java/nio/channels/Selector/OutOfBand.java,, -java/nio/channels/Selector/SelectWithConsumer.java,,Flaky -java/nio/channels/ServerSocketChannel/SocketOptionTests.java,, -java/nio/channels/SocketChannel/LingerOnClose.java,, -java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901, -java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker -java/rmi/activation/Activatable/extLoadedImpl/ext.sh,, -java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,, -java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker -java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker -java/util/Calendar/JapaneseEraNameTest.java,, -java/util/Currency/CurrencyTest.java,,Fails in Docker -java/util/Currency/ValidateISO4217.java,,Fails in Docker -java/util/Locale/LSRDataTest.java,, -java/util/concurrent/locks/Lock/TimedAcquireLeak.java,, -java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker -java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,, -java/util/logging/TestLoggerWeakRefLeak.java,, -javax/imageio/AppletResourceTest.java,, -javax/management/security/HashedPasswordFileTest.java,, -javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker -javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,, -jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,, -jdk/jfr/event/runtime/TestThreadParkEvent.java,, -jdk/jfr/event/sampling/TestNative.java,, -jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,, -jdk/jfr/jcmd/TestJcmdConfigure.java,, -jdk/jfr/jcmd/TestJcmdDump.java,, -jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,, -jdk/jfr/jcmd/TestJcmdDumpLimited.java,, -jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,, -jdk/jfr/jcmd/TestJcmdLegacy.java,, -jdk/jfr/jcmd/TestJcmdSaveToFile.java,, -jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,, -jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,, -jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,, -jdk/jfr/jcmd/TestJcmdStartStopDefault.java,, -jdk/jfr/jcmd/TestJcmdStartWithOptions.java,, -jdk/jfr/jcmd/TestJcmdStartWithSettings.java,, -jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,, -jdk/jfr/jvm/TestJfrJavaBase.java,, -jdk/jfr/startupargs/TestStartRecording.java,, -jdk/modules/incubator/ImageModules.java,, -jdk/net/Sockets/ExtOptionTest.java,, -jdk/net/Sockets/QuickAckTest.java,, -lib/security/cacerts/VerifyCACerts.java,, -sun/management/jmxremote/bootstrap/CustomLauncherTest.java,, -sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,, -sun/management/jmxremote/bootstrap/LocalManagementTest.java,, -sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,, -sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,, -sun/management/jmxremote/startstop/JMXStartStopTest.java,, -sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,, -sun/management/jmxremote/startstop/JMXStatusTest.java,, -sun/text/resources/LocaleDataTest.java,, -sun/tools/jcmd/TestJcmdSanity.java,, -sun/tools/jhsdb/AlternateHashingTest.java,, -sun/tools/jhsdb/BasicLauncherTest.java,, -sun/tools/jhsdb/HeapDumpTest.java,, -sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,, -sun/tools/jinfo/BasicJInfoTest.java,, -sun/tools/jinfo/JInfoTest.java,, -sun/tools/jmap/BasicJMapTest.java,, -sun/tools/jstack/BasicJStackTest.java,, -sun/tools/jstack/DeadlockDetectionTest.java,, -sun/tools/jstatd/TestJstatdExternalRegistry.java,, -sun/tools/jstatd/TestJstatdPort.java,,Flaky -sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky -sun/util/calendar/zi/TestZoneInfo310.java,, -tools/jar/modularJar/Basic.java,, -tools/jar/multiRelease/Basic.java,, -tools/jimage/JImageExtractTest.java,, -tools/jimage/JImageTest.java,, -tools/jlink/JLinkTest.java,, -tools/jlink/plugins/IncludeLocalesPluginTest.java,, -tools/jmod/hashes/HashesTest.java,, -tools/launcher/BigJar.java,b/111611473, -tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,, diff --git a/test/runtimes/blacklist_nodejs12.4.0.csv b/test/runtimes/blacklist_nodejs12.4.0.csv deleted file mode 100644 index 4ab4e2927..000000000 --- a/test/runtimes/blacklist_nodejs12.4.0.csv +++ /dev/null @@ -1,47 +0,0 @@ -test name,bug id,comment -benchmark/test-benchmark-fs.js,, -benchmark/test-benchmark-module.js,, -benchmark/test-benchmark-napi.js,, -doctool/test-make-doc.js,b/68848110,Expected to fail. -fixtures/test-error-first-line-offset.js,, -fixtures/test-fs-readfile-error.js,, -fixtures/test-fs-stat-sync-overflow.js,, -internet/test-dgram-broadcast-multi-process.js,, -internet/test-dgram-multicast-multi-process.js,, -internet/test-dgram-multicast-set-interface-lo.js,, -parallel/test-cluster-dgram-reuse.js,b/64024294, -parallel/test-dgram-bind-fd.js,b/132447356, -parallel/test-dgram-create-socket-handle-fd.js,b/132447238, -parallel/test-dgram-createSocket-type.js,b/68847739, -parallel/test-dgram-socket-buffer-size.js,b/68847921, -parallel/test-fs-access.js,, -parallel/test-fs-write-stream-double-close.js,, -parallel/test-fs-write-stream-throw-type-error.js,b/110226209, -parallel/test-fs-write-stream.js,, -parallel/test-http2-respond-file-error-pipe-offset.js,, -parallel/test-os.js,, -parallel/test-process-uid-gid.js,, -pseudo-tty/test-assert-colors.js,, -pseudo-tty/test-assert-no-color.js,, -pseudo-tty/test-assert-position-indicator.js,, -pseudo-tty/test-async-wrap-getasyncid-tty.js,, -pseudo-tty/test-fatal-error.js,, -pseudo-tty/test-handle-wrap-isrefed-tty.js,, -pseudo-tty/test-readable-tty-keepalive.js,, -pseudo-tty/test-set-raw-mode-reset-process-exit.js,, -pseudo-tty/test-set-raw-mode-reset-signal.js,, -pseudo-tty/test-set-raw-mode-reset.js,, -pseudo-tty/test-stderr-stdout-handle-sigwinch.js,, -pseudo-tty/test-stdout-read.js,, -pseudo-tty/test-tty-color-support.js,, -pseudo-tty/test-tty-isatty.js,, -pseudo-tty/test-tty-stdin-call-end.js,, -pseudo-tty/test-tty-stdin-end.js,, -pseudo-tty/test-stdin-write.js,, -pseudo-tty/test-tty-stdout-end.js,, -pseudo-tty/test-tty-stdout-resize.js,, -pseudo-tty/test-tty-stream-constructors.js,, -pseudo-tty/test-tty-window-size.js,, -pseudo-tty/test-tty-wrap.js,, -pummel/test-net-pingpong.js,, -pummel/test-vm-memleak.js,, diff --git a/test/runtimes/blacklist_php7.3.6.csv b/test/runtimes/blacklist_php7.3.6.csv deleted file mode 100644 index 456bf7487..000000000 --- a/test/runtimes/blacklist_php7.3.6.csv +++ /dev/null @@ -1,29 +0,0 @@ -test name,bug id,comment -ext/intl/tests/bug77895.phpt,, -ext/intl/tests/dateformat_bug65683_2.phpt,, -ext/mbstring/tests/bug76319.phpt,, -ext/mbstring/tests/bug76958.phpt,, -ext/mbstring/tests/bug77025.phpt,, -ext/mbstring/tests/bug77165.phpt,, -ext/mbstring/tests/bug77454.phpt,, -ext/mbstring/tests/mb_convert_encoding_leak.phpt,, -ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, -ext/standard/tests/file/filetype_variation.phpt,, -ext/standard/tests/file/fopen_variation19.phpt,, -ext/standard/tests/file/php_fd_wrapper_01.phpt,, -ext/standard/tests/file/php_fd_wrapper_02.phpt,, -ext/standard/tests/file/php_fd_wrapper_03.phpt,, -ext/standard/tests/file/php_fd_wrapper_04.phpt,, -ext/standard/tests/file/realpath_bug77484.phpt,, -ext/standard/tests/file/rename_variation.phpt,b/68717309, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,, -ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, -ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, -ext/standard/tests/network/bug20134.phpt,, -tests/output/stream_isatty_err.phpt,b/68720279, -tests/output/stream_isatty_in-err.phpt,b/68720282, -tests/output/stream_isatty_in-out-err.phpt,, -tests/output/stream_isatty_in-out.phpt,b/68720299, -tests/output/stream_isatty_out-err.phpt,b/68720311, -tests/output/stream_isatty_out.phpt,b/68720325, diff --git a/test/runtimes/blacklist_python3.7.3.csv b/test/runtimes/blacklist_python3.7.3.csv deleted file mode 100644 index 2b9947212..000000000 --- a/test/runtimes/blacklist_python3.7.3.csv +++ /dev/null @@ -1,27 +0,0 @@ -test name,bug id,comment -test_asynchat,b/76031995,SO_REUSEADDR -test_asyncio,,Fails on Docker. -test_asyncore,b/76031995,SO_REUSEADDR -test_epoll,, -test_fcntl,,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. -test_ftplib,,Fails in Docker -test_httplib,b/76031995,SO_REUSEADDR -test_imaplib,, -test_logging,, -test_multiprocessing_fork,,Flaky. Sometimes times out. -test_multiprocessing_forkserver,,Flaky. Sometimes times out. -test_multiprocessing_main_handling,,Flaky. Sometimes times out. -test_multiprocessing_spawn,,Flaky. Sometimes times out. -test_nntplib,b/76031995,tests should not set SO_REUSEADDR -test_poplib,,Fails on Docker -test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted -test_pty,b/76157709,out of pty devices -test_readline,b/76157709,out of pty devices -test_resource,b/76174079, -test_selectors,b/76116849,OSError not raised with epoll -test_smtplib,b/76031995,SO_REUSEADDR and unclosed sockets -test_socket,b/75983380, -test_ssl,b/76031995,SO_REUSEADDR -test_subprocess,, -test_support,b/76031995,SO_REUSEADDR -test_telnetlib,b/76031995,SO_REUSEADDR diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/blacklist_test.go deleted file mode 100644 index 52f49b984..000000000 --- a/test/runtimes/blacklist_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "flag" - "os" - "testing" -) - -func TestMain(m *testing.M) { - flag.Parse() - os.Exit(m.Run()) -} - -// Test that the blacklist parses without error. -func TestBlacklists(t *testing.T) { - bl, err := getBlacklist() - if err != nil { - t.Fatalf("error parsing blacklist: %v", err) - } - if *blacklistFile != "" && len(bl) == 0 { - t.Errorf("got empty blacklist for file %q", blacklistFile) - } -} diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl deleted file mode 100644 index 92e275a76..000000000 --- a/test/runtimes/build_defs.bzl +++ /dev/null @@ -1,75 +0,0 @@ -"""Defines a rule for runtime test targets.""" - -load("//tools:defs.bzl", "go_test", "loopback") - -def runtime_test( - name, - lang, - image_repo = "gcr.io/gvisor-presubmit", - image_name = None, - blacklist_file = None, - shard_count = 50, - size = "enormous"): - """Generates sh_test and blacklist test targets for a given runtime. - - Args: - name: The name of the runtime being tested. Typically, the lang + version. - This is used in the names of the generated test targets. - lang: The language being tested. - image_repo: The docker repository containing the proctor image to run. - i.e., the prefix to the fully qualified docker image id. - image_name: The name of the image in the image_repo. - Defaults to the test name. - blacklist_file: A test blacklist to pass to the runtime test's runner. - shard_count: See Bazel common test attributes. - size: See Bazel common test attributes. - """ - if image_name == None: - image_name = name - args = [ - "--lang", - lang, - "--image", - "/".join([image_repo, image_name]), - ] - data = [ - ":runner", - loopback, - ] - if blacklist_file: - args += ["--blacklist_file", "test/runtimes/" + blacklist_file] - data += [blacklist_file] - - # Add a test that the blacklist parses correctly. - blacklist_test(name, blacklist_file) - - sh_test( - name = name + "_test", - srcs = ["runner.sh"], - args = args, - data = data, - size = size, - shard_count = shard_count, - tags = [ - # Requires docker and runsc to be configured before the test runs. - "local", - # Don't include test target in wildcard target patterns. - "manual", - ], - ) - -def blacklist_test(name, blacklist_file): - """Test that a blacklist parses correctly.""" - go_test( - name = name + "_blacklist_test", - library = ":runner", - srcs = ["blacklist_test.go"], - args = ["--blacklist_file", "test/runtimes/" + blacklist_file], - data = [blacklist_file], - ) - -def sh_test(**kwargs): - """Wraps the standard sh_test.""" - native.sh_test( - **kwargs - ) diff --git a/test/runtimes/images/Dockerfile_go1.12 b/test/runtimes/images/Dockerfile_go1.12 deleted file mode 100644 index ab9d6abf3..000000000 --- a/test/runtimes/images/Dockerfile_go1.12 +++ /dev/null @@ -1,10 +0,0 @@ -# Go is easy, since we already have everything we need to compile the proctor -# binary and run the tests in the golang Docker image. -FROM golang:1.12 -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -# Pre-compile the tests so we don't need to do so in each test run. -RUN ["go", "tool", "dist", "test", "-compile-only"] - -ENTRYPOINT ["/proctor", "--runtime=go"] diff --git a/test/runtimes/images/Dockerfile_java11 b/test/runtimes/images/Dockerfile_java11 deleted file mode 100644 index 9b7c3d5a3..000000000 --- a/test/runtimes/images/Dockerfile_java11 +++ /dev/null @@ -1,30 +0,0 @@ -# Compile the proctor binary. -FROM golang:1.12 AS golang -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -FROM ubuntu:bionic -RUN apt-get update && apt-get install -y \ - autoconf \ - build-essential \ - curl \ - make \ - openjdk-11-jdk \ - unzip \ - zip - -# Download the JDK test library. -WORKDIR /root -RUN set -ex \ - && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk11/archive/76072a077ee1.tar.gz/test \ - && tar -xzf /tmp/jdktests.tar.gz \ - && mv jdk11-76072a077ee1/test test \ - && rm -f /tmp/jdktests.tar.gz - -# Install jtreg and add to PATH. -RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz -RUN tar -xzf jtreg.tar.gz -ENV PATH="/root/jtreg/bin:$PATH" - -COPY --from=golang /proctor /proctor -ENTRYPOINT ["/proctor", "--runtime=java"] diff --git a/test/runtimes/images/Dockerfile_nodejs12.4.0 b/test/runtimes/images/Dockerfile_nodejs12.4.0 deleted file mode 100644 index 26f68b487..000000000 --- a/test/runtimes/images/Dockerfile_nodejs12.4.0 +++ /dev/null @@ -1,28 +0,0 @@ -# Compile the proctor binary. -FROM golang:1.12 AS golang -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -FROM ubuntu:bionic -RUN apt-get update && apt-get install -y \ - curl \ - dumb-init \ - g++ \ - make \ - python - -WORKDIR /root -ARG VERSION=v12.4.0 -RUN curl -o node-${VERSION}.tar.gz https://nodejs.org/dist/${VERSION}/node-${VERSION}.tar.gz -RUN tar -zxf node-${VERSION}.tar.gz - -WORKDIR /root/node-${VERSION} -RUN ./configure -RUN make -RUN make test-build - -COPY --from=golang /proctor /proctor - -# Including dumb-init emulates the Linux "init" process, preventing the failure -# of tests involving worker processes. -ENTRYPOINT ["/usr/bin/dumb-init", "/proctor", "--runtime=nodejs"] diff --git a/test/runtimes/images/Dockerfile_php7.3.6 b/test/runtimes/images/Dockerfile_php7.3.6 deleted file mode 100644 index e6b4c6329..000000000 --- a/test/runtimes/images/Dockerfile_php7.3.6 +++ /dev/null @@ -1,27 +0,0 @@ -# Compile the proctor binary. -FROM golang:1.12 AS golang -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -FROM ubuntu:bionic -RUN apt-get update && apt-get install -y \ - autoconf \ - automake \ - bison \ - build-essential \ - curl \ - libtool \ - libxml2-dev \ - re2c - -WORKDIR /root -ARG VERSION=7.3.6 -RUN curl -o php-${VERSION}.tar.gz https://www.php.net/distributions/php-${VERSION}.tar.gz -RUN tar -zxf php-${VERSION}.tar.gz - -WORKDIR /root/php-${VERSION} -RUN ./configure -RUN make - -COPY --from=golang /proctor /proctor -ENTRYPOINT ["/proctor", "--runtime=php"] diff --git a/test/runtimes/images/Dockerfile_python3.7.3 b/test/runtimes/images/Dockerfile_python3.7.3 deleted file mode 100644 index 905cd22d7..000000000 --- a/test/runtimes/images/Dockerfile_python3.7.3 +++ /dev/null @@ -1,30 +0,0 @@ -# Compile the proctor binary. -FROM golang:1.12 AS golang -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -FROM ubuntu:bionic - -RUN apt-get update && apt-get install -y \ - curl \ - gcc \ - libbz2-dev \ - libffi-dev \ - liblzma-dev \ - libreadline-dev \ - libssl-dev \ - make \ - zlib1g-dev - -# Use flags -LJO to follow the html redirect and download .tar.gz. -WORKDIR /root -ARG VERSION=3.7.3 -RUN curl -LJO https://github.com/python/cpython/archive/v${VERSION}.tar.gz -RUN tar -zxf cpython-${VERSION}.tar.gz - -WORKDIR /root/cpython-${VERSION} -RUN ./configure --with-pydebug -RUN make -s -j2 - -COPY --from=golang /proctor /proctor -ENTRYPOINT ["/proctor", "--runtime=python"] diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/images/proctor/BUILD deleted file mode 100644 index 85e004c45..000000000 --- a/test/runtimes/images/proctor/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_binary", "go_test") - -package(licenses = ["notice"]) - -go_binary( - name = "proctor", - srcs = [ - "go.go", - "java.go", - "nodejs.go", - "php.go", - "proctor.go", - "python.go", - ], - visibility = ["//test/runtimes/images:__subpackages__"], -) - -go_test( - name = "proctor_test", - size = "small", - srcs = ["proctor_test.go"], - library = ":proctor", - deps = [ - "//runsc/testutil", - ], -) diff --git a/test/runtimes/images/proctor/go.go b/test/runtimes/images/proctor/go.go deleted file mode 100644 index 3e2d5d8db..000000000 --- a/test/runtimes/images/proctor/go.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "fmt" - "os" - "os/exec" - "regexp" - "strings" -) - -var ( - goTestRegEx = regexp.MustCompile(`^.+\.go$`) - - // Directories with .dir contain helper files for tests. - // Exclude benchmarks and stress tests. - goDirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`) -) - -// Location of Go tests on disk. -const goTestDir = "/usr/local/go/test" - -// goRunner implements TestRunner for Go. -// -// There are two types of Go tests: "Go tool tests" and "Go tests on disk". -// "Go tool tests" are found and executed using `go tool dist test`. "Go tests -// on disk" are found in the /usr/local/go/test directory and are executed -// using `go run run.go`. -type goRunner struct{} - -var _ TestRunner = goRunner{} - -// ListTests implements TestRunner.ListTests. -func (goRunner) ListTests() ([]string, error) { - // Go tool dist test tests. - args := []string{"tool", "dist", "test", "-list"} - cmd := exec.Command("go", args...) - cmd.Stderr = os.Stderr - out, err := cmd.Output() - if err != nil { - return nil, fmt.Errorf("failed to list: %v", err) - } - var toolSlice []string - for _, test := range strings.Split(string(out), "\n") { - toolSlice = append(toolSlice, test) - } - - // Go tests on disk. - diskSlice, err := search(goTestDir, goTestRegEx) - if err != nil { - return nil, err - } - // Remove items from /bench/, /stress/ and .dir files - diskFiltered := diskSlice[:0] - for _, file := range diskSlice { - if !goDirFilter.MatchString(file) { - diskFiltered = append(diskFiltered, file) - } - } - - return append(toolSlice, diskFiltered...), nil -} - -// TestCmd implements TestRunner.TestCmd. -func (goRunner) TestCmd(test string) *exec.Cmd { - // Check if test exists on disk by searching for file of the same name. - // This will determine whether or not it is a Go test on disk. - if strings.HasSuffix(test, ".go") { - // Test has suffix ".go" which indicates a disk test, run it as such. - cmd := exec.Command("go", "run", "run.go", "-v", "--", test) - cmd.Dir = goTestDir - return cmd - } - - // No ".go" suffix, run as a tool test. - return exec.Command("go", "tool", "dist", "test", "-run", test) -} diff --git a/test/runtimes/images/proctor/java.go b/test/runtimes/images/proctor/java.go deleted file mode 100644 index 8b362029d..000000000 --- a/test/runtimes/images/proctor/java.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "fmt" - "os" - "os/exec" - "regexp" - "strings" -) - -// Directories to exclude from tests. -var javaExclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`) - -// Location of java tests. -const javaTestDir = "/root/test/jdk" - -// javaRunner implements TestRunner for Java. -type javaRunner struct{} - -var _ TestRunner = javaRunner{} - -// ListTests implements TestRunner.ListTests. -func (javaRunner) ListTests() ([]string, error) { - args := []string{ - "-dir:" + javaTestDir, - "-ignore:quiet", - "-a", - "-listtests", - ":jdk_core", - ":jdk_svc", - ":jdk_sound", - ":jdk_imageio", - } - cmd := exec.Command("jtreg", args...) - cmd.Stderr = os.Stderr - out, err := cmd.Output() - if err != nil { - return nil, fmt.Errorf("jtreg -listtests : %v", err) - } - var testSlice []string - for _, test := range strings.Split(string(out), "\n") { - if !javaExclDirs.MatchString(test) { - testSlice = append(testSlice, test) - } - } - return testSlice, nil -} - -// TestCmd implements TestRunner.TestCmd. -func (javaRunner) TestCmd(test string) *exec.Cmd { - args := []string{ - "-noreport", - "-dir:" + javaTestDir, - test, - } - return exec.Command("jtreg", args...) -} diff --git a/test/runtimes/images/proctor/nodejs.go b/test/runtimes/images/proctor/nodejs.go deleted file mode 100644 index bd57db444..000000000 --- a/test/runtimes/images/proctor/nodejs.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "os/exec" - "path/filepath" - "regexp" -) - -var nodejsTestRegEx = regexp.MustCompile(`^test-[^-].+\.js$`) - -// Location of nodejs tests relative to working dir. -const nodejsTestDir = "test" - -// nodejsRunner implements TestRunner for NodeJS. -type nodejsRunner struct{} - -var _ TestRunner = nodejsRunner{} - -// ListTests implements TestRunner.ListTests. -func (nodejsRunner) ListTests() ([]string, error) { - testSlice, err := search(nodejsTestDir, nodejsTestRegEx) - if err != nil { - return nil, err - } - return testSlice, nil -} - -// TestCmd implements TestRunner.TestCmd. -func (nodejsRunner) TestCmd(test string) *exec.Cmd { - args := []string{filepath.Join("tools", "test.py"), test} - return exec.Command("/usr/bin/python", args...) -} diff --git a/test/runtimes/images/proctor/php.go b/test/runtimes/images/proctor/php.go deleted file mode 100644 index 9115040e1..000000000 --- a/test/runtimes/images/proctor/php.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "os/exec" - "regexp" -) - -var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`) - -// phpRunner implements TestRunner for PHP. -type phpRunner struct{} - -var _ TestRunner = phpRunner{} - -// ListTests implements TestRunner.ListTests. -func (phpRunner) ListTests() ([]string, error) { - testSlice, err := search(".", phpTestRegEx) - if err != nil { - return nil, err - } - return testSlice, nil -} - -// TestCmd implements TestRunner.TestCmd. -func (phpRunner) TestCmd(test string) *exec.Cmd { - args := []string{"test", "TESTS=" + test} - return exec.Command("make", args...) -} diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/images/proctor/proctor.go deleted file mode 100644 index b54abe434..000000000 --- a/test/runtimes/images/proctor/proctor.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary proctor runs the test for a particular runtime. It is meant to be -// included in Docker images for all runtime tests. -package main - -import ( - "flag" - "fmt" - "log" - "os" - "os/exec" - "os/signal" - "path/filepath" - "regexp" - "syscall" -) - -// TestRunner is an interface that must be implemented for each runtime -// integrated with proctor. -type TestRunner interface { - // ListTests returns a string slice of tests available to run. - ListTests() ([]string, error) - - // TestCmd returns an *exec.Cmd that will run the given test. - TestCmd(test string) *exec.Cmd -} - -var ( - runtime = flag.String("runtime", "", "name of runtime") - list = flag.Bool("list", false, "list all available tests") - testName = flag.String("test", "", "run a single test from the list of available tests") - pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") -) - -func main() { - flag.Parse() - - if *pause { - pauseAndReap() - panic("pauseAndReap should never return") - } - - if *runtime == "" { - log.Fatalf("runtime flag must be provided") - } - - tr, err := testRunnerForRuntime(*runtime) - if err != nil { - log.Fatalf("%v", err) - } - - // List tests. - if *list { - tests, err := tr.ListTests() - if err != nil { - log.Fatalf("failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - - var tests []string - if *testName == "" { - // Run every test. - tests, err = tr.ListTests() - if err != nil { - log.Fatalf("failed to get all tests: %v", err) - } - } else { - // Run a single test. - tests = []string{*testName} - } - for _, test := range tests { - cmd := tr.TestCmd(test) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - log.Fatalf("FAIL: %v", err) - } - } -} - -// testRunnerForRuntime returns a new TestRunner for the given runtime. -func testRunnerForRuntime(runtime string) (TestRunner, error) { - switch runtime { - case "go": - return goRunner{}, nil - case "java": - return javaRunner{}, nil - case "nodejs": - return nodejsRunner{}, nil - case "php": - return phpRunner{}, nil - case "python": - return pythonRunner{}, nil - } - return nil, fmt.Errorf("invalid runtime %q", runtime) -} - -// pauseAndReap is like init. It runs forever and reaps any children. -func pauseAndReap() { - // Get notified of any new children. - ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGCHLD) - - for { - if _, ok := <-ch; !ok { - // Channel closed. This should not happen. - panic("signal channel closed") - } - - // Reap the child. - for { - if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 { - break - } - } - } -} - -// search is a helper function to find tests in the given directory that match -// the regex. -func search(root string, testFilter *regexp.Regexp) ([]string, error) { - var testSlice []string - - err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - name := filepath.Base(path) - - if info.IsDir() || !testFilter.MatchString(name) { - return nil - } - - relPath, err := filepath.Rel(root, path) - if err != nil { - return err - } - testSlice = append(testSlice, relPath) - return nil - }) - if err != nil { - return nil, fmt.Errorf("walking %q: %v", root, err) - } - - return testSlice, nil -} diff --git a/test/runtimes/images/proctor/proctor_test.go b/test/runtimes/images/proctor/proctor_test.go deleted file mode 100644 index 6bb61d142..000000000 --- a/test/runtimes/images/proctor/proctor_test.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "io/ioutil" - "os" - "path/filepath" - "reflect" - "regexp" - "strings" - "testing" - - "gvisor.dev/gvisor/runsc/testutil" -) - -func touch(t *testing.T, name string) { - t.Helper() - f, err := os.Create(name) - if err != nil { - t.Fatal(err) - } - if err := f.Close(); err != nil { - t.Fatal(err) - } -} - -func TestSearchEmptyDir(t *testing.T) { - td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(td) - - var want []string - - testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := search(td, testFilter) - if err != nil { - t.Errorf("search error: %v", err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("Found %#v; want %#v", got, want) - } -} - -func TestSearch(t *testing.T) { - td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(td) - - // Creating various files similar to the test filter regex. - files := []string{ - "emp/", - "tee/", - "test-foo.tc", - "test-foo.tc", - "test-bar.tc", - "test-sam.tc", - "Test-que.tc", - "test-brett", - "test--abc.tc", - "test---xyz.tc", - "test-bool.TC", - "--test-gvs.tc", - " test-pew.tc", - "dir/test_baz.tc", - "dir/testsnap.tc", - "dir/test-luk.tc", - "dir/nest/test-ok.tc", - "dir/dip/diz/goog/test-pack.tc", - "dir/dip/diz/wobble/thud/test-cas.e", - "dir/dip/diz/wobble/thud/test-cas.tc", - } - want := []string{ - "dir/dip/diz/goog/test-pack.tc", - "dir/dip/diz/wobble/thud/test-cas.tc", - "dir/nest/test-ok.tc", - "dir/test-luk.tc", - "test-bar.tc", - "test-foo.tc", - "test-sam.tc", - } - - for _, item := range files { - if strings.HasSuffix(item, "/") { - // This item is a directory, create it. - if err := os.MkdirAll(filepath.Join(td, item), 0755); err != nil { - t.Fatal(err) - } - } else { - // This item is a file, create the directory and touch file. - // Create directory in which file should be created - fullDirPath := filepath.Join(td, filepath.Dir(item)) - if err := os.MkdirAll(fullDirPath, 0755); err != nil { - t.Fatal(err) - } - // Create file with full path to file. - touch(t, filepath.Join(td, item)) - } - } - - testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := search(td, testFilter) - if err != nil { - t.Errorf("search error: %v", err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("Found %#v; want %#v", got, want) - } -} diff --git a/test/runtimes/images/proctor/python.go b/test/runtimes/images/proctor/python.go deleted file mode 100644 index b9e0fbe6f..000000000 --- a/test/runtimes/images/proctor/python.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "fmt" - "os" - "os/exec" - "strings" -) - -// pythonRunner implements TestRunner for Python. -type pythonRunner struct{} - -var _ TestRunner = pythonRunner{} - -// ListTests implements TestRunner.ListTests. -func (pythonRunner) ListTests() ([]string, error) { - args := []string{"-m", "test", "--list-tests"} - cmd := exec.Command("./python", args...) - cmd.Stderr = os.Stderr - out, err := cmd.Output() - if err != nil { - return nil, fmt.Errorf("failed to list: %v", err) - } - var toolSlice []string - for _, test := range strings.Split(string(out), "\n") { - toolSlice = append(toolSlice, test) - } - return toolSlice, nil -} - -// TestCmd implements TestRunner.TestCmd. -func (pythonRunner) TestCmd(test string) *exec.Cmd { - args := []string{"-m", "test", test} - return exec.Command("./python", args...) -} diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go deleted file mode 100644 index ddb890dbc..000000000 --- a/test/runtimes/runner.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary runner runs the runtime tests in a Docker container. -package main - -import ( - "encoding/csv" - "flag" - "fmt" - "io" - "os" - "sort" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" -) - -var ( - lang = flag.String("lang", "", "language runtime to test") - image = flag.String("image", "", "docker image with runtime tests") - blacklistFile = flag.String("blacklist_file", "", "file containing blacklist of tests to exclude, in CSV format with fields: test name, bug id, comment") -) - -// Wait time for each test to run. -const timeout = 5 * time.Minute - -func main() { - flag.Parse() - if *lang == "" || *image == "" { - fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") - os.Exit(1) - } - - os.Exit(runTests()) -} - -// runTests is a helper that is called by main. It exists so that we can run -// defered functions before exiting. It returns an exit code that should be -// passed to os.Exit. -func runTests() int { - // Get tests to blacklist. - blacklist, err := getBlacklist() - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting blacklist: %s\n", err.Error()) - return 1 - } - - // Create a single docker container that will be used for all tests. - d := dockerutil.MakeDocker("gvisor-" + *lang) - defer d.CleanUp() - - // Get a slice of tests to run. This will also start a single Docker - // container that will be used to run each test. The final test will - // stop the Docker container. - tests, err := getTests(d, blacklist) - if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err.Error()) - return 1 - } - - m := testing.MainStart(testDeps{}, tests, nil, nil) - return m.Run() -} - -// getTests returns a slice of tests to run, subject to the shard size and -// index. -func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) { - // Pull the image. - if err := dockerutil.Pull(*image); err != nil { - return nil, fmt.Errorf("docker pull %q failed: %v", *image, err) - } - - // Run proctor with --pause flag to keep container alive forever. - if err := d.Run(*image, "--pause"); err != nil { - return nil, fmt.Errorf("docker run failed: %v", err) - } - - // Get a list of all tests in the image. - list, err := d.Exec("/proctor", "--runtime", *lang, "--list") - if err != nil { - return nil, fmt.Errorf("docker exec failed: %v", err) - } - - // Calculate a subset of tests to run corresponding to the current - // shard. - tests := strings.Fields(list) - sort.Strings(tests) - indices, err := testutil.TestIndicesForShard(len(tests)) - if err != nil { - return nil, fmt.Errorf("TestsForShard() failed: %v", err) - } - - var itests []testing.InternalTest - for _, tci := range indices { - // Capture tc in this scope. - tc := tests[tci] - itests = append(itests, testing.InternalTest{ - Name: tc, - F: func(t *testing.T) { - // Is the test blacklisted? - if _, ok := blacklist[tc]; ok { - t.Skip("SKIP: blacklisted test %q", tc) - } - - var ( - now = time.Now() - done = make(chan struct{}) - output string - err error - ) - - go func() { - fmt.Printf("RUNNING %s...\n", tc) - output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc) - close(done) - }() - - select { - case <-done: - if err == nil { - fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now)) - return - } - t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output) - case <-time.After(timeout): - t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output) - } - }, - }) - } - return itests, nil -} - -// getBlacklist reads the blacklist file and returns a set of test names to -// exclude. -func getBlacklist() (map[string]struct{}, error) { - blacklist := make(map[string]struct{}) - if *blacklistFile == "" { - return blacklist, nil - } - file, err := testutil.FindFile(*blacklistFile) - if err != nil { - return nil, err - } - f, err := os.Open(file) - if err != nil { - return nil, err - } - defer f.Close() - - r := csv.NewReader(f) - - // First line is header. Skip it. - if _, err := r.Read(); err != nil { - return nil, err - } - - for { - record, err := r.Read() - if err == io.EOF { - break - } - if err != nil { - return nil, err - } - blacklist[record[0]] = struct{}{} - } - return blacklist, nil -} - -// testDeps implements testing.testDeps (an unexported interface), and is -// required to use testing.MainStart. -type testDeps struct{} - -func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil } -func (f testDeps) StartCPUProfile(io.Writer) error { return nil } -func (f testDeps) StopCPUProfile() {} -func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } -func (f testDeps) ImportPath() string { return "" } -func (f testDeps) StartTestLog(io.Writer) {} -func (f testDeps) StopTestLog() error { return nil } diff --git a/test/runtimes/runner.sh b/test/runtimes/runner.sh deleted file mode 100755 index a8d9a3460..000000000 --- a/test/runtimes/runner.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -euf -x -o pipefail - -echo -- "$@" - -# Create outputs dir if it does not exist. -if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then - mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}" - chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}" -fi - -# Update the timestamp on the shard status file. Bazel looks for this. -touch "${TEST_SHARD_STATUS_FILE}" - -# Get location of runner binary. -readonly runner=$(find "${TEST_SRCDIR}" -name runner) - -# Pass the arguments of this script directly to the runner. -exec "${runner}" "$@" - diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD deleted file mode 100644 index 9800a0cdf..000000000 --- a/test/syscalls/BUILD +++ /dev/null @@ -1,740 +0,0 @@ -load("//test/runner:defs.bzl", "syscall_test") - -package(licenses = ["notice"]) - -syscall_test(test = "//test/syscalls/linux:32bit_test") - -syscall_test(test = "//test/syscalls/linux:accept_bind_stream_test") - -syscall_test( - size = "large", - shard_count = 50, - test = "//test/syscalls/linux:accept_bind_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:access_test", -) - -syscall_test(test = "//test/syscalls/linux:affinity_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:aio_test", -) - -syscall_test( - size = "medium", - shard_count = 5, - test = "//test/syscalls/linux:alarm_test", -) - -syscall_test(test = "//test/syscalls/linux:arch_prctl_test") - -syscall_test(test = "//test/syscalls/linux:bad_test") - -syscall_test( - size = "large", - add_overlay = True, - test = "//test/syscalls/linux:bind_test", -) - -syscall_test(test = "//test/syscalls/linux:brk_test") - -syscall_test(test = "//test/syscalls/linux:socket_test") - -syscall_test( - size = "large", - shard_count = 50, - # Takes too long for TSAN. Since this is kind of a stress test that doesn't - # involve much concurrency, TSAN's usefulness here is limited anyway. - tags = ["nogotsan"], - test = "//test/syscalls/linux:socket_stress_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:chdir_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:chmod_test", -) - -syscall_test( - size = "medium", - add_overlay = True, - test = "//test/syscalls/linux:chown_test", - use_tmpfs = True, # chwon tests require gofer to be running as root. -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:chroot_test", -) - -syscall_test(test = "//test/syscalls/linux:clock_getres_test") - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:clock_gettime_test", -) - -syscall_test(test = "//test/syscalls/linux:clock_nanosleep_test") - -syscall_test(test = "//test/syscalls/linux:concurrency_test") - -syscall_test( - add_uds_tree = True, - test = "//test/syscalls/linux:connect_external_test", - use_tmpfs = True, -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:creat_test", -) - -syscall_test(test = "//test/syscalls/linux:dev_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:dup_test", -) - -syscall_test(test = "//test/syscalls/linux:epoll_test") - -syscall_test(test = "//test/syscalls/linux:eventfd_test") - -syscall_test(test = "//test/syscalls/linux:exceptions_test") - -syscall_test( - size = "medium", - add_overlay = True, - test = "//test/syscalls/linux:exec_test", -) - -syscall_test( - size = "medium", - add_overlay = True, - test = "//test/syscalls/linux:exec_binary_test", -) - -syscall_test(test = "//test/syscalls/linux:exit_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:fadvise64_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:fallocate_test", -) - -syscall_test(test = "//test/syscalls/linux:fault_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:fchdir_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:fcntl_test", -) - -syscall_test( - size = "medium", - add_overlay = True, - test = "//test/syscalls/linux:flock_test", -) - -syscall_test(test = "//test/syscalls/linux:fork_test") - -syscall_test(test = "//test/syscalls/linux:fpsig_fork_test") - -syscall_test(test = "//test/syscalls/linux:fpsig_nested_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:fsync_test", -) - -syscall_test( - size = "medium", - shard_count = 5, - test = "//test/syscalls/linux:futex_test", -) - -syscall_test(test = "//test/syscalls/linux:getcpu_host_test") - -syscall_test(test = "//test/syscalls/linux:getcpu_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:getdents_test", -) - -syscall_test(test = "//test/syscalls/linux:getrandom_test") - -syscall_test(test = "//test/syscalls/linux:getrusage_test") - -syscall_test( - size = "medium", - add_overlay = False, # TODO(gvisor.dev/issue/317): enable when fixed. - test = "//test/syscalls/linux:inotify_test", -) - -syscall_test( - size = "medium", - add_overlay = True, - test = "//test/syscalls/linux:ioctl_test", -) - -syscall_test( - test = "//test/syscalls/linux:iptables_test", -) - -syscall_test( - size = "large", - shard_count = 5, - test = "//test/syscalls/linux:itimer_test", -) - -syscall_test(test = "//test/syscalls/linux:kill_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:link_test", - use_tmpfs = True, # gofer needs CAP_DAC_READ_SEARCH to use AT_EMPTY_PATH with linkat(2) -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:lseek_test", -) - -syscall_test(test = "//test/syscalls/linux:madvise_test") - -syscall_test(test = "//test/syscalls/linux:memory_accounting_test") - -syscall_test(test = "//test/syscalls/linux:mempolicy_test") - -syscall_test(test = "//test/syscalls/linux:mincore_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:mkdir_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:mknod_test", -) - -syscall_test( - size = "medium", - shard_count = 5, - test = "//test/syscalls/linux:mmap_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:mount_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:mremap_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:msync_test", -) - -syscall_test(test = "//test/syscalls/linux:munmap_test") - -syscall_test(test = "//test/syscalls/linux:network_namespace_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:open_create_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:open_test", -) - -syscall_test(test = "//test/syscalls/linux:packet_socket_raw_test") - -syscall_test(test = "//test/syscalls/linux:packet_socket_test") - -syscall_test(test = "//test/syscalls/linux:partial_bad_buffer_test") - -syscall_test(test = "//test/syscalls/linux:pause_test") - -syscall_test( - size = "large", - add_overlay = True, - shard_count = 5, - test = "//test/syscalls/linux:pipe_test", -) - -syscall_test(test = "//test/syscalls/linux:poll_test") - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:ppoll_test", -) - -syscall_test(test = "//test/syscalls/linux:prctl_setuid_test") - -syscall_test(test = "//test/syscalls/linux:prctl_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:pread64_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:preadv_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:preadv2_test", -) - -syscall_test(test = "//test/syscalls/linux:priority_test") - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:proc_test", -) - -syscall_test(test = "//test/syscalls/linux:proc_net_test") - -syscall_test(test = "//test/syscalls/linux:proc_pid_oomscore_test") - -syscall_test(test = "//test/syscalls/linux:proc_pid_smaps_test") - -syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test") - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:pselect_test", -) - -syscall_test(test = "//test/syscalls/linux:ptrace_test") - -syscall_test( - size = "medium", - shard_count = 5, - test = "//test/syscalls/linux:pty_test", -) - -syscall_test( - test = "//test/syscalls/linux:pty_root_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:pwritev2_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:pwrite64_test", -) - -syscall_test(test = "//test/syscalls/linux:raw_socket_hdrincl_test") - -syscall_test(test = "//test/syscalls/linux:raw_socket_icmp_test") - -syscall_test(test = "//test/syscalls/linux:raw_socket_ipv4_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:read_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:readahead_test", -) - -syscall_test( - size = "medium", - shard_count = 5, - test = "//test/syscalls/linux:readv_socket_test", -) - -syscall_test( - size = "medium", - add_overlay = True, - test = "//test/syscalls/linux:readv_test", -) - -syscall_test( - size = "medium", - add_overlay = True, - test = "//test/syscalls/linux:rename_test", -) - -syscall_test(test = "//test/syscalls/linux:rlimits_test") - -syscall_test(test = "//test/syscalls/linux:rseq_test") - -syscall_test(test = "//test/syscalls/linux:rtsignal_test") - -syscall_test(test = "//test/syscalls/linux:signalfd_test") - -syscall_test(test = "//test/syscalls/linux:sched_test") - -syscall_test(test = "//test/syscalls/linux:sched_yield_test") - -syscall_test(test = "//test/syscalls/linux:seccomp_test") - -syscall_test(test = "//test/syscalls/linux:select_test") - -syscall_test( - shard_count = 20, - test = "//test/syscalls/linux:semaphore_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:sendfile_socket_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:sendfile_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:splice_test", -) - -syscall_test(test = "//test/syscalls/linux:sigaction_test") - -# TODO(b/119826902): Enable once the test passes in runsc. -# syscall_test(test = "//test/syscalls/linux:sigaltstack_test") - -syscall_test(test = "//test/syscalls/linux:sigiret_test") - -syscall_test(test = "//test/syscalls/linux:sigprocmask_test") - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:sigstop_test", -) - -syscall_test(test = "//test/syscalls/linux:sigtimedwait_test") - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:shm_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_abstract_non_blocking_test", -) - -syscall_test( - size = "large", - shard_count = 50, - test = "//test/syscalls/linux:socket_abstract_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_domain_non_blocking_test", -) - -syscall_test( - size = "large", - shard_count = 50, - test = "//test/syscalls/linux:socket_domain_test", -) - -syscall_test( - size = "medium", - add_overlay = True, - test = "//test/syscalls/linux:socket_filesystem_non_blocking_test", -) - -syscall_test( - size = "large", - add_overlay = True, - shard_count = 50, - test = "//test/syscalls/linux:socket_filesystem_test", -) - -syscall_test( - size = "large", - shard_count = 50, - test = "//test/syscalls/linux:socket_inet_loopback_test", -) - -syscall_test( - size = "large", - shard_count = 50, - test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_ip_tcp_loopback_non_blocking_test", -) - -syscall_test( - size = "large", - shard_count = 50, - test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", -) - -syscall_test( - size = "medium", - shard_count = 50, - test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_ip_udp_loopback_non_blocking_test", -) - -syscall_test( - size = "large", - shard_count = 50, - test = "//test/syscalls/linux:socket_ip_udp_loopback_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test", -) - -syscall_test(test = "//test/syscalls/linux:socket_ip_unbound_test") - -syscall_test(test = "//test/syscalls/linux:socket_netdevice_test") - -syscall_test(test = "//test/syscalls/linux:socket_netlink_test") - -syscall_test(test = "//test/syscalls/linux:socket_netlink_route_test") - -syscall_test(test = "//test/syscalls/linux:socket_netlink_uevent_test") - -syscall_test(test = "//test/syscalls/linux:socket_blocking_local_test") - -syscall_test(test = "//test/syscalls/linux:socket_blocking_ip_test") - -syscall_test(test = "//test/syscalls/linux:socket_non_stream_blocking_local_test") - -syscall_test(test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test") - -syscall_test( - size = "large", - test = "//test/syscalls/linux:socket_stream_blocking_local_test", -) - -syscall_test( - size = "large", - test = "//test/syscalls/linux:socket_stream_blocking_tcp_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_stream_local_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_stream_nonblock_local_test", -) - -syscall_test( - # NOTE(b/116636318): Large sendmsg may stall a long time. - size = "enormous", - shard_count = 5, - test = "//test/syscalls/linux:socket_unix_dgram_local_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_unix_dgram_non_blocking_test", -) - -syscall_test( - size = "large", - add_overlay = True, - shard_count = 50, - test = "//test/syscalls/linux:socket_unix_pair_test", -) - -syscall_test( - # NOTE(b/116636318): Large sendmsg may stall a long time. - size = "enormous", - shard_count = 5, - test = "//test/syscalls/linux:socket_unix_seqpacket_local_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_unix_stream_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_unix_unbound_abstract_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_unix_unbound_dgram_test", -) - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:socket_unix_unbound_filesystem_test", -) - -syscall_test( - size = "medium", - shard_count = 10, - test = "//test/syscalls/linux:socket_unix_unbound_seqpacket_test", -) - -syscall_test( - size = "large", - shard_count = 50, - test = "//test/syscalls/linux:socket_unix_unbound_stream_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:statfs_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:stat_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:stat_times_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:sticky_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:symlink_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:sync_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:sync_file_range_test", -) - -syscall_test(test = "//test/syscalls/linux:sysinfo_test") - -syscall_test(test = "//test/syscalls/linux:syslog_test") - -syscall_test(test = "//test/syscalls/linux:sysret_test") - -syscall_test( - size = "medium", - shard_count = 10, - test = "//test/syscalls/linux:tcp_socket_test", -) - -syscall_test(test = "//test/syscalls/linux:tgkill_test") - -syscall_test(test = "//test/syscalls/linux:timerfd_test") - -syscall_test(test = "//test/syscalls/linux:timers_test") - -syscall_test(test = "//test/syscalls/linux:time_test") - -syscall_test(test = "//test/syscalls/linux:tkill_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:truncate_test", -) - -syscall_test(test = "//test/syscalls/linux:tuntap_test") - -syscall_test( - add_hostinet = True, - test = "//test/syscalls/linux:tuntap_hostinet_test", -) - -syscall_test(test = "//test/syscalls/linux:udp_bind_test") - -syscall_test( - size = "medium", - add_hostinet = True, - shard_count = 10, - test = "//test/syscalls/linux:udp_socket_test", -) - -syscall_test(test = "//test/syscalls/linux:uidgid_test") - -syscall_test(test = "//test/syscalls/linux:uname_test") - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:unlink_test", -) - -syscall_test(test = "//test/syscalls/linux:unshare_test") - -syscall_test(test = "//test/syscalls/linux:utimes_test") - -syscall_test( - size = "medium", - test = "//test/syscalls/linux:vdso_clock_gettime_test", -) - -syscall_test(test = "//test/syscalls/linux:vdso_test") - -syscall_test(test = "//test/syscalls/linux:vsyscall_test") - -syscall_test(test = "//test/syscalls/linux:vfork_test") - -syscall_test( - size = "medium", - shard_count = 5, - test = "//test/syscalls/linux:wait_test", -) - -syscall_test( - add_overlay = True, - test = "//test/syscalls/linux:write_test", -) - -syscall_test(test = "//test/syscalls/linux:proc_net_unix_test") - -syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test") - -syscall_test(test = "//test/syscalls/linux:proc_net_udp_test") diff --git a/test/syscalls/README.md b/test/syscalls/README.md deleted file mode 100644 index 9e0991940..000000000 --- a/test/syscalls/README.md +++ /dev/null @@ -1,107 +0,0 @@ -# gVisor system call test suite - -This is a test suite for Linux system calls. It runs under both gVisor and -Linux, and ensures compatibility between the two. - -When adding support for a new syscall (or syscall argument) to gVisor, a -corresponding syscall test should be added. It's usually recommended to write -the test first and make sure that it passes on Linux before making changes to -gVisor. - -This document outlines the general guidelines for tests and specific rules that -must be followed for new tests. - -## Running the tests - -Each test file generates three different test targets that run in different -environments: - -* a `native` target that runs directly on the host machine, -* a `runsc_ptrace` target that runs inside runsc using the ptrace platform, and -* a `runsc_kvm` target that runs inside runsc using the KVM platform. - -For example, the test in `access_test.cc` generates the following targets: - -* `//test/syscalls:access_test_native` -* `//test/syscalls:access_test_runsc_ptrace` -* `//test/syscalls:access_test_runsc_kvm` - -Any of these targets can be run directly via `bazel test`. - -```bash -$ bazel test //test/syscalls:access_test_native -$ bazel test //test/syscalls:access_test_runsc_ptrace -$ bazel test //test/syscalls:access_test_runsc_kvm -``` - -To run all the tests on a particular platform, you can filter by the platform -tag: - -```bash -# Run all tests in native environment: -$ bazel test --test_tag_filters=native //test/syscalls/... - -# Run all tests in runsc with ptrace: -$ bazel test --test_tag_filters=runsc_ptrace //test/syscalls/... - -# Run all tests in runsc with kvm: -$ bazel test --test_tag_filters=runsc_kvm //test/syscalls/... -``` - -You can also run all the tests on every platform. (Warning, this may take a -while to run.) - -```bash -# Run all tests on every platform: -$ bazel test //test/syscalls/... -``` - -## Writing new tests - -Whenever we add support for a new syscall, or add support for a new argument or -option for a syscall, we should always add a new test (perhaps many new tests). - -In general, it is best to write the test first and make sure it passes on Linux -by running the test on the `native` platform on a Linux machine. This ensures -that the gVisor implementation matches actual Linux behavior. Sometimes man -pages contain errors, so always check the actual Linux behavior. - -gVisor uses the [Google Test][googletest] test framework, with a few custom -matchers and guidelines, described below. - -### Syscall matchers - -When testing an individual system call, use the following syscall matchers, -which will match the value returned by the syscall and the errno. - -```cc -SyscallSucceeds() -SyscallSucceedsWithValue(...) -SyscallFails() -SyscallFailsWithErrno(...) -``` - -### Use test utilities (RAII classes) - -The test utilties are written as RAII classes. These utilities should be -preferred over custom test harnesses. - -Local class instances should be preferred, wherever possible, over full test -fixtures. - -A test utility should be created when there is more than one test that requires -that same functionality, otherwise the class should be test local. - -## Save/Restore support in tests - -gVisor supports save/restore, and our syscall tests are written in a way to -enable saving/restoring at certain points. Hence, there are calls to -`MaybeSave`, and certain tests that should not trigger saves are named with -`NoSave`. - -However, the current open-source test runner does not yet support triggering -save/restore, so these functions and annotations have no effect on the tests. We -plan on extending the test runner to trigger save/restore. Until then, these -functions and annotations should be ignored. - -[googletest]: https://github.com/abseil/googletest diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc deleted file mode 100644 index 3c825477c..000000000 --- a/test/syscalls/linux/32bit.cc +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <string.h> -#include <sys/mman.h> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "test/util/memory_util.h" -#include "test/util/platform_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -#ifndef __x86_64__ -#error "This test is x86-64 specific." -#endif - -namespace gvisor { -namespace testing { - -namespace { - -constexpr char kInt3 = '\xcc'; -constexpr char kInt80[2] = {'\xcd', '\x80'}; -constexpr char kSyscall[2] = {'\x0f', '\x05'}; -constexpr char kSysenter[2] = {'\x0f', '\x34'}; - -void ExitGroup32(const char instruction[2], int code) { - const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE | PROT_EXEC, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_32BIT, -1, 0)); - - // Fill with INT 3 in case we execute too far. - memset(m.ptr(), kInt3, m.len()); - - // Copy in the actual instruction. - memcpy(m.ptr(), instruction, 2); - - // We're playing *extremely* fast-and-loose with the various syscall ABIs - // here, which we can more-or-less get away with since exit_group doesn't - // return. - // - // SYSENTER expects the user stack in (%ebp) and arg6 in 0(%ebp). The kernel - // will unconditionally dereference %ebp for arg6, so we must pass a valid - // address or it will return EFAULT. - // - // SYSENTER also unconditionally returns to thread_info->sysenter_return which - // is ostensibly a stub in the 32-bit VDSO. But a 64-bit binary doesn't have - // the 32-bit VDSO mapped, so sysenter_return will simply be the value - // inherited from the most recent 32-bit ancestor, or NULL if there is none. - // As a result, return would not return from SYSENTER. - asm volatile( - "movl $252, %%eax\n" // exit_group - "movl %[code], %%ebx\n" // code - "movl %%edx, %%ebp\n" // SYSENTER: user stack (use IP as a valid addr) - "leaq -20(%%rsp), %%rsp\n" - "movl $0x2b, 16(%%rsp)\n" // SS = CPL3 data segment - "movl $0,12(%%rsp)\n" // ESP = nullptr (unused) - "movl $0, 8(%%rsp)\n" // EFLAGS - "movl $0x23, 4(%%rsp)\n" // CS = CPL3 32-bit code segment - "movl %%edx, 0(%%rsp)\n" // EIP - "iretl\n" - "int $3\n" - : - : [ code ] "m"(code), [ ip ] "d"(m.ptr()) - : "rax", "rbx"); -} - -constexpr int kExitCode = 42; - -TEST(Syscall32Bit, Int80) { - switch (PlatformSupport32Bit()) { - case PlatformSupport::NotSupported: - break; - case PlatformSupport::Segfault: - EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), - ::testing::KilledBySignal(SIGSEGV), ""); - break; - - case PlatformSupport::Ignored: - // Since the call is ignored, we'll hit the int3 trap. - EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), - ::testing::KilledBySignal(SIGTRAP), ""); - break; - - case PlatformSupport::Allowed: - EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), ::testing::ExitedWithCode(42), - ""); - break; - } -} - -TEST(Syscall32Bit, Sysenter) { - if ((PlatformSupport32Bit() == PlatformSupport::Allowed || - PlatformSupport32Bit() == PlatformSupport::Ignored) && - GetCPUVendor() == CPUVendor::kAMD) { - // SYSENTER is an illegal instruction in compatibility mode on AMD. - EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), - ::testing::KilledBySignal(SIGILL), ""); - return; - } - - switch (PlatformSupport32Bit()) { - case PlatformSupport::NotSupported: - break; - - case PlatformSupport::Segfault: - EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), - ::testing::KilledBySignal(SIGSEGV), ""); - break; - - case PlatformSupport::Ignored: - // See above, except expected code is SIGSEGV. - EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), - ::testing::KilledBySignal(SIGSEGV), ""); - break; - - case PlatformSupport::Allowed: - EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), - ::testing::ExitedWithCode(42), ""); - break; - } -} - -TEST(Syscall32Bit, Syscall) { - if ((PlatformSupport32Bit() == PlatformSupport::Allowed || - PlatformSupport32Bit() == PlatformSupport::Ignored) && - GetCPUVendor() == CPUVendor::kIntel) { - // SYSCALL is an illegal instruction in compatibility mode on Intel. - EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), - ::testing::KilledBySignal(SIGILL), ""); - return; - } - - switch (PlatformSupport32Bit()) { - case PlatformSupport::NotSupported: - break; - - case PlatformSupport::Segfault: - EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), - ::testing::KilledBySignal(SIGSEGV), ""); - break; - - case PlatformSupport::Ignored: - // See above. - EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), - ::testing::KilledBySignal(SIGSEGV), ""); - break; - - case PlatformSupport::Allowed: - EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), - ::testing::ExitedWithCode(42), ""); - break; - } -} - -// Far call code called below. -// -// Input stack layout: -// -// %esp+12 lcall segment -// %esp+8 lcall address offset -// %esp+0 return address -// -// The lcall will enter compatibility mode and jump to the call address (the -// address of the lret). The lret will return to 64-bit mode at the retq, which -// will return to the external caller of this function. -// -// Since this enters compatibility mode, it must be mapped in a 32-bit region of -// address space and have a 32-bit stack pointer. -constexpr char kFarCall[] = { - '\x67', '\xff', '\x5c', '\x24', '\x08', // lcall *8(%esp) - '\xc3', // retq - '\xcb', // lret -}; - -void FarCall32() { - const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE | PROT_EXEC, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_32BIT, -1, 0)); - - // Fill with INT 3 in case we execute too far. - memset(m.ptr(), kInt3, m.len()); - - // 32-bit code. - memcpy(m.ptr(), kFarCall, sizeof(kFarCall)); - - // Use the end of the code page as its stack. - uintptr_t stack = m.endaddr(); - - uintptr_t lcall = m.addr(); - uintptr_t lret = m.addr() + sizeof(kFarCall) - 1; - - // N.B. We must save and restore RSP manually. GCC can do so automatically - // with an "rsp" clobber, but clang cannot. - asm volatile( - // Place the address of lret (%edx) and the 32-bit code segment (0x23) on - // the 32-bit stack for lcall. - "subl $0x8, %%ecx\n" - "movl $0x23, 4(%%ecx)\n" - "movl %%edx, 0(%%ecx)\n" - - // Save the current stack and switch to 32-bit stack. - "pushq %%rbp\n" - "movq %%rsp, %%rbp\n" - "movq %%rcx, %%rsp\n" - - // Run the lcall code. - "callq *%%rbx\n" - - // Restore the old stack. - "leaveq\n" - : "+c"(stack) - : "b"(lcall), "d"(lret)); -} - -TEST(Call32Bit, Disallowed) { - switch (PlatformSupport32Bit()) { - case PlatformSupport::NotSupported: - break; - - case PlatformSupport::Segfault: - EXPECT_EXIT(FarCall32(), ::testing::KilledBySignal(SIGSEGV), ""); - break; - - case PlatformSupport::Ignored: - ABSL_FALLTHROUGH_INTENDED; - case PlatformSupport::Allowed: - // Shouldn't crash. - FarCall32(); - } -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD deleted file mode 100644 index 636e5db12..000000000 --- a/test/syscalls/linux/BUILD +++ /dev/null @@ -1,3875 +0,0 @@ -load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gtest", "select_arch", "select_system") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -exports_files( - [ - "socket.cc", - "socket_inet_loopback.cc", - "socket_ip_loopback_blocking.cc", - "socket_ip_tcp_generic_loopback.cc", - "socket_ip_tcp_loopback.cc", - "socket_ip_tcp_loopback_blocking.cc", - "socket_ip_tcp_loopback_nonblock.cc", - "socket_ip_tcp_udp_generic.cc", - "socket_ip_udp_loopback.cc", - "socket_ip_udp_loopback_blocking.cc", - "socket_ip_udp_loopback_nonblock.cc", - "socket_ip_unbound.cc", - "socket_ipv4_tcp_unbound_external_networking_test.cc", - "socket_ipv4_udp_unbound_external_networking_test.cc", - "socket_ipv4_udp_unbound_loopback.cc", - "tcp_socket.cc", - "udp_bind.cc", - "udp_socket.cc", - ], - visibility = ["//:sandbox"], -) - -cc_binary( - name = "sigaltstack_check", - testonly = 1, - srcs = ["sigaltstack_check.cc"], - deps = ["//test/util:logging"], -) - -cc_binary( - name = "exec_assert_closed_workload", - testonly = 1, - srcs = ["exec_assert_closed_workload.cc"], - deps = [ - "@com_google_absl//absl/strings", - ], -) - -cc_binary( - name = "exec_basic_workload", - testonly = 1, - srcs = [ - "exec.h", - "exec_basic_workload.cc", - ], -) - -cc_binary( - name = "exec_proc_exe_workload", - testonly = 1, - srcs = ["exec_proc_exe_workload.cc"], - deps = [ - "//test/util:fs_util", - "//test/util:posix_error", - ], -) - -cc_binary( - name = "exec_state_workload", - testonly = 1, - srcs = ["exec_state_workload.cc"], - deps = ["@com_google_absl//absl/strings"], -) - -sh_binary( - name = "exit_script", - testonly = 1, - srcs = [ - "exit_script.sh", - ], -) - -cc_binary( - name = "priority_execve", - testonly = 1, - srcs = [ - "priority_execve.cc", - ], -) - -cc_library( - name = "base_poll_test", - testonly = 1, - srcs = ["base_poll_test.cc"], - hdrs = ["base_poll_test.h"], - deps = [ - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:signal_util", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_library( - name = "file_base", - testonly = 1, - hdrs = ["file_base.h"], - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_util", - ], -) - -cc_library( - name = "socket_netlink_util", - testonly = 1, - srcs = ["socket_netlink_util.cc"], - hdrs = ["socket_netlink_util.h"], - deps = [ - ":socket_test_util", - "//test/util:file_descriptor", - "//test/util:posix_error", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "socket_netlink_route_util", - testonly = 1, - srcs = ["socket_netlink_route_util.cc"], - hdrs = ["socket_netlink_route_util.h"], - deps = [ - ":socket_netlink_util", - "@com_google_absl//absl/types:optional", - ], -) - -cc_library( - name = "socket_test_util", - testonly = 1, - srcs = [ - "socket_test_util.cc", - "socket_test_util_impl.cc", - ], - hdrs = ["socket_test_util.h"], - defines = select_system(), - deps = default_net_util() + [ - gtest, - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:optional", - "//test/util:file_descriptor", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_library( - name = "unix_domain_socket_test_util", - testonly = 1, - srcs = ["unix_domain_socket_test_util.cc"], - hdrs = ["unix_domain_socket_test_util.h"], - deps = [ - ":socket_test_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:test_util", - ], -) - -cc_library( - name = "ip_socket_test_util", - testonly = 1, - srcs = ["ip_socket_test_util.cc"], - hdrs = ["ip_socket_test_util.h"], - deps = [ - ":socket_test_util", - "@com_google_absl//absl/strings", - ], -) - -cc_binary( - name = "clock_nanosleep_test", - testonly = 1, - srcs = ["clock_nanosleep.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "@com_google_absl//absl/time", - gtest, - "//test/util:posix_error", - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:timer_util", - ], -) - -cc_binary( - name = "32bit_test", - testonly = 1, - srcs = select_arch( - amd64 = ["32bit.cc"], - arm64 = [], - ), - linkstatic = 1, - deps = [ - "@com_google_absl//absl/base:core_headers", - gtest, - "//test/util:memory_util", - "//test/util:platform_util", - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "accept_bind_test", - testonly = 1, - srcs = ["accept_bind.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "accept_bind_stream_test", - testonly = 1, - srcs = ["accept_bind_stream.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "access_test", - testonly = 1, - srcs = ["access.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:fs_util", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "affinity_test", - testonly = 1, - srcs = ["affinity.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "aio_test", - testonly = 1, - srcs = [ - "aio.cc", - "file_base.h", - ], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:memory_util", - "//test/util:posix_error", - "//test/util:proc_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "alarm_test", - testonly = 1, - srcs = ["alarm.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:signal_util", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "bad_test", - testonly = 1, - srcs = ["bad.cc"], - linkstatic = 1, - visibility = [ - "//:sandbox", - ], - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "bind_test", - testonly = 1, - srcs = ["bind.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_test", - testonly = 1, - srcs = ["socket.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "brk_test", - testonly = 1, - srcs = ["brk.cc"], - linkstatic = 1, - deps = [ - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "chdir_test", - testonly = 1, - srcs = ["chdir.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "chmod_test", - testonly = 1, - srcs = ["chmod.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "chown_test", - testonly = 1, - srcs = ["chown.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/synchronization", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "sticky_test", - testonly = 1, - srcs = ["sticky.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/flags:flag", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "chroot_test", - testonly = 1, - srcs = ["chroot.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:mount_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "clock_getres_test", - testonly = 1, - srcs = ["clock_getres.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "clock_gettime_test", - testonly = 1, - srcs = ["clock_gettime.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "concurrency_test", - testonly = 1, - srcs = ["concurrency.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:platform_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "connect_external_test", - testonly = 1, - srcs = ["connect_external.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "creat_test", - testonly = 1, - srcs = ["creat.cc"], - linkstatic = 1, - deps = [ - "//test/util:fs_util", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "dev_test", - testonly = 1, - srcs = ["dev.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "dup_test", - testonly = 1, - srcs = ["dup.cc"], - linkstatic = 1, - deps = [ - "//test/util:eventfd_util", - "//test/util:file_descriptor", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "epoll_test", - testonly = 1, - srcs = ["epoll.cc"], - linkstatic = 1, - deps = [ - "//test/util:epoll_util", - "//test/util:eventfd_util", - "//test/util:file_descriptor", - gtest, - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "eventfd_test", - testonly = 1, - srcs = ["eventfd.cc"], - linkstatic = 1, - deps = [ - "//test/util:epoll_util", - "//test/util:eventfd_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "exceptions_test", - testonly = 1, - srcs = select_arch( - amd64 = ["exceptions.cc"], - arm64 = [], - ), - linkstatic = 1, - deps = [ - gtest, - "//test/util:logging", - "//test/util:platform_util", - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "getcpu_test", - testonly = 1, - srcs = ["getcpu.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "getcpu_host_test", - testonly = 1, - srcs = ["getcpu.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "getrusage_test", - testonly = 1, - srcs = ["getrusage.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:memory_util", - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "exec_binary_test", - testonly = 1, - srcs = select_arch( - amd64 = ["exec_binary.cc"], - arm64 = [], - ), - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:proc_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "exec_test", - testonly = 1, - srcs = [ - "exec.cc", - "exec.h", - ], - data = [ - ":exec_assert_closed_workload", - ":exec_basic_workload", - ":exec_proc_exe_workload", - ":exec_state_workload", - ":exit_script", - ":priority_execve", - ], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", - gtest, - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "exit_test", - testonly = 1, - srcs = ["exit.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:time_util", - ], -) - -cc_binary( - name = "fallocate_test", - testonly = 1, - srcs = ["fallocate.cc"], - linkstatic = 1, - deps = [ - ":file_base", - "//test/util:cleanup", - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "fault_test", - testonly = 1, - srcs = ["fault.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "fchdir_test", - testonly = 1, - srcs = ["fchdir.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "fcntl_test", - testonly = 1, - srcs = ["fcntl.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - "//test/util:cleanup", - "//test/util:eventfd_util", - "//test/util:fs_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:save_util", - "//test/util:temp_path", - "//test/util:test_util", - "//test/util:timer_util", - ], -) - -cc_binary( - name = "flock_test", - testonly = 1, - srcs = [ - "file_base.h", - "flock.cc", - ], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:timer_util", - ], -) - -cc_binary( - name = "fork_test", - testonly = 1, - srcs = ["fork.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:memory_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "fpsig_fork_test", - testonly = 1, - srcs = select_arch( - amd64 = ["fpsig_fork.cc"], - arm64 = [], - ), - linkstatic = 1, - deps = [ - gtest, - "//test/util:logging", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "fpsig_nested_test", - testonly = 1, - srcs = select_arch( - amd64 = ["fpsig_nested.cc"], - arm64 = [], - ), - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "sync_file_range_test", - testonly = 1, - srcs = ["sync_file_range.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "fsync_test", - testonly = 1, - srcs = ["fsync.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "futex_test", - testonly = 1, - srcs = ["futex.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "//test/util:file_descriptor", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/time", - gtest, - "//test/util:memory_util", - "//test/util:save_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:time_util", - "//test/util:timer_util", - ], -) - -cc_binary( - name = "getdents_test", - testonly = 1, - srcs = ["getdents.cc"], - linkstatic = 1, - deps = [ - "//test/util:eventfd_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "getrandom_test", - testonly = 1, - srcs = ["getrandom.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "inotify_test", - testonly = 1, - srcs = ["inotify.cc"], - linkstatic = 1, - deps = [ - "//test/util:epoll_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", - ], -) - -cc_binary( - name = "ioctl_test", - testonly = 1, - srcs = ["ioctl.cc"], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:file_descriptor", - gtest, - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_library( - name = "iptables_types", - testonly = 1, - hdrs = [ - "iptables.h", - ], -) - -cc_binary( - name = "iptables_test", - testonly = 1, - srcs = [ - "iptables.cc", - ], - linkstatic = 1, - deps = [ - ":iptables_types", - ":socket_test_util", - "//test/util:capability_util", - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "itimer_test", - testonly = 1, - srcs = ["itimer.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:signal_util", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:timer_util", - ], -) - -cc_binary( - name = "kill_test", - testonly = 1, - srcs = ["kill.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "link_test", - testonly = 1, - srcs = ["link.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "lseek_test", - testonly = 1, - srcs = ["lseek.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "madvise_test", - testonly = 1, - srcs = ["madvise.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:logging", - "//test/util:memory_util", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "mempolicy_test", - testonly = 1, - srcs = ["mempolicy.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "@com_google_absl//absl/memory", - gtest, - "//test/util:memory_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "mincore_test", - testonly = 1, - srcs = ["mincore.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:memory_util", - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "mkdir_test", - testonly = 1, - srcs = ["mkdir.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:fs_util", - gtest, - "//test/util:temp_path", - "//test/util:temp_umask", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "mknod_test", - testonly = 1, - srcs = ["mknod.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "mlock_test", - testonly = 1, - srcs = ["mlock.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:cleanup", - gtest, - "//test/util:memory_util", - "//test/util:multiprocess_util", - "//test/util:rlimit_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "mmap_test", - testonly = 1, - srcs = ["mmap.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:memory_util", - "//test/util:multiprocess_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "mount_test", - testonly = 1, - srcs = ["mount.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:mount_util", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "mremap_test", - testonly = 1, - srcs = ["mremap.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:logging", - "//test/util:memory_util", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "msync_test", - testonly = 1, - srcs = ["msync.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "//test/util:memory_util", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "munmap_test", - testonly = 1, - srcs = ["munmap.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "open_test", - testonly = 1, - srcs = [ - "file_base.h", - "open.cc", - ], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "open_create_test", - testonly = 1, - srcs = ["open_create.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - gtest, - "//test/util:temp_path", - "//test/util:temp_umask", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "packet_socket_raw_test", - testonly = 1, - srcs = ["packet_socket_raw.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:capability_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:endian", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "packet_socket_test", - testonly = 1, - srcs = ["packet_socket.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:capability_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:endian", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "pty_test", - testonly = 1, - srcs = ["pty.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:posix_error", - "//test/util:pty_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "pty_root_test", - testonly = 1, - srcs = ["pty_root.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/base:core_headers", - gtest, - "//test/util:posix_error", - "//test/util:pty_util", - "//test/util:test_main", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "partial_bad_buffer_test", - testonly = 1, - srcs = ["partial_bad_buffer.cc"], - linkstatic = 1, - deps = [ - "//test/syscalls/linux:socket_test_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/time", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "pause_test", - testonly = 1, - srcs = ["pause.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "pipe_test", - testonly = 1, - srcs = ["pipe.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "poll_test", - testonly = 1, - srcs = ["poll.cc"], - linkstatic = 1, - deps = [ - ":base_poll_test", - "//test/util:eventfd_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "ppoll_test", - testonly = 1, - srcs = ["ppoll.cc"], - linkstatic = 1, - deps = [ - ":base_poll_test", - "@com_google_absl//absl/time", - gtest, - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "arch_prctl_test", - testonly = 1, - srcs = select_arch( - amd64 = ["arch_prctl.cc"], - arm64 = [], - ), - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "prctl_test", - testonly = 1, - srcs = ["prctl.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:cleanup", - "@com_google_absl//absl/flags:flag", - gtest, - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "prctl_setuid_test", - testonly = 1, - srcs = ["prctl_setuid.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "@com_google_absl//absl/flags:flag", - gtest, - "//test/util:logging", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "pread64_test", - testonly = 1, - srcs = ["pread64.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "preadv_test", - testonly = 1, - srcs = ["preadv.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:memory_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:timer_util", - ], -) - -cc_binary( - name = "preadv2_test", - testonly = 1, - srcs = [ - "file_base.h", - "preadv2.cc", - ], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "priority_test", - testonly = 1, - srcs = ["priority.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "proc_test", - testonly = 1, - srcs = ["proc.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:memory_util", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:time_util", - "//test/util:timer_util", - ], -) - -cc_binary( - name = "proc_net_test", - testonly = 1, - srcs = ["proc_net.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "proc_pid_oomscore_test", - testonly = 1, - srcs = ["proc_pid_oomscore.cc"], - linkstatic = 1, - deps = [ - "//test/util:fs_util", - "//test/util:test_main", - "//test/util:test_util", - "@com_google_absl//absl/strings", - ], -) - -cc_binary( - name = "proc_pid_smaps_test", - testonly = 1, - srcs = ["proc_pid_smaps.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", - gtest, - "//test/util:memory_util", - "//test/util:posix_error", - "//test/util:proc_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "proc_pid_uid_gid_map_test", - testonly = 1, - srcs = ["proc_pid_uid_gid_map.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:logging", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:save_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:time_util", - ], -) - -cc_binary( - name = "pselect_test", - testonly = 1, - srcs = ["pselect.cc"], - linkstatic = 1, - deps = [ - ":base_poll_test", - "@com_google_absl//absl/time", - gtest, - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "ptrace_test", - testonly = 1, - srcs = ["ptrace.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:multiprocess_util", - "//test/util:platform_util", - "//test/util:signal_util", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:time_util", - ], -) - -cc_binary( - name = "pwrite64_test", - testonly = 1, - srcs = ["pwrite64.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "pwritev2_test", - testonly = 1, - srcs = [ - "pwritev2.cc", - ], - linkstatic = 1, - deps = [ - ":file_base", - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "raw_socket_hdrincl_test", - testonly = 1, - srcs = ["raw_socket_hdrincl.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:capability_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:endian", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "raw_socket_ipv4_test", - testonly = 1, - srcs = ["raw_socket_ipv4.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:capability_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/base:core_headers", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "raw_socket_icmp_test", - testonly = 1, - srcs = ["raw_socket_icmp.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:capability_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/base:core_headers", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "read_test", - testonly = 1, - srcs = ["read.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "readahead_test", - testonly = 1, - srcs = ["readahead.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "readv_test", - testonly = 1, - srcs = [ - "file_base.h", - "readv.cc", - "readv_common.cc", - "readv_common.h", - ], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:timer_util", - ], -) - -cc_binary( - name = "readv_socket_test", - testonly = 1, - srcs = [ - "readv_common.cc", - "readv_common.h", - "readv_socket.cc", - ], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "rename_test", - testonly = 1, - srcs = ["rename.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "rlimits_test", - testonly = 1, - srcs = ["rlimits.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "rseq_test", - testonly = 1, - srcs = ["rseq.cc"], - data = ["//test/syscalls/linux/rseq"], - linkstatic = 1, - deps = [ - "//test/syscalls/linux/rseq:lib", - gtest, - "//test/util:logging", - "//test/util:multiprocess_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "rtsignal_test", - testonly = 1, - srcs = ["rtsignal.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - gtest, - "//test/util:logging", - "//test/util:posix_error", - "//test/util:signal_util", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sched_test", - testonly = 1, - srcs = ["sched.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sched_yield_test", - testonly = 1, - srcs = ["sched_yield.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "seccomp_test", - testonly = 1, - srcs = ["seccomp.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/base:core_headers", - gtest, - "//test/util:logging", - "//test/util:memory_util", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:proc_util", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "select_test", - testonly = 1, - srcs = ["select.cc"], - linkstatic = 1, - deps = [ - ":base_poll_test", - "//test/util:file_descriptor", - "@com_google_absl//absl/time", - gtest, - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:rlimit_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sendfile_test", - testonly = 1, - srcs = ["sendfile.cc"], - linkstatic = 1, - deps = [ - "//test/util:eventfd_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "sendfile_socket_test", - testonly = 1, - srcs = ["sendfile_socket.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "splice_test", - testonly = 1, - srcs = ["splice.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "sigaction_test", - testonly = 1, - srcs = ["sigaction.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sigaltstack_test", - testonly = 1, - srcs = ["sigaltstack.cc"], - data = [ - ":sigaltstack_check", - ], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "//test/util:fs_util", - gtest, - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "sigiret_test", - testonly = 1, - srcs = select_arch( - amd64 = ["sigiret.cc"], - arm64 = [], - ), - linkstatic = 1, - deps = [ - gtest, - "//test/util:logging", - "//test/util:signal_util", - "//test/util:test_util", - "//test/util:timer_util", - ] + select_arch( - amd64 = [], - arm64 = ["//test/util:test_main"], - ), -) - -cc_binary( - name = "signalfd_test", - testonly = 1, - srcs = ["signalfd.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/synchronization", - gtest, - "//test/util:logging", - "//test/util:posix_error", - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "sigprocmask_test", - testonly = 1, - srcs = ["sigprocmask.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sigstop_test", - testonly = 1, - srcs = ["sigstop.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", - gtest, - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "sigtimedwait_test", - testonly = 1, - srcs = ["sigtimedwait.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:signal_util", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:timer_util", - ], -) - -cc_library( - name = "socket_generic_test_cases", - testonly = 1, - srcs = [ - "socket_generic.cc", - ], - hdrs = [ - "socket_generic.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_binary( - name = "socket_stress_test", - testonly = 1, - srcs = [ - "socket_generic_stress.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_library( - name = "socket_unix_dgram_test_cases", - testonly = 1, - srcs = ["socket_unix_dgram.cc"], - hdrs = ["socket_unix_dgram.h"], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_unix_seqpacket_test_cases", - testonly = 1, - srcs = ["socket_unix_seqpacket.cc"], - hdrs = ["socket_unix_seqpacket.h"], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_ip_tcp_generic_test_cases", - testonly = 1, - srcs = [ - "socket_ip_tcp_generic.cc", - ], - hdrs = [ - "socket_ip_tcp_generic.h", - ], - deps = [ - ":socket_test_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_util", - "//test/util:thread_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_non_blocking_test_cases", - testonly = 1, - srcs = [ - "socket_non_blocking.cc", - ], - hdrs = [ - "socket_non_blocking.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_unix_non_stream_test_cases", - testonly = 1, - srcs = [ - "socket_unix_non_stream.cc", - ], - hdrs = [ - "socket_unix_non_stream.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:memory_util", - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_non_stream_test_cases", - testonly = 1, - srcs = [ - "socket_non_stream.cc", - ], - hdrs = [ - "socket_non_stream.h", - ], - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_ip_udp_test_cases", - testonly = 1, - srcs = [ - "socket_ip_udp_generic.cc", - ], - hdrs = [ - "socket_ip_udp_generic.h", - ], - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_ipv4_udp_unbound_test_cases", - testonly = 1, - srcs = [ - "socket_ipv4_udp_unbound.cc", - ], - hdrs = [ - "socket_ipv4_udp_unbound.h", - ], - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - "@com_google_absl//absl/memory", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_ipv4_udp_unbound_external_networking_test_cases", - testonly = 1, - srcs = [ - "socket_ipv4_udp_unbound_external_networking.cc", - ], - hdrs = [ - "socket_ipv4_udp_unbound_external_networking.h", - ], - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_ipv4_tcp_unbound_external_networking_test_cases", - testonly = 1, - srcs = [ - "socket_ipv4_tcp_unbound_external_networking.cc", - ], - hdrs = [ - "socket_ipv4_tcp_unbound_external_networking.h", - ], - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_binary( - name = "socket_abstract_test", - testonly = 1, - srcs = [ - "socket_abstract.cc", - ], - linkstatic = 1, - deps = [ - ":socket_generic_test_cases", - ":socket_test_util", - ":socket_unix_cmsg_test_cases", - ":socket_unix_test_cases", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_abstract_non_blocking_test", - testonly = 1, - srcs = [ - "socket_unix_abstract_nonblock.cc", - ], - linkstatic = 1, - deps = [ - ":socket_non_blocking_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_dgram_local_test", - testonly = 1, - srcs = ["socket_unix_dgram_local.cc"], - linkstatic = 1, - deps = [ - ":socket_non_stream_test_cases", - ":socket_test_util", - ":socket_unix_dgram_test_cases", - ":socket_unix_non_stream_test_cases", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_dgram_non_blocking_test", - testonly = 1, - srcs = ["socket_unix_dgram_non_blocking.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_seqpacket_local_test", - testonly = 1, - srcs = [ - "socket_unix_seqpacket_local.cc", - ], - linkstatic = 1, - deps = [ - ":socket_non_stream_test_cases", - ":socket_test_util", - ":socket_unix_non_stream_test_cases", - ":socket_unix_seqpacket_test_cases", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_stream_test", - testonly = 1, - srcs = ["socket_unix_stream.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ip_tcp_generic_loopback_test", - testonly = 1, - srcs = [ - "socket_ip_tcp_generic_loopback.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_ip_tcp_generic_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ip_tcp_udp_generic_loopback_test", - testonly = 1, - srcs = [ - "socket_ip_tcp_udp_generic.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ip_tcp_loopback_test", - testonly = 1, - srcs = [ - "socket_ip_tcp_loopback.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_generic_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ip_tcp_loopback_non_blocking_test", - testonly = 1, - srcs = [ - "socket_ip_tcp_loopback_nonblock.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_non_blocking_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ip_udp_loopback_test", - testonly = 1, - srcs = [ - "socket_ip_udp_loopback.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_generic_test_cases", - ":socket_ip_udp_test_cases", - ":socket_non_stream_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ipv4_udp_unbound_external_networking_test", - testonly = 1, - srcs = [ - "socket_ipv4_udp_unbound_external_networking_test.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_ipv4_udp_unbound_external_networking_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ipv4_tcp_unbound_external_networking_test", - testonly = 1, - srcs = [ - "socket_ipv4_tcp_unbound_external_networking_test.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_ipv4_tcp_unbound_external_networking_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_bind_to_device_test", - testonly = 1, - srcs = [ - "socket_bind_to_device.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_bind_to_device_util", - ":socket_test_util", - "//test/util:capability_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "socket_bind_to_device_sequence_test", - testonly = 1, - srcs = [ - "socket_bind_to_device_sequence.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_bind_to_device_util", - ":socket_test_util", - "//test/util:capability_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "socket_bind_to_device_distribution_test", - testonly = 1, - srcs = [ - "socket_bind_to_device_distribution.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_bind_to_device_util", - ":socket_test_util", - "//test/util:capability_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "socket_ip_udp_loopback_non_blocking_test", - testonly = 1, - srcs = [ - "socket_ip_udp_loopback_nonblock.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_non_blocking_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ipv4_udp_unbound_loopback_test", - testonly = 1, - srcs = [ - "socket_ipv4_udp_unbound_loopback.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_ipv4_udp_unbound_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_ip_unbound_test", - testonly = 1, - srcs = [ - "socket_ip_unbound.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_domain_test", - testonly = 1, - srcs = [ - "socket_unix_domain.cc", - ], - linkstatic = 1, - deps = [ - ":socket_generic_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_domain_non_blocking_test", - testonly = 1, - srcs = [ - "socket_unix_pair_nonblock.cc", - ], - linkstatic = 1, - deps = [ - ":socket_non_blocking_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_filesystem_test", - testonly = 1, - srcs = [ - "socket_filesystem.cc", - ], - linkstatic = 1, - deps = [ - ":socket_generic_test_cases", - ":socket_test_util", - ":socket_unix_cmsg_test_cases", - ":socket_unix_test_cases", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_filesystem_non_blocking_test", - testonly = 1, - srcs = [ - "socket_unix_filesystem_nonblock.cc", - ], - linkstatic = 1, - deps = [ - ":socket_non_blocking_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_inet_loopback_test", - testonly = 1, - srcs = ["socket_inet_loopback.cc"], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_test_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:posix_error", - "//test/util:save_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "socket_netlink_test", - testonly = 1, - srcs = ["socket_netlink.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_netlink_route_test", - testonly = 1, - srcs = ["socket_netlink_route.cc"], - linkstatic = 1, - deps = [ - ":socket_netlink_util", - ":socket_test_util", - "//test/util:capability_util", - "//test/util:cleanup", - "//test/util:file_descriptor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_netlink_uevent_test", - testonly = 1, - srcs = ["socket_netlink_uevent.cc"], - linkstatic = 1, - deps = [ - ":socket_netlink_util", - ":socket_test_util", - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -# These socket tests are in a library because the test cases are shared -# across several test build targets. -cc_library( - name = "socket_stream_test_cases", - testonly = 1, - srcs = [ - "socket_stream.cc", - ], - hdrs = [ - "socket_stream.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_blocking_test_cases", - testonly = 1, - srcs = [ - "socket_blocking.cc", - ], - hdrs = [ - "socket_blocking.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:timer_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_unix_test_cases", - testonly = 1, - srcs = [ - "socket_unix.cc", - ], - hdrs = [ - "socket_unix.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:test_util", - "//test/util:thread_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_unix_cmsg_test_cases", - testonly = 1, - srcs = [ - "socket_unix_cmsg.cc", - ], - hdrs = [ - "socket_unix_cmsg.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:test_util", - "//test/util:thread_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_stream_blocking_test_cases", - testonly = 1, - srcs = [ - "socket_stream_blocking.cc", - ], - hdrs = [ - "socket_stream_blocking.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:timer_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_stream_nonblocking_test_cases", - testonly = 1, - srcs = [ - "socket_stream_nonblock.cc", - ], - hdrs = [ - "socket_stream_nonblock.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_non_stream_blocking_test_cases", - testonly = 1, - srcs = [ - "socket_non_stream_blocking.cc", - ], - hdrs = [ - "socket_non_stream_blocking.h", - ], - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_util", - "//test/util:thread_util", - ], - alwayslink = 1, -) - -cc_library( - name = "socket_bind_to_device_util", - testonly = 1, - srcs = [ - "socket_bind_to_device_util.cc", - ], - hdrs = [ - "socket_bind_to_device_util.h", - ], - deps = [ - "//test/util:test_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], - alwayslink = 1, -) - -cc_binary( - name = "socket_stream_local_test", - testonly = 1, - srcs = [ - "socket_unix_stream_local.cc", - ], - linkstatic = 1, - deps = [ - ":socket_stream_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_stream_blocking_local_test", - testonly = 1, - srcs = [ - "socket_unix_stream_blocking_local.cc", - ], - linkstatic = 1, - deps = [ - ":socket_stream_blocking_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_stream_blocking_tcp_test", - testonly = 1, - srcs = [ - "socket_ip_tcp_loopback_blocking.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_stream_blocking_test_cases", - ":socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_stream_nonblock_local_test", - testonly = 1, - srcs = [ - "socket_unix_stream_nonblock_local.cc", - ], - linkstatic = 1, - deps = [ - ":socket_stream_nonblocking_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_unbound_dgram_test", - testonly = 1, - srcs = ["socket_unix_unbound_dgram.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_unbound_abstract_test", - testonly = 1, - srcs = ["socket_unix_unbound_abstract.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_unbound_filesystem_test", - testonly = 1, - srcs = ["socket_unix_unbound_filesystem.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_blocking_local_test", - testonly = 1, - srcs = [ - "socket_unix_blocking_local.cc", - ], - linkstatic = 1, - deps = [ - ":socket_blocking_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_blocking_ip_test", - testonly = 1, - srcs = [ - "socket_ip_loopback_blocking.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_blocking_test_cases", - ":socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_non_stream_blocking_local_test", - testonly = 1, - srcs = [ - "socket_unix_non_stream_blocking_local.cc", - ], - linkstatic = 1, - deps = [ - ":socket_non_stream_blocking_test_cases", - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_non_stream_blocking_udp_test", - testonly = 1, - srcs = [ - "socket_ip_udp_loopback_blocking.cc", - ], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - ":socket_non_stream_blocking_test_cases", - ":socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_pair_test", - testonly = 1, - srcs = [ - "socket_unix_pair.cc", - ], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":socket_unix_cmsg_test_cases", - ":socket_unix_test_cases", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_unbound_seqpacket_test", - testonly = 1, - srcs = ["socket_unix_unbound_seqpacket.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_unix_unbound_stream_test", - testonly = 1, - srcs = ["socket_unix_unbound_stream.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "socket_netdevice_test", - testonly = 1, - srcs = ["socket_netdevice.cc"], - linkstatic = 1, - deps = [ - ":socket_netlink_util", - ":socket_test_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/base:endian", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "stat_test", - testonly = 1, - srcs = [ - "file_base.h", - "stat.cc", - ], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "stat_times_test", - testonly = 1, - srcs = ["stat_times.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/time", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "statfs_test", - testonly = 1, - srcs = [ - "file_base.h", - "statfs.cc", - ], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "symlink_test", - testonly = 1, - srcs = ["symlink.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sync_test", - testonly = 1, - srcs = ["sync.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sysinfo_test", - testonly = 1, - srcs = ["sysinfo.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "syslog_test", - testonly = 1, - srcs = ["syslog.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "sysret_test", - testonly = 1, - srcs = select_arch( - amd64 = ["sysret.cc"], - arm64 = [], - ), - linkstatic = 1, - deps = [ - gtest, - "//test/util:logging", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "tcp_socket_test", - testonly = 1, - srcs = ["tcp_socket.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/time", - gtest, - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "tgkill_test", - testonly = 1, - srcs = ["tgkill.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "time_test", - testonly = 1, - srcs = ["time.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:proc_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "timerfd_test", - testonly = 1, - srcs = ["timerfd.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - "@com_google_absl//absl/time", - ], -) - -cc_binary( - name = "timers_test", - testonly = 1, - srcs = ["timers.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:signal_util", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "tkill_test", - testonly = 1, - srcs = ["tkill.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:logging", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "truncate_test", - testonly = 1, - srcs = ["truncate.cc"], - linkstatic = 1, - deps = [ - ":file_base", - "//test/util:capability_util", - "//test/util:cleanup", - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "tuntap_test", - testonly = 1, - srcs = ["tuntap.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - gtest, - "//test/syscalls/linux:socket_netlink_route_util", - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - "@com_google_absl//absl/strings", - ], -) - -cc_binary( - name = "tuntap_hostinet_test", - testonly = 1, - srcs = ["tuntap_hostinet.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_library( - name = "udp_socket_test_cases", - testonly = 1, - srcs = [ - "udp_socket_errqueue_test_case.cc", - "udp_socket_test_cases.cc", - ], - hdrs = ["udp_socket_test_cases.h"], - defines = select_system(), - deps = [ - ":socket_test_util", - ":unix_domain_socket_test_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], - alwayslink = 1, -) - -cc_binary( - name = "udp_socket_test", - testonly = 1, - srcs = ["udp_socket.cc"], - linkstatic = 1, - deps = [ - ":udp_socket_test_cases", - ], -) - -cc_binary( - name = "udp_bind_test", - testonly = 1, - srcs = ["udp_bind.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - "//test/util:file_descriptor", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "uidgid_test", - testonly = 1, - srcs = ["uidgid.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:uid_util", - ], -) - -cc_binary( - name = "uname_test", - testonly = 1, - srcs = ["uname.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "unlink_test", - testonly = 1, - srcs = ["unlink.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "unshare_test", - testonly = 1, - srcs = ["unshare.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/synchronization", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "utimes_test", - testonly = 1, - srcs = ["utimes.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "//test/util:fs_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "@com_google_absl//absl/time", - ], -) - -cc_binary( - name = "vdso_test", - testonly = 1, - srcs = ["vdso.cc"], - linkstatic = 1, - deps = [ - "//test/util:fs_util", - gtest, - "//test/util:posix_error", - "//test/util:proc_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "vfork_test", - testonly = 1, - srcs = ["vfork.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:multiprocess_util", - "//test/util:test_util", - "//test/util:time_util", - ], -) - -cc_binary( - name = "wait_test", - testonly = 1, - srcs = ["wait.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:logging", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:signal_util", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - "//test/util:time_util", - ], -) - -cc_binary( - name = "write_test", - testonly = 1, - srcs = ["write.cc"], - linkstatic = 1, - deps = [ - "//test/util:cleanup", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "memory_accounting_test", - testonly = 1, - srcs = ["memory_accounting.cc"], - linkstatic = 1, - deps = [ - "//test/util:fs_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - gtest, - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "network_namespace_test", - testonly = 1, - srcs = ["network_namespace.cc"], - linkstatic = 1, - deps = [ - ":socket_test_util", - gtest, - "//test/util:capability_util", - "//test/util:posix_error", - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "semaphore_test", - testonly = 1, - srcs = ["semaphore.cc"], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - "//test/util:thread_util", - ], -) - -cc_binary( - name = "shm_test", - testonly = 1, - srcs = ["shm.cc"], - linkstatic = 1, - deps = [ - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - "@com_google_absl//absl/time", - ], -) - -cc_binary( - name = "fadvise64_test", - testonly = 1, - srcs = ["fadvise64.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - gtest, - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "vdso_clock_gettime_test", - testonly = 1, - srcs = ["vdso_clock_gettime.cc"], - linkstatic = 1, - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "vsyscall_test", - testonly = 1, - srcs = ["vsyscall.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:proc_util", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "proc_net_unix_test", - testonly = 1, - srcs = ["proc_net_unix.cc"], - linkstatic = 1, - deps = [ - ":unix_domain_socket_test_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "memfd_test", - testonly = 1, - srcs = ["memfd.cc"], - linkstatic = 1, - deps = [ - "//test/util:file_descriptor", - "//test/util:fs_util", - gtest, - "//test/util:memory_util", - "//test/util:multiprocess_util", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "proc_net_tcp_test", - testonly = 1, - srcs = ["proc_net_tcp.cc"], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "proc_net_udp_test", - testonly = 1, - srcs = ["proc_net_udp.cc"], - linkstatic = 1, - deps = [ - ":ip_socket_test_util", - "//test/util:file_descriptor", - "@com_google_absl//absl/strings", - gtest, - "//test/util:test_main", - "//test/util:test_util", - ], -) - -cc_binary( - name = "xattr_test", - testonly = 1, - srcs = [ - "file_base.h", - "xattr.cc", - ], - linkstatic = 1, - deps = [ - "//test/util:capability_util", - "//test/util:file_descriptor", - "//test/util:fs_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - gtest, - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_main", - "//test/util:test_util", - ], -) diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc deleted file mode 100644 index e08c578f0..000000000 --- a/test/syscalls/linux/accept_bind.cc +++ /dev/null @@ -1,599 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/un.h> - -#include <algorithm> -#include <vector> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(AllSocketPairTest, Listen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 5), - SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, ListenIncreaseBacklog) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 5), - SyscallSucceeds()); - ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 10), - SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, ListenDecreaseBacklog) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 5), - SyscallSucceeds()); - ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 1), - SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, ListenWithoutBind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, DoubleBind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->second_addr(), - sockets->second_addr_size()), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, BindListenBind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->second_addr(), - sockets->second_addr_size()), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, DoubleListen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, DoubleConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallFailsWithErrno(EISCONN)); -} - -TEST_P(AllSocketPairTest, Connect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, ConnectNonListening) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallFailsWithErrno(ECONNREFUSED)); -} - -TEST_P(AllSocketPairTest, ConnectToFilePath) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct sockaddr_un addr = {}; - addr.sun_family = AF_UNIX; - constexpr char kPath[] = "/tmp"; - memcpy(addr.sun_path, kPath, sizeof(kPath)); - - ASSERT_THAT( - connect(sockets->second_fd(), - reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(ECONNREFUSED)); -} - -TEST_P(AllSocketPairTest, ConnectToInvalidAbstractPath) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct sockaddr_un addr = {}; - addr.sun_family = AF_UNIX; - constexpr char kPath[] = "\0nonexistent"; - memcpy(addr.sun_path, kPath, sizeof(kPath)); - - ASSERT_THAT( - connect(sockets->second_fd(), - reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(ECONNREFUSED)); -} - -TEST_P(AllSocketPairTest, SelfConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, ConnectWithoutListen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallFailsWithErrno(ECONNREFUSED)); -} - -TEST_P(AllSocketPairTest, Accept) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - int accepted = -1; - ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr), - SyscallSucceeds()); - ASSERT_THAT(close(accepted), SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, AcceptValidAddrLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - int accepted = -1; - struct sockaddr_un addr = {}; - socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - accepted = accept(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(&addr), &addr_len), - SyscallSucceeds()); - ASSERT_THAT(close(accepted), SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, AcceptNegativeAddrLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - // With a negative addr_len, accept returns EINVAL, - struct sockaddr_un addr = {}; - socklen_t addr_len = -1; - ASSERT_THAT(accept(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(&addr), &addr_len), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, AcceptLargePositiveAddrLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - // With a large (positive) addr_len, accept does not return EINVAL. - int accepted = -1; - char addr_buf[200]; - socklen_t addr_len = sizeof(addr_buf); - ASSERT_THAT(accepted = accept(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(addr_buf), - &addr_len), - SyscallSucceeds()); - // addr_len should have been updated by accept(). - EXPECT_LT(addr_len, sizeof(addr_buf)); - ASSERT_THAT(close(accepted), SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, AcceptVeryLargePositiveAddrLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - // With a large (positive) addr_len, accept does not return EINVAL. - int accepted = -1; - char addr_buf[2000]; - socklen_t addr_len = sizeof(addr_buf); - ASSERT_THAT(accepted = accept(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(addr_buf), - &addr_len), - SyscallSucceeds()); - // addr_len should have been updated by accept(). - EXPECT_LT(addr_len, sizeof(addr_buf)); - ASSERT_THAT(close(accepted), SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, AcceptWithoutBind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(accept(sockets->first_fd(), nullptr, nullptr), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, AcceptWithoutListen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(accept(sockets->first_fd(), nullptr, nullptr), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, GetRemoteAddress) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - socklen_t addr_len = sockets->first_addr_size(); - struct sockaddr_storage addr = {}; - ASSERT_THAT( - getpeername(sockets->second_fd(), (struct sockaddr*)(&addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, sockets->first_addr_len()); - EXPECT_EQ(0, memcmp(&addr, sockets->first_addr(), sockets->first_addr_len())); -} - -TEST_P(AllSocketPairTest, UnboundGetLocalAddress) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - socklen_t addr_len = sockets->first_addr_size(); - struct sockaddr_storage addr = {}; - ASSERT_THAT( - getsockname(sockets->second_fd(), (struct sockaddr*)(&addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, 2); - EXPECT_EQ( - memcmp(&addr, sockets->second_addr(), - std::min((size_t)addr_len, (size_t)sockets->second_addr_len())), - 0); -} - -TEST_P(AllSocketPairTest, BoundGetLocalAddress) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(), - sockets->second_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - socklen_t addr_len = sockets->first_addr_size(); - struct sockaddr_storage addr = {}; - ASSERT_THAT( - getsockname(sockets->second_fd(), (struct sockaddr*)(&addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, sockets->second_addr_len()); - EXPECT_EQ( - memcmp(&addr, sockets->second_addr(), - std::min((size_t)addr_len, (size_t)sockets->second_addr_len())), - 0); -} - -TEST_P(AllSocketPairTest, BoundConnector) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(), - sockets->second_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, UnboundSenderAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - int accepted = -1; - ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr), - SyscallSucceeds()); - FileDescriptor accepted_fd(accepted); - - int i = 0; - ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(i))); - EXPECT_EQ(addr_len, 0); -} - -TEST_P(AllSocketPairTest, BoundSenderAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(), - sockets->second_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - int accepted = -1; - ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr), - SyscallSucceeds()); - FileDescriptor accepted_fd(accepted); - - int i = 0; - ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(i))); - EXPECT_EQ(addr_len, sockets->second_addr_len()); - EXPECT_EQ( - memcmp(&addr, sockets->second_addr(), - std::min((size_t)addr_len, (size_t)sockets->second_addr_len())), - 0); -} - -TEST_P(AllSocketPairTest, BindAfterConnectSenderAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(), - sockets->second_addr_size()), - SyscallSucceeds()); - - int accepted = -1; - ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr), - SyscallSucceeds()); - FileDescriptor accepted_fd(accepted); - - int i = 0; - ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(i))); - EXPECT_EQ(addr_len, sockets->second_addr_len()); - EXPECT_EQ( - memcmp(&addr, sockets->second_addr(), - std::min((size_t)addr_len, (size_t)sockets->second_addr_len())), - 0); -} - -TEST_P(AllSocketPairTest, BindAfterAcceptSenderAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - int accepted = -1; - ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr), - SyscallSucceeds()); - FileDescriptor accepted_fd(accepted); - - ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(), - sockets->second_addr_size()), - SyscallSucceeds()); - - int i = 0; - ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(i))); - EXPECT_EQ(addr_len, sockets->second_addr_len()); - EXPECT_EQ( - memcmp(&addr, sockets->second_addr(), - std::min((size_t)addr_len, (size_t)sockets->second_addr_len())), - 0); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, AllSocketPairTest, - ::testing::ValuesIn(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>( - FilesystemUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK}))))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/accept_bind_stream.cc b/test/syscalls/linux/accept_bind_stream.cc deleted file mode 100644 index 4857f160b..000000000 --- a/test/syscalls/linux/accept_bind_stream.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/un.h> - -#include <algorithm> -#include <vector> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(AllSocketPairTest, BoundSenderAddrCoalesced) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - int accepted = -1; - ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr), - SyscallSucceeds()); - FileDescriptor closer(accepted); - - int i = 0; - ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - - ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(), - sockets->second_addr_size()), - SyscallSucceeds()); - - i = 0; - ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - - int ri[2] = {0, 0}; - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted, ri, sizeof(ri), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(ri))); - EXPECT_EQ(addr_len, sockets->second_addr_len()); - - EXPECT_EQ( - memcmp(&addr, sockets->second_addr(), - std::min((size_t)addr_len, (size_t)sockets->second_addr_len())), - 0); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, AllSocketPairTest, - ::testing::ValuesIn(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>(FilesystemUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, SOCK_NONBLOCK}))))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/access.cc b/test/syscalls/linux/access.cc deleted file mode 100644 index bcc25cef4..000000000 --- a/test/syscalls/linux/access.cc +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <stdlib.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/capability_util.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -using ::testing::Ge; - -namespace gvisor { -namespace testing { - -namespace { - -class AccessTest : public ::testing::Test { - public: - std::string CreateTempFile(int perm) { - const std::string path = NewTempAbsPath(); - const int fd = open(path.c_str(), O_CREAT | O_RDONLY, perm); - TEST_PCHECK(fd > 0); - TEST_PCHECK(close(fd) == 0); - return path; - } - - protected: - // SetUp creates various configurations of files. - void SetUp() override { - // Move to the temporary directory. This allows us to reason more easily - // about absolute and relative paths. - ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); - - // Create an empty file, standard permissions. - relfile_ = NewTempRelPath(); - int fd; - ASSERT_THAT(fd = open(relfile_.c_str(), O_CREAT | O_TRUNC, 0644), - SyscallSucceedsWithValue(Ge(0))); - ASSERT_THAT(close(fd), SyscallSucceeds()); - absfile_ = GetAbsoluteTestTmpdir() + "/" + relfile_; - - // Create an empty directory, no writable permissions. - absdir_ = NewTempAbsPath(); - reldir_ = JoinPath(Basename(absdir_), ""); - ASSERT_THAT(mkdir(reldir_.c_str(), 0555), SyscallSucceeds()); - - // This file doesn't exist. - relnone_ = NewTempRelPath(); - absnone_ = GetAbsoluteTestTmpdir() + "/" + relnone_; - } - - // TearDown unlinks created files. - void TearDown() override { - ASSERT_THAT(unlink(absfile_.c_str()), SyscallSucceeds()); - ASSERT_THAT(rmdir(absdir_.c_str()), SyscallSucceeds()); - } - - std::string relfile_; - std::string reldir_; - - std::string absfile_; - std::string absdir_; - - std::string relnone_; - std::string absnone_; -}; - -TEST_F(AccessTest, RelativeFile) { - EXPECT_THAT(access(relfile_.c_str(), R_OK), SyscallSucceeds()); -} - -TEST_F(AccessTest, RelativeDir) { - EXPECT_THAT(access(reldir_.c_str(), R_OK | X_OK), SyscallSucceeds()); -} - -TEST_F(AccessTest, AbsFile) { - EXPECT_THAT(access(absfile_.c_str(), R_OK), SyscallSucceeds()); -} - -TEST_F(AccessTest, AbsDir) { - EXPECT_THAT(access(absdir_.c_str(), R_OK | X_OK), SyscallSucceeds()); -} - -TEST_F(AccessTest, RelDoesNotExist) { - EXPECT_THAT(access(relnone_.c_str(), R_OK), SyscallFailsWithErrno(ENOENT)); -} - -TEST_F(AccessTest, AbsDoesNotExist) { - EXPECT_THAT(access(absnone_.c_str(), R_OK), SyscallFailsWithErrno(ENOENT)); -} - -TEST_F(AccessTest, InvalidMode) { - EXPECT_THAT(access(relfile_.c_str(), 0xffffffff), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(AccessTest, NoPerms) { - // Drop capabilities that allow us to override permissions. We must drop - // PERMITTED because access() checks those instead of EFFECTIVE. - ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_OVERRIDE)); - ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_READ_SEARCH)); - - EXPECT_THAT(access(absdir_.c_str(), W_OK), SyscallFailsWithErrno(EACCES)); -} - -TEST_F(AccessTest, InvalidName) { - EXPECT_THAT(access(reinterpret_cast<char*>(0x1234), W_OK), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(AccessTest, UsrReadOnly) { - // Drop capabilities that allow us to override permissions. We must drop - // PERMITTED because access() checks those instead of EFFECTIVE. - ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_OVERRIDE)); - ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_READ_SEARCH)); - - const std::string filename = CreateTempFile(0400); - EXPECT_THAT(access(filename.c_str(), R_OK), SyscallSucceeds()); - EXPECT_THAT(access(filename.c_str(), W_OK), SyscallFailsWithErrno(EACCES)); - EXPECT_THAT(access(filename.c_str(), X_OK), SyscallFailsWithErrno(EACCES)); - EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds()); -} - -TEST_F(AccessTest, UsrReadExec) { - // Drop capabilities that allow us to override permissions. We must drop - // PERMITTED because access() checks those instead of EFFECTIVE. - ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_OVERRIDE)); - ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_READ_SEARCH)); - - const std::string filename = CreateTempFile(0500); - EXPECT_THAT(access(filename.c_str(), R_OK | X_OK), SyscallSucceeds()); - EXPECT_THAT(access(filename.c_str(), W_OK), SyscallFailsWithErrno(EACCES)); - EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds()); -} - -TEST_F(AccessTest, UsrReadWrite) { - const std::string filename = CreateTempFile(0600); - EXPECT_THAT(access(filename.c_str(), R_OK | W_OK), SyscallSucceeds()); - EXPECT_THAT(access(filename.c_str(), X_OK), SyscallFailsWithErrno(EACCES)); - EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds()); -} - -TEST_F(AccessTest, UsrReadWriteExec) { - const std::string filename = CreateTempFile(0700); - EXPECT_THAT(access(filename.c_str(), R_OK | W_OK | X_OK), SyscallSucceeds()); - EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/affinity.cc b/test/syscalls/linux/affinity.cc deleted file mode 100644 index 128364c34..000000000 --- a/test/syscalls/linux/affinity.cc +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sched.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/strings/str_split.h" -#include "test/util/cleanup.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { -namespace { - -// These tests are for both the sched_getaffinity(2) and sched_setaffinity(2) -// syscalls. -class AffinityTest : public ::testing::Test { - protected: - void SetUp() override { - EXPECT_THAT( - // Needs use the raw syscall to get the actual size. - cpuset_size_ = syscall(SYS_sched_getaffinity, /*pid=*/0, - sizeof(cpu_set_t), &mask_), - SyscallSucceeds()); - // Lots of tests rely on having more than 1 logical processor available. - EXPECT_GT(CPU_COUNT(&mask_), 1); - } - - static PosixError ClearLowestBit(cpu_set_t* mask, size_t cpus) { - const size_t mask_size = CPU_ALLOC_SIZE(cpus); - for (size_t n = 0; n < cpus; ++n) { - if (CPU_ISSET_S(n, mask_size, mask)) { - CPU_CLR_S(n, mask_size, mask); - return NoError(); - } - } - return PosixError(EINVAL, "No bit to clear, mask is empty"); - } - - PosixError ClearLowestBit() { return ClearLowestBit(&mask_, CPU_SETSIZE); } - - // Stores the initial cpu mask for this process. - cpu_set_t mask_ = {}; - int cpuset_size_ = 0; -}; - -// sched_getaffinity(2) is implemented. -TEST_F(AffinityTest, SchedGetAffinityImplemented) { - EXPECT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &mask_), - SyscallSucceeds()); -} - -// PID is not found. -TEST_F(AffinityTest, SchedGetAffinityInvalidPID) { - // Flaky, but it's tough to avoid a race condition when finding an unused pid - EXPECT_THAT(sched_getaffinity(/*pid=*/INT_MAX - 1, sizeof(cpu_set_t), &mask_), - SyscallFailsWithErrno(ESRCH)); -} - -// PID is not found. -TEST_F(AffinityTest, SchedSetAffinityInvalidPID) { - // Flaky, but it's tough to avoid a race condition when finding an unused pid - EXPECT_THAT(sched_setaffinity(/*pid=*/INT_MAX - 1, sizeof(cpu_set_t), &mask_), - SyscallFailsWithErrno(ESRCH)); -} - -TEST_F(AffinityTest, SchedSetAffinityZeroMask) { - CPU_ZERO(&mask_); - EXPECT_THAT(sched_setaffinity(/*pid=*/0, sizeof(cpu_set_t), &mask_), - SyscallFailsWithErrno(EINVAL)); -} - -// N.B. This test case relies on cpuset_size_ larger than the actual number of -// of all existing CPUs. Check your machine if the test fails. -TEST_F(AffinityTest, SchedSetAffinityNonexistentCPUDropped) { - cpu_set_t mask = mask_; - // Add a nonexistent CPU. - // - // The number needs to be larger than the possible number of CPU available, - // but smaller than the number of the CPU that the kernel claims to support -- - // it's implicitly returned by raw sched_getaffinity syscall. - CPU_SET(cpuset_size_ * 8 - 1, &mask); - EXPECT_THAT( - // Use raw syscall because it will be rejected by the libc wrapper - // otherwise. - syscall(SYS_sched_setaffinity, /*pid=*/0, sizeof(cpu_set_t), &mask), - SyscallSucceeds()) - << "failed with cpumask : " << CPUSetToString(mask) - << ", cpuset_size_ : " << cpuset_size_; - cpu_set_t newmask; - EXPECT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &newmask), - SyscallSucceeds()); - EXPECT_TRUE(CPU_EQUAL(&mask_, &newmask)) - << "got: " << CPUSetToString(newmask) - << " != expected: " << CPUSetToString(mask_); -} - -TEST_F(AffinityTest, SchedSetAffinityOnlyNonexistentCPUFails) { - // Make an empty cpu set. - CPU_ZERO(&mask_); - // Add a nonexistent CPU. - // - // The number needs to be larger than the possible number of CPU available, - // but smaller than the number of the CPU that the kernel claims to support -- - // it's implicitly returned by raw sched_getaffinity syscall. - int cpu = cpuset_size_ * 8 - 1; - if (cpu <= NumCPUs()) { - GTEST_SKIP() << "Skipping test: cpu " << cpu << " exists"; - } - CPU_SET(cpu, &mask_); - EXPECT_THAT( - // Use raw syscall because it will be rejected by the libc wrapper - // otherwise. - syscall(SYS_sched_setaffinity, /*pid=*/0, sizeof(cpu_set_t), &mask_), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(AffinityTest, SchedSetAffinityInvalidSize) { - EXPECT_GT(cpuset_size_, 0); - // Not big enough. - EXPECT_THAT(sched_getaffinity(/*pid=*/0, cpuset_size_ - 1, &mask_), - SyscallFailsWithErrno(EINVAL)); - // Not a multiple of word size. - EXPECT_THAT(sched_getaffinity(/*pid=*/0, cpuset_size_ + 1, &mask_), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(AffinityTest, Sanity) { - ASSERT_NO_ERRNO(ClearLowestBit()); - EXPECT_THAT(sched_setaffinity(/*pid=*/0, sizeof(cpu_set_t), &mask_), - SyscallSucceeds()); - cpu_set_t newmask; - EXPECT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &newmask), - SyscallSucceeds()); - EXPECT_TRUE(CPU_EQUAL(&mask_, &newmask)) - << "got: " << CPUSetToString(newmask) - << " != expected: " << CPUSetToString(mask_); -} - -TEST_F(AffinityTest, NewThread) { - SKIP_IF(CPU_COUNT(&mask_) < 3); - ASSERT_NO_ERRNO(ClearLowestBit()); - ASSERT_NO_ERRNO(ClearLowestBit()); - EXPECT_THAT(sched_setaffinity(/*pid=*/0, sizeof(cpu_set_t), &mask_), - SyscallSucceeds()); - ScopedThread([this]() { - cpu_set_t child_mask; - ASSERT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &child_mask), - SyscallSucceeds()); - ASSERT_TRUE(CPU_EQUAL(&child_mask, &mask_)) - << "child cpu mask: " << CPUSetToString(child_mask) - << " != parent cpu mask: " << CPUSetToString(mask_); - }); -} - -TEST_F(AffinityTest, ConsistentWithProcCpuInfo) { - // Count how many cpus are shown in /proc/cpuinfo. - std::string cpuinfo = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/cpuinfo")); - int count = 0; - for (auto const& line : absl::StrSplit(cpuinfo, '\n')) { - if (absl::StartsWith(line, "processor")) { - count++; - } - } - EXPECT_GE(count, CPU_COUNT(&mask_)); -} - -TEST_F(AffinityTest, ConsistentWithProcStat) { - // Count how many cpus are shown in /proc/stat. - std::string stat = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/stat")); - int count = 0; - for (auto const& line : absl::StrSplit(stat, '\n')) { - if (absl::StartsWith(line, "cpu") && !absl::StartsWith(line, "cpu ")) { - count++; - } - } - EXPECT_GE(count, CPU_COUNT(&mask_)); -} - -TEST_F(AffinityTest, SmallCpuMask) { - const int num_cpus = NumCPUs(); - const size_t mask_size = CPU_ALLOC_SIZE(num_cpus); - cpu_set_t* mask = CPU_ALLOC(num_cpus); - ASSERT_NE(mask, nullptr); - const auto free_mask = Cleanup([&] { CPU_FREE(mask); }); - - CPU_ZERO_S(mask_size, mask); - ASSERT_THAT(sched_getaffinity(0, mask_size, mask), SyscallSucceeds()); -} - -TEST_F(AffinityTest, LargeCpuMask) { - // Allocate mask bigger than cpu_set_t normally allocates. - const size_t cpus = CPU_SETSIZE * 8; - const size_t mask_size = CPU_ALLOC_SIZE(cpus); - - cpu_set_t* large_mask = CPU_ALLOC(cpus); - auto free_mask = Cleanup([large_mask] { CPU_FREE(large_mask); }); - CPU_ZERO_S(mask_size, large_mask); - - // Check that get affinity with large mask works as expected. - ASSERT_THAT(sched_getaffinity(/*pid=*/0, mask_size, large_mask), - SyscallSucceeds()); - EXPECT_TRUE(CPU_EQUAL(&mask_, large_mask)) - << "got: " << CPUSetToString(*large_mask, cpus) - << " != expected: " << CPUSetToString(mask_); - - // Check that set affinity with large mask works as expected. - ASSERT_NO_ERRNO(ClearLowestBit(large_mask, cpus)); - EXPECT_THAT(sched_setaffinity(/*pid=*/0, mask_size, large_mask), - SyscallSucceeds()); - - cpu_set_t* new_mask = CPU_ALLOC(cpus); - auto free_new_mask = Cleanup([new_mask] { CPU_FREE(new_mask); }); - CPU_ZERO_S(mask_size, new_mask); - EXPECT_THAT(sched_getaffinity(/*pid=*/0, mask_size, new_mask), - SyscallSucceeds()); - - EXPECT_TRUE(CPU_EQUAL_S(mask_size, large_mask, new_mask)) - << "got: " << CPUSetToString(*new_mask, cpus) - << " != expected: " << CPUSetToString(*large_mask, cpus); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc deleted file mode 100644 index a33daff17..000000000 --- a/test/syscalls/linux/aio.cc +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <linux/aio_abi.h> -#include <sys/mman.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> -#include <string> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/memory_util.h" -#include "test/util/posix_error.h" -#include "test/util/proc_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -using ::testing::_; - -namespace gvisor { -namespace testing { -namespace { - -// Returns the size of the VMA containing the given address. -PosixErrorOr<size_t> VmaSizeAt(uintptr_t addr) { - ASSIGN_OR_RETURN_ERRNO(std::string proc_self_maps, - GetContents("/proc/self/maps")); - ASSIGN_OR_RETURN_ERRNO(auto entries, ParseProcMaps(proc_self_maps)); - // Use binary search to find the first VMA that might contain addr. - ProcMapsEntry target = {}; - target.end = addr; - auto it = - std::upper_bound(entries.begin(), entries.end(), target, - [](const ProcMapsEntry& x, const ProcMapsEntry& y) { - return x.end < y.end; - }); - // Check that it actually contains addr. - if (it == entries.end() || addr < it->start) { - return PosixError(ENOENT, absl::StrCat("no VMA contains address ", addr)); - } - return it->end - it->start; -} - -constexpr char kData[] = "hello world!"; - -int SubmitCtx(aio_context_t ctx, long nr, struct iocb** iocbpp) { - return syscall(__NR_io_submit, ctx, nr, iocbpp); -} - -class AIOTest : public FileTest { - public: - AIOTest() : ctx_(0) {} - - int SetupContext(unsigned int nr) { - return syscall(__NR_io_setup, nr, &ctx_); - } - - int Submit(long nr, struct iocb** iocbpp) { - return SubmitCtx(ctx_, nr, iocbpp); - } - - int GetEvents(long min, long max, struct io_event* events, - struct timespec* timeout) { - return RetryEINTR(syscall)(__NR_io_getevents, ctx_, min, max, events, - timeout); - } - - int DestroyContext() { return syscall(__NR_io_destroy, ctx_); } - - void TearDown() override { - FileTest::TearDown(); - if (ctx_ != 0) { - ASSERT_THAT(DestroyContext(), SyscallSucceeds()); - } - } - - struct iocb CreateCallback() { - struct iocb cb = {}; - cb.aio_data = 0x123; - cb.aio_fildes = test_file_fd_.get(); - cb.aio_lio_opcode = IOCB_CMD_PWRITE; - cb.aio_buf = reinterpret_cast<uint64_t>(kData); - cb.aio_offset = 0; - cb.aio_nbytes = strlen(kData); - return cb; - } - - protected: - aio_context_t ctx_; -}; - -TEST_F(AIOTest, BasicWrite) { - // Copied from fs/aio.c. - constexpr unsigned AIO_RING_MAGIC = 0xa10a10a1; - struct aio_ring { - unsigned id; - unsigned nr; - unsigned head; - unsigned tail; - unsigned magic; - unsigned compat_features; - unsigned incompat_features; - unsigned header_length; - struct io_event io_events[0]; - }; - - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - - // Check that 'ctx_' points to a valid address. libaio uses it to check if - // aio implementation uses aio_ring. gVisor doesn't and returns all zeroes. - // Linux implements aio_ring, so skip the zeroes check. - // - // TODO(gvisor.dev/issue/204): Remove when gVisor implements aio_ring. - auto ring = reinterpret_cast<struct aio_ring*>(ctx_); - auto magic = IsRunningOnGvisor() ? 0 : AIO_RING_MAGIC; - EXPECT_EQ(ring->magic, magic); - - struct iocb cb = CreateCallback(); - struct iocb* cbs[1] = {&cb}; - - // Submit the request. - ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1)); - - // Get the reply. - struct io_event events[1]; - ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1)); - - // Verify that it is as expected. - EXPECT_EQ(events[0].data, 0x123); - EXPECT_EQ(events[0].obj, reinterpret_cast<long>(&cb)); - EXPECT_EQ(events[0].res, strlen(kData)); - - // Verify that the file contains the contents. - char verify_buf[sizeof(kData)] = {}; - ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)), - SyscallSucceedsWithValue(strlen(kData))); - EXPECT_STREQ(verify_buf, kData); -} - -TEST_F(AIOTest, BadWrite) { - // Create a pipe and immediately close the read end. - int pipefd[2]; - ASSERT_THAT(pipe(pipefd), SyscallSucceeds()); - - FileDescriptor rfd(pipefd[0]); - FileDescriptor wfd(pipefd[1]); - - rfd.reset(); // Close the read end. - - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - - struct iocb cb = CreateCallback(); - // Try to write to the read end. - cb.aio_fildes = wfd.get(); - struct iocb* cbs[1] = {&cb}; - - // Submit the request. - ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1)); - - // Get the reply. - struct io_event events[1]; - ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1)); - - // Verify that it fails with the right error code. - EXPECT_EQ(events[0].data, 0x123); - EXPECT_EQ(events[0].obj, reinterpret_cast<uint64_t>(&cb)); - EXPECT_LT(events[0].res, 0); -} - -TEST_F(AIOTest, ExitWithPendingIo) { - // Setup a context that is 5 entries deep. - ASSERT_THAT(SetupContext(5), SyscallSucceeds()); - - struct iocb cb = CreateCallback(); - struct iocb* cbs[] = {&cb}; - - // Submit a request but don't complete it to make it pending. - EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); -} - -int Submitter(void* arg) { - auto test = reinterpret_cast<AIOTest*>(arg); - - struct iocb cb = test->CreateCallback(); - struct iocb* cbs[1] = {&cb}; - - // Submit the request. - TEST_CHECK(test->Submit(1, cbs) == 1); - return 0; -} - -TEST_F(AIOTest, CloneVm) { - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - - const size_t kStackSize = 5 * kPageSize; - std::unique_ptr<char[]> stack(new char[kStackSize]); - char* bp = stack.get() + kStackSize; - pid_t child; - ASSERT_THAT(child = clone(Submitter, bp, CLONE_VM | SIGCHLD, - reinterpret_cast<void*>(this)), - SyscallSucceeds()); - - // Get the reply. - struct io_event events[1]; - ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1)); - - // Verify that it is as expected. - EXPECT_EQ(events[0].data, 0x123); - EXPECT_EQ(events[0].res, strlen(kData)); - - // Verify that the file contains the contents. - char verify_buf[32] = {}; - ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)), - SyscallSucceeds()); - EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -// Tests that AIO context can be remapped to a different address. -TEST_F(AIOTest, Mremap) { - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - const size_t ctx_size = - ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_))); - - struct iocb cb = CreateCallback(); - struct iocb* cbs[1] = {&cb}; - - // Reserve address space for the mremap target so we have something safe to - // map over. - Mapping dst = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE)); - - // Remap context 'handle' to a different address. - ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(), - MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()), - IsPosixErrorOkAndHolds(dst.ptr())); - aio_context_t old_ctx = ctx_; - ctx_ = reinterpret_cast<aio_context_t>(dst.addr()); - // io_destroy() will unmap dst now. - dst.release(); - - // Check that submitting the request with the old 'ctx_' fails. - ASSERT_THAT(SubmitCtx(old_ctx, 1, cbs), SyscallFailsWithErrno(EINVAL)); - - // Submit the request with the new 'ctx_'. - ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1)); - - // Remap again. - dst = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE)); - ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(), - MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()), - IsPosixErrorOkAndHolds(dst.ptr())); - ctx_ = reinterpret_cast<aio_context_t>(dst.addr()); - dst.release(); - - // Get the reply with yet another 'ctx_' and verify it. - struct io_event events[1]; - ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1)); - EXPECT_EQ(events[0].data, 0x123); - EXPECT_EQ(events[0].obj, reinterpret_cast<long>(&cb)); - EXPECT_EQ(events[0].res, strlen(kData)); - - // Verify that the file contains the contents. - char verify_buf[sizeof(kData)] = {}; - ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)), - SyscallSucceedsWithValue(strlen(kData))); - EXPECT_STREQ(verify_buf, kData); -} - -// Tests that AIO context cannot be expanded with mremap. -TEST_F(AIOTest, MremapExpansion) { - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - const size_t ctx_size = - ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_))); - - // Reserve address space for the mremap target so we have something safe to - // map over. - Mapping dst = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(ctx_size + kPageSize, PROT_NONE, MAP_PRIVATE)); - - // Test that remapping to a larger address range fails. - ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(), - MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()), - PosixErrorIs(EFAULT, _)); - - // mm/mremap.c:sys_mremap() => mremap_to() does do_munmap() of the destination - // before it hits the VM_DONTEXPAND check in vma_to_resize(), so we should no - // longer munmap it (another thread may have created a mapping there). - dst.release(); -} - -// Tests that AIO calls fail if context's address is inaccessible. -TEST_F(AIOTest, Mprotect) { - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - - struct iocb cb = CreateCallback(); - struct iocb* cbs[1] = {&cb}; - - ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1)); - - // Makes the context 'handle' inaccessible and check that all subsequent - // calls fail. - ASSERT_THAT(mprotect(reinterpret_cast<void*>(ctx_), kPageSize, PROT_NONE), - SyscallSucceeds()); - struct io_event events[1]; - EXPECT_THAT(GetEvents(1, 1, events, nullptr), SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT(Submit(1, cbs), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(DestroyContext(), SyscallFailsWithErrno(EINVAL)); - - // Prevent TearDown from attempting to destroy the context and fail. - ctx_ = 0; -} - -TEST_F(AIOTest, Timeout) { - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - - struct timespec timeout; - timeout.tv_sec = 0; - timeout.tv_nsec = 10; - struct io_event events[1]; - ASSERT_THAT(GetEvents(1, 1, events, &timeout), SyscallSucceedsWithValue(0)); -} - -class AIOReadWriteParamTest : public AIOTest, - public ::testing::WithParamInterface<int> {}; - -TEST_P(AIOReadWriteParamTest, BadOffset) { - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - - struct iocb cb = CreateCallback(); - struct iocb* cbs[1] = {&cb}; - - // Create a buffer that we can write to. - char buf[] = "hello world!"; - cb.aio_buf = reinterpret_cast<uint64_t>(buf); - - // Set the operation on the callback and give a negative offset. - const int opcode = GetParam(); - cb.aio_lio_opcode = opcode; - - iovec iov = {}; - if (opcode == IOCB_CMD_PREADV || opcode == IOCB_CMD_PWRITEV) { - // Create a valid iovec and set it in the callback. - iov.iov_base = reinterpret_cast<void*>(buf); - iov.iov_len = 1; - cb.aio_buf = reinterpret_cast<uint64_t>(&iov); - // aio_nbytes is the number of iovecs. - cb.aio_nbytes = 1; - } - - // Pass a negative offset. - cb.aio_offset = -1; - - // Should get error on submission. - ASSERT_THAT(Submit(1, cbs), SyscallFailsWithErrno(EINVAL)); -} - -INSTANTIATE_TEST_SUITE_P(BadOffset, AIOReadWriteParamTest, - ::testing::Values(IOCB_CMD_PREAD, IOCB_CMD_PWRITE, - IOCB_CMD_PREADV, IOCB_CMD_PWRITEV)); - -class AIOVectorizedParamTest : public AIOTest, - public ::testing::WithParamInterface<int> {}; - -TEST_P(AIOVectorizedParamTest, BadIOVecs) { - // Setup a context that is 128 entries deep. - ASSERT_THAT(SetupContext(128), SyscallSucceeds()); - - struct iocb cb = CreateCallback(); - struct iocb* cbs[1] = {&cb}; - - // Modify the callback to use the operation from the param. - cb.aio_lio_opcode = GetParam(); - - // Create an iovec with address in kernel range, and pass that as the buffer. - iovec iov = {}; - iov.iov_base = reinterpret_cast<void*>(0xFFFFFFFF00000000); - iov.iov_len = 1; - cb.aio_buf = reinterpret_cast<uint64_t>(&iov); - // aio_nbytes is the number of iovecs. - cb.aio_nbytes = 1; - - // Should get error on submission. - ASSERT_THAT(Submit(1, cbs), SyscallFailsWithErrno(EFAULT)); -} - -INSTANTIATE_TEST_SUITE_P(BadIOVecs, AIOVectorizedParamTest, - ::testing::Values(IOCB_CMD_PREADV, IOCB_CMD_PWRITEV)); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc deleted file mode 100644 index 940c97285..000000000 --- a/test/syscalls/linux/alarm.cc +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// N.B. Below, main blocks SIGALRM. Test cases must unblock it if they want -// delivery. - -void do_nothing_handler(int sig, siginfo_t* siginfo, void* arg) {} - -// No random save as the test relies on alarm timing. Cooperative save tests -// already cover the save between alarm and read. -TEST(AlarmTest, Interrupt_NoRandomSave) { - int pipe_fds[2]; - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - - FileDescriptor read_fd(pipe_fds[0]); - FileDescriptor write_fd(pipe_fds[1]); - - // Use a signal handler that interrupts but does nothing rather than using the - // default terminate action. - struct sigaction sa; - sa.sa_sigaction = do_nothing_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = 0; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); - - // Actually allow SIGALRM delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM)); - - // Alarm in 20 second, which should be well after read blocks below. - ASSERT_THAT(alarm(20), SyscallSucceeds()); - - char buf; - ASSERT_THAT(read(read_fd.get(), &buf, 1), SyscallFailsWithErrno(EINTR)); -} - -/* Count of the number of SIGALARMS handled. */ -static volatile int alarms_received = 0; - -void inc_alarms_handler(int sig, siginfo_t* siginfo, void* arg) { - alarms_received++; -} - -// No random save as the test relies on alarm timing. Cooperative save tests -// already cover the save between alarm and read. -TEST(AlarmTest, Restart_NoRandomSave) { - alarms_received = 0; - - int pipe_fds[2]; - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - - FileDescriptor read_fd(pipe_fds[0]); - // Write end closed by thread below. - - struct sigaction sa; - sa.sa_sigaction = inc_alarms_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_RESTART; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); - - // Spawn a thread to eventually unblock the read below. - ScopedThread t([pipe_fds] { - absl::SleepFor(absl::Seconds(30)); - EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds()); - }); - - // Actually allow SIGALRM delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM)); - - // Alarm in 20 second, which should be well after read blocks below, but - // before it returns. - ASSERT_THAT(alarm(20), SyscallSucceeds()); - - // Read and eventually get an EOF from the writer closing. If SA_RESTART - // didn't work, then the alarm would not have fired and we wouldn't increment - // our alarms_received count in our signal handler, or we would have not - // restarted the syscall gracefully, which we expect below in order to be - // able to get the final EOF on the pipe. - char buf; - ASSERT_THAT(read(read_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_EQ(alarms_received, 1); - - t.Join(); -} - -// No random save as the test relies on alarm timing. Cooperative save tests -// already cover the save between alarm and pause. -TEST(AlarmTest, SaSiginfo_NoRandomSave) { - // Use a signal handler that interrupts but does nothing rather than using the - // default terminate action. - struct sigaction sa; - sa.sa_sigaction = do_nothing_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); - - // Actually allow SIGALRM delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM)); - - // Alarm in 20 second, which should be well after pause blocks below. - ASSERT_THAT(alarm(20), SyscallSucceeds()); - ASSERT_THAT(pause(), SyscallFailsWithErrno(EINTR)); -} - -// No random save as the test relies on alarm timing. Cooperative save tests -// already cover the save between alarm and pause. -TEST(AlarmTest, SaInterrupt_NoRandomSave) { - // Use a signal handler that interrupts but does nothing rather than using the - // default terminate action. - struct sigaction sa; - sa.sa_sigaction = do_nothing_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_INTERRUPT; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); - - // Actually allow SIGALRM delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM)); - - // Alarm in 20 second, which should be well after pause blocks below. - ASSERT_THAT(alarm(20), SyscallSucceeds()); - ASSERT_THAT(pause(), SyscallFailsWithErrno(EINTR)); -} - -TEST(AlarmTest, UserModeSpinning) { - alarms_received = 0; - - struct sigaction sa = {}; - sa.sa_sigaction = inc_alarms_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); - - // Actually allow SIGALRM delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM)); - - // Alarm in 20 second, which should be well into the loop below. - ASSERT_THAT(alarm(20), SyscallSucceeds()); - // Make sure that the signal gets delivered even if we are spinning in user - // mode when it arrives. - while (!alarms_received) { - } -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - // These tests depend on delivering SIGALRM to the main thread. Block SIGALRM - // so that any other threads created by TestInit will also have SIGALRM - // blocked. - sigset_t set; - sigemptyset(&set); - sigaddset(&set, SIGALRM); - TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); - - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/arch_prctl.cc b/test/syscalls/linux/arch_prctl.cc deleted file mode 100644 index 81bf5a775..000000000 --- a/test/syscalls/linux/arch_prctl.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <asm/prctl.h> -#include <sys/prctl.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -// glibc does not provide a prototype for arch_prctl() so declare it here. -extern "C" int arch_prctl(int code, uintptr_t addr); - -namespace gvisor { -namespace testing { - -namespace { - -TEST(ArchPrctlTest, GetSetFS) { - uintptr_t orig; - const uintptr_t kNonCanonicalFsbase = 0x4141414142424242; - - // Get the original FS.base and then set it to the same value (this is - // intentional because FS.base is the TLS pointer so we cannot change it - // arbitrarily). - ASSERT_THAT(arch_prctl(ARCH_GET_FS, reinterpret_cast<uintptr_t>(&orig)), - SyscallSucceeds()); - ASSERT_THAT(arch_prctl(ARCH_SET_FS, orig), SyscallSucceeds()); - - // Trying to set FS.base to a non-canonical value should return an error. - ASSERT_THAT(arch_prctl(ARCH_SET_FS, kNonCanonicalFsbase), - SyscallFailsWithErrno(EPERM)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc deleted file mode 100644 index a26fc6af3..000000000 --- a/test/syscalls/linux/bad.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/syscall.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { -#ifdef __x86_64__ -// get_kernel_syms is not supported in Linux > 2.6, and not implemented in -// gVisor. -constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms; -#elif __aarch64__ -// Use the last of arch_specific_syscalls which are not implemented on arm64. -constexpr uint32_t kNotImplementedSyscall = __NR_arch_specific_syscall + 15; -#endif - -TEST(BadSyscallTest, NotImplemented) { - EXPECT_THAT(syscall(kNotImplementedSyscall), SyscallFailsWithErrno(ENOSYS)); -} - -TEST(BadSyscallTest, NegativeOne) { - EXPECT_THAT(syscall(-1), SyscallFailsWithErrno(ENOSYS)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/base_poll_test.cc b/test/syscalls/linux/base_poll_test.cc deleted file mode 100644 index ab7a19dd0..000000000 --- a/test/syscalls/linux/base_poll_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/base_poll_test.h" - -#include <sys/syscall.h> -#include <sys/types.h> -#include <syscall.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -static volatile int timer_fired = 0; -static void SigAlarmHandler(int, siginfo_t*, void*) { timer_fired = 1; } - -BasePollTest::BasePollTest() { - // Register our SIGALRM handler, but save the original so we can restore in - // the destructor. - struct sigaction sa = {}; - sa.sa_sigaction = SigAlarmHandler; - sigfillset(&sa.sa_mask); - TEST_PCHECK(sigaction(SIGALRM, &sa, &original_alarm_sa_) == 0); -} - -BasePollTest::~BasePollTest() { - ClearTimer(); - TEST_PCHECK(sigaction(SIGALRM, &original_alarm_sa_, nullptr) == 0); -} - -void BasePollTest::SetTimer(absl::Duration duration) { - pid_t tgid = getpid(); - pid_t tid = gettid(); - ClearTimer(); - - // Create a new timer thread. - timer_ = absl::make_unique<TimerThread>(absl::Now() + duration, tgid, tid); -} - -bool BasePollTest::TimerFired() const { return timer_fired; } - -void BasePollTest::ClearTimer() { - timer_.reset(); - timer_fired = 0; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/base_poll_test.h b/test/syscalls/linux/base_poll_test.h deleted file mode 100644 index 0d4a6701e..000000000 --- a/test/syscalls/linux/base_poll_test.h +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_BASE_POLL_TEST_H_ -#define GVISOR_TEST_SYSCALLS_BASE_POLL_TEST_H_ - -#include <signal.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <syscall.h> -#include <time.h> -#include <unistd.h> - -#include <memory> - -#include "gtest/gtest.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "test/util/logging.h" -#include "test/util/signal_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -// TimerThread is a cancelable timer. -class TimerThread { - public: - TimerThread(absl::Time deadline, pid_t tgid, pid_t tid) - : thread_([=] { - mu_.Lock(); - mu_.AwaitWithDeadline(absl::Condition(&cancel_), deadline); - if (!cancel_) { - TEST_PCHECK(tgkill(tgid, tid, SIGALRM) == 0); - } - mu_.Unlock(); - }) {} - - ~TimerThread() { Cancel(); } - - void Cancel() { - absl::MutexLock ml(&mu_); - cancel_ = true; - } - - private: - mutable absl::Mutex mu_; - bool cancel_ ABSL_GUARDED_BY(mu_) = false; - - // Must be last to ensure that the destructor for the thread is run before - // any other member of the object is destroyed. - ScopedThread thread_; -}; - -// Base test fixture for poll, select, ppoll, and pselect tests. -// -// This fixture makes use of SIGALRM. The handler is saved in SetUp() and -// restored in TearDown(). -class BasePollTest : public ::testing::Test { - protected: - BasePollTest(); - ~BasePollTest() override; - - // Sets a timer that will send a signal to the calling thread after - // `duration`. - void SetTimer(absl::Duration duration); - - // Returns true if the timer has fired. - bool TimerFired() const; - - // Stops the pending timer (if any) and clear the "fired" state. - void ClearTimer(); - - private: - // Thread that implements the timer. If the timer is stopped, timer_ is null. - // - // We have to use a thread for this purpose because tests using this fixture - // expect to be interrupted by the timer signal, but itimers/alarm(2) send - // thread-group-directed signals, which may be handled by any thread in the - // test process. - std::unique_ptr<TimerThread> timer_; - - // The original SIGALRM handler, to restore in destructor. - struct sigaction original_alarm_sa_; -}; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_BASE_POLL_TEST_H_ diff --git a/test/syscalls/linux/bind.cc b/test/syscalls/linux/bind.cc deleted file mode 100644 index 9547c4ab2..000000000 --- a/test/syscalls/linux/bind.cc +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/socket.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(AllSocketPairTest, Bind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); -} - -TEST_P(AllSocketPairTest, BindTooLong) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - // first_addr is a sockaddr_storage being used as a sockaddr_un. Use the full - // length which is longer than expected for a Unix socket. - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sizeof(sockaddr_storage)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, DoubleBindSocket) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - EXPECT_THAT( - bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched - // to EADDRINUSE. - AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL))); -} - -TEST_P(AllSocketPairTest, GetLocalAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - socklen_t addressLength = sockets->first_addr_size(); - struct sockaddr_storage address = {}; - ASSERT_THAT(getsockname(sockets->first_fd(), (struct sockaddr*)(&address), - &addressLength), - SyscallSucceeds()); - EXPECT_EQ( - 0, memcmp(&address, sockets->first_addr(), sockets->first_addr_size())); -} - -TEST_P(AllSocketPairTest, GetLocalAddrWithoutBind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - socklen_t addressLength = sockets->first_addr_size(); - struct sockaddr_storage received_address = {}; - ASSERT_THAT( - getsockname(sockets->first_fd(), (struct sockaddr*)(&received_address), - &addressLength), - SyscallSucceeds()); - struct sockaddr_storage want_address = {}; - want_address.ss_family = sockets->first_addr()->sa_family; - EXPECT_EQ(0, memcmp(&received_address, &want_address, addressLength)); -} - -TEST_P(AllSocketPairTest, GetRemoteAddressWithoutConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - socklen_t addressLength = sockets->first_addr_size(); - struct sockaddr_storage address = {}; - ASSERT_THAT(getpeername(sockets->second_fd(), (struct sockaddr*)(&address), - &addressLength), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(AllSocketPairTest, DoubleBindAddress) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - EXPECT_THAT(bind(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(AllSocketPairTest, Unbind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - - // Filesystem Unix sockets do not release their address when closed. - if (sockets->first_addr()->sa_data[0] != 0) { - ASSERT_THAT(bind(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallFailsWithErrno(EADDRINUSE)); - return; - } - - ASSERT_THAT(bind(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, AllSocketPairTest, - ::testing::ValuesIn(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>( - FilesystemUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, - SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, - SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK}))))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/brk.cc b/test/syscalls/linux/brk.cc deleted file mode 100644 index a03a44465..000000000 --- a/test/syscalls/linux/brk.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdint.h> -#include <sys/syscall.h> -#include <unistd.h> - -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -TEST(BrkTest, BrkSyscallReturnsOldBrkOnFailure) { - auto old_brk = sbrk(0); - EXPECT_THAT(syscall(SYS_brk, reinterpret_cast<void*>(-1)), - SyscallSucceedsWithValue(reinterpret_cast<uintptr_t>(old_brk))); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/chdir.cc b/test/syscalls/linux/chdir.cc deleted file mode 100644 index 3182c228b..000000000 --- a/test/syscalls/linux/chdir.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <linux/limits.h> -#include <sys/socket.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/util/capability_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(ChdirTest, Success) { - auto old_dir = GetAbsoluteTestTmpdir(); - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chdir(temp_dir.path().c_str()), SyscallSucceeds()); - // Temp path destructor deletes the newly created tmp dir and Sentry rejects - // saving when its current dir is still pointing to the path. Switch to a - // permanent path here. - EXPECT_THAT(chdir(old_dir.c_str()), SyscallSucceeds()); -} - -TEST(ChdirTest, PermissionDenied) { - // Drop capabilities that allow us to override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); - EXPECT_THAT(chdir(temp_dir.path().c_str()), SyscallFailsWithErrno(EACCES)); -} - -TEST(ChdirTest, NotDir) { - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - EXPECT_THAT(chdir(temp_file.path().c_str()), SyscallFailsWithErrno(ENOTDIR)); -} - -TEST(ChdirTest, NotExist) { - EXPECT_THAT(chdir("/foo/bar"), SyscallFailsWithErrno(ENOENT)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc deleted file mode 100644 index a06b5cfd6..000000000 --- a/test/syscalls/linux/chmod.cc +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include <string> - -#include "gtest/gtest.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(ChmodTest, ChmodFileSucceeds) { - // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - ASSERT_THAT(chmod(file.path().c_str(), 0466), SyscallSucceeds()); - EXPECT_THAT(open(file.path().c_str(), O_RDWR), SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, ChmodDirSucceeds) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string fileInDir = NewTempAbsPathInDir(dir.path()); - - ASSERT_THAT(chmod(dir.path().c_str(), 0466), SyscallSucceeds()); - EXPECT_THAT(open(fileInDir.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, FchmodFileSucceeds_NoRandomSave) { - // Drop capabilities that allow us to file directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); - int fd; - ASSERT_THAT(fd = open(file.path().c_str(), O_RDWR), SyscallSucceeds()); - - { - const DisableSave ds; // File permissions are reduced. - ASSERT_THAT(fchmod(fd, 0444), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - } - - EXPECT_THAT(open(file.path().c_str(), O_RDWR), SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, FchmodDirSucceeds_NoRandomSave) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - int fd; - ASSERT_THAT(fd = open(dir.path().c_str(), O_RDONLY | O_DIRECTORY), - SyscallSucceeds()); - - { - const DisableSave ds; // File permissions are reduced. - ASSERT_THAT(fchmod(fd, 0), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - } - - EXPECT_THAT(open(dir.path().c_str(), O_RDONLY), - SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, FchmodBadF) { - ASSERT_THAT(fchmod(-1, 0444), SyscallFailsWithErrno(EBADF)); -} - -TEST(ChmodTest, FchmodatBadF) { - ASSERT_THAT(fchmodat(-1, "foo", 0444, 0), SyscallFailsWithErrno(EBADF)); -} - -TEST(ChmodTest, FchmodatNotDir) { - ASSERT_THAT(fchmodat(-1, "", 0444, 0), SyscallFailsWithErrno(ENOENT)); -} - -TEST(ChmodTest, FchmodatFileAbsolutePath) { - // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - ASSERT_THAT(fchmodat(-1, file.path().c_str(), 0444, 0), SyscallSucceeds()); - EXPECT_THAT(open(file.path().c_str(), O_RDWR), SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, FchmodatDirAbsolutePath) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - int fd; - ASSERT_THAT(fd = open(dir.path().c_str(), O_RDONLY | O_DIRECTORY), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - ASSERT_THAT(fchmodat(-1, dir.path().c_str(), 0, 0), SyscallSucceeds()); - EXPECT_THAT(open(dir.path().c_str(), O_RDONLY), - SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, FchmodatFile) { - // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - int parent_fd; - ASSERT_THAT( - parent_fd = open(GetAbsoluteTestTmpdir().c_str(), O_RDONLY | O_DIRECTORY), - SyscallSucceeds()); - - ASSERT_THAT( - fchmodat(parent_fd, std::string(Basename(temp_file.path())).c_str(), 0444, - 0), - SyscallSucceeds()); - EXPECT_THAT(close(parent_fd), SyscallSucceeds()); - - EXPECT_THAT(open(temp_file.path().c_str(), O_RDWR), - SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, FchmodatDir) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - int parent_fd; - ASSERT_THAT( - parent_fd = open(GetAbsoluteTestTmpdir().c_str(), O_RDONLY | O_DIRECTORY), - SyscallSucceeds()); - - int fd; - ASSERT_THAT(fd = open(dir.path().c_str(), O_RDONLY | O_DIRECTORY), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - ASSERT_THAT( - fchmodat(parent_fd, std::string(Basename(dir.path())).c_str(), 0, 0), - SyscallSucceeds()); - EXPECT_THAT(close(parent_fd), SyscallSucceeds()); - - EXPECT_THAT(open(dir.path().c_str(), O_RDONLY | O_DIRECTORY), - SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, ChmodDowngradeWritability_NoRandomSave) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); - - int fd; - ASSERT_THAT(fd = open(file.path().c_str(), O_RDWR), SyscallSucceeds()); - - const DisableSave ds; // Permissions are dropped. - ASSERT_THAT(chmod(file.path().c_str(), 0444), SyscallSucceeds()); - EXPECT_THAT(write(fd, "hello", 5), SyscallSucceedsWithValue(5)); - - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(ChmodTest, ChmodFileToNoPermissionsSucceeds) { - // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); - - ASSERT_THAT(chmod(file.path().c_str(), 0), SyscallSucceeds()); - - EXPECT_THAT(open(file.path().c_str(), O_RDONLY), - SyscallFailsWithErrno(EACCES)); -} - -TEST(ChmodTest, FchmodDowngradeWritability_NoRandomSave) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - int fd; - ASSERT_THAT(fd = open(file.path().c_str(), O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - - const DisableSave ds; // Permissions are dropped. - ASSERT_THAT(fchmod(fd, 0444), SyscallSucceeds()); - EXPECT_THAT(write(fd, "hello", 5), SyscallSucceedsWithValue(5)); - - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(ChmodTest, FchmodFileToNoPermissionsSucceeds_NoRandomSave) { - // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); - - int fd; - ASSERT_THAT(fd = open(file.path().c_str(), O_RDWR), SyscallSucceeds()); - - { - const DisableSave ds; // Permissions are dropped. - ASSERT_THAT(fchmod(fd, 0), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - } - - EXPECT_THAT(open(file.path().c_str(), O_RDONLY), - SyscallFailsWithErrno(EACCES)); -} - -// Verify that we can get a RW FD after chmod, even if a RO fd is left open. -TEST(ChmodTest, ChmodWritableWithOpenFD) { - // FIXME(b/72455313): broken on hostfs. - if (IsRunningOnGvisor()) { - return; - } - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0444)); - - FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - ASSERT_THAT(fchmod(fd1.get(), 0644), SyscallSucceeds()); - - // This FD is writable, even though fd1 has a read-only reference to the file. - FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - // fd1 is not writable, but fd2 is. - char c = 'a'; - EXPECT_THAT(WriteFd(fd1.get(), &c, 1), SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc deleted file mode 100644 index 7a28b674d..000000000 --- a/test/syscalls/linux/chown.cc +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <grp.h> -#include <sys/types.h> -#include <unistd.h> - -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/synchronization/notification.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID"); -ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); -ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID"); - -namespace gvisor { -namespace testing { - -namespace { - -TEST(ChownTest, FchownBadF) { - ASSERT_THAT(fchown(-1, 0, 0), SyscallFailsWithErrno(EBADF)); -} - -TEST(ChownTest, FchownatBadF) { - ASSERT_THAT(fchownat(-1, "fff", 0, 0, 0), SyscallFailsWithErrno(EBADF)); -} - -TEST(ChownTest, FchownatEmptyPath) { - const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const auto fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_RDONLY)); - ASSERT_THAT(fchownat(fd.get(), "", 0, 0, 0), SyscallFailsWithErrno(ENOENT)); -} - -using Chown = - std::function<PosixError(const std::string&, uid_t owner, gid_t group)>; - -class ChownParamTest : public ::testing::TestWithParam<Chown> {}; - -TEST_P(ChownParamTest, ChownFileSucceeds) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_CHOWN))) { - ASSERT_NO_ERRNO(SetCapability(CAP_CHOWN, false)); - } - - const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // At least *try* setting to a group other than the EGID. - gid_t gid; - EXPECT_THAT(gid = getegid(), SyscallSucceeds()); - int num_groups; - EXPECT_THAT(num_groups = getgroups(0, nullptr), SyscallSucceeds()); - if (num_groups > 0) { - std::vector<gid_t> list(num_groups); - EXPECT_THAT(getgroups(list.size(), list.data()), SyscallSucceeds()); - gid = list[0]; - } - - EXPECT_NO_ERRNO(GetParam()(file.path(), geteuid(), gid)); - - struct stat s = {}; - ASSERT_THAT(stat(file.path().c_str(), &s), SyscallSucceeds()); - EXPECT_EQ(s.st_uid, geteuid()); - EXPECT_EQ(s.st_gid, gid); -} - -TEST_P(ChownParamTest, ChownFilePermissionDenied) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - - const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0777)); - - // Drop privileges and change IDs only in child thread, or else this parent - // thread won't be able to open some log files after the test ends. - ScopedThread([&] { - // Drop privileges. - if (HaveCapability(CAP_CHOWN).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_CHOWN, false)); - } - - // Change EUID and EGID. - // - // See note about POSIX below. - EXPECT_THAT( - syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1), - SyscallSucceeds()); - EXPECT_THAT( - syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid1), -1), - SyscallSucceeds()); - - EXPECT_THAT(GetParam()(file.path(), geteuid(), getegid()), - PosixErrorIs(EPERM, ::testing::ContainsRegex("chown"))); - }); -} - -TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_CHOWN)))); - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_SETUID)))); - - const std::string filename = NewTempAbsPath(); - - absl::Notification fileCreated, fileChowned; - // Change UID only in child thread, or else this parent thread won't be able - // to open some log files after the test ends. - ScopedThread t([&] { - // POSIX requires that all threads in a process share the same UIDs, so - // the NPTL setresuid wrappers use signals to make all threads execute the - // setresuid syscall. However, we want this thread to have its own set of - // credentials different from the parent process, so we use the raw - // syscall. - EXPECT_THAT( - syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid2), -1), - SyscallSucceeds()); - - // Create file and immediately close it. - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT | O_RDWR, 0644)); - fd.reset(); // Close the fd. - - fileCreated.Notify(); - fileChowned.WaitForNotification(); - - EXPECT_THAT(open(filename.c_str(), O_RDWR), SyscallFailsWithErrno(EACCES)); - FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_RDONLY)); - }); - - fileCreated.WaitForNotification(); - - // Set file's owners to someone different. - EXPECT_NO_ERRNO(GetParam()(filename, absl::GetFlag(FLAGS_scratch_uid1), - absl::GetFlag(FLAGS_scratch_gid))); - - struct stat s; - EXPECT_THAT(stat(filename.c_str(), &s), SyscallSucceeds()); - EXPECT_EQ(s.st_uid, absl::GetFlag(FLAGS_scratch_uid1)); - EXPECT_EQ(s.st_gid, absl::GetFlag(FLAGS_scratch_gid)); - - fileChowned.Notify(); -} - -PosixError errorFromReturn(const std::string& name, int ret) { - if (ret == -1) { - return PosixError(errno, absl::StrCat(name, " failed")); - } - return NoError(); -} - -INSTANTIATE_TEST_SUITE_P( - ChownKinds, ChownParamTest, - ::testing::Values( - [](const std::string& path, uid_t owner, gid_t group) -> PosixError { - int rc = chown(path.c_str(), owner, group); - MaybeSave(); - return errorFromReturn("chown", rc); - }, - [](const std::string& path, uid_t owner, gid_t group) -> PosixError { - int rc = lchown(path.c_str(), owner, group); - MaybeSave(); - return errorFromReturn("lchown", rc); - }, - [](const std::string& path, uid_t owner, gid_t group) -> PosixError { - ASSIGN_OR_RETURN_ERRNO(auto fd, Open(path, O_RDWR)); - int rc = fchown(fd.get(), owner, group); - MaybeSave(); - return errorFromReturn("fchown", rc); - }, - [](const std::string& path, uid_t owner, gid_t group) -> PosixError { - ASSIGN_OR_RETURN_ERRNO(auto fd, Open(path, O_RDWR)); - int rc = fchownat(fd.get(), "", owner, group, AT_EMPTY_PATH); - MaybeSave(); - return errorFromReturn("fchownat-fd", rc); - }, - [](const std::string& path, uid_t owner, gid_t group) -> PosixError { - ASSIGN_OR_RETURN_ERRNO(auto dirfd, Open(std::string(Dirname(path)), - O_DIRECTORY | O_RDONLY)); - int rc = fchownat(dirfd.get(), std::string(Basename(path)).c_str(), - owner, group, 0); - MaybeSave(); - return errorFromReturn("fchownat-dirfd", rc); - })); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc deleted file mode 100644 index 85ec013d5..000000000 --- a/test/syscalls/linux/chroot.cc +++ /dev/null @@ -1,366 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <stddef.h> -#include <sys/mman.h> -#include <sys/stat.h> -#include <syscall.h> -#include <unistd.h> - -#include <string> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/mount_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -using ::testing::HasSubstr; -using ::testing::Not; - -namespace gvisor { -namespace testing { - -namespace { - -TEST(ChrootTest, Success) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); -} - -TEST(ChrootTest, PermissionDenied) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - // CAP_DAC_READ_SEARCH and CAP_DAC_OVERRIDE may override Execute permission on - // directories. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); - EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallFailsWithErrno(EACCES)); -} - -TEST(ChrootTest, NotDir) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - EXPECT_THAT(chroot(temp_file.path().c_str()), SyscallFailsWithErrno(ENOTDIR)); -} - -TEST(ChrootTest, NotExist) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - EXPECT_THAT(chroot("/foo/bar"), SyscallFailsWithErrno(ENOENT)); -} - -TEST(ChrootTest, WithoutCapability) { - // Unset CAP_SYS_CHROOT. - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_CHROOT, false)); - - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallFailsWithErrno(EPERM)); -} - -TEST(ChrootTest, CreatesNewRoot) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - // Grab the initial cwd. - char initial_cwd[1024]; - ASSERT_THAT(syscall(__NR_getcwd, initial_cwd, sizeof(initial_cwd)), - SyscallSucceeds()); - - auto new_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto file_in_new_root = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(new_root.path())); - - // chroot into new_root. - ASSERT_THAT(chroot(new_root.path().c_str()), SyscallSucceeds()); - - // getcwd should return "(unreachable)" followed by the initial_cwd. - char cwd[1024]; - ASSERT_THAT(syscall(__NR_getcwd, cwd, sizeof(cwd)), SyscallSucceeds()); - std::string expected_cwd = "(unreachable)"; - expected_cwd += initial_cwd; - EXPECT_STREQ(cwd, expected_cwd.c_str()); - - // Should not be able to stat file by its full path. - struct stat statbuf; - EXPECT_THAT(stat(file_in_new_root.path().c_str(), &statbuf), - SyscallFailsWithErrno(ENOENT)); - - // Should be able to stat file at new rooted path. - auto basename = std::string(Basename(file_in_new_root.path())); - auto rootedFile = "/" + basename; - ASSERT_THAT(stat(rootedFile.c_str(), &statbuf), SyscallSucceeds()); - - // Should be able to stat cwd at '.' even though it's outside root. - ASSERT_THAT(stat(".", &statbuf), SyscallSucceeds()); - - // chdir into new root. - ASSERT_THAT(chdir("/"), SyscallSucceeds()); - - // getcwd should return "/". - EXPECT_THAT(syscall(__NR_getcwd, cwd, sizeof(cwd)), SyscallSucceeds()); - EXPECT_STREQ(cwd, "/"); - - // Statting '.', '..', '/', and '/..' all return the same dev and inode. - struct stat statbuf_dot; - ASSERT_THAT(stat(".", &statbuf_dot), SyscallSucceeds()); - struct stat statbuf_dotdot; - ASSERT_THAT(stat("..", &statbuf_dotdot), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_dotdot.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_dotdot.st_ino); - struct stat statbuf_slash; - ASSERT_THAT(stat("/", &statbuf_slash), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_slash.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_slash.st_ino); - struct stat statbuf_slashdotdot; - ASSERT_THAT(stat("/..", &statbuf_slashdotdot), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_slashdotdot.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_slashdotdot.st_ino); -} - -TEST(ChrootTest, DotDotFromOpenFD) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - auto dir_outside_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(dir_outside_root.path(), O_RDONLY | O_DIRECTORY)); - auto new_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // chroot into new_root. - ASSERT_THAT(chroot(new_root.path().c_str()), SyscallSucceeds()); - - // openat on fd with path .. will succeed. - int other_fd; - ASSERT_THAT(other_fd = openat(fd.get(), "..", O_RDONLY), SyscallSucceeds()); - EXPECT_THAT(close(other_fd), SyscallSucceeds()); - - // getdents on fd should not error. - char buf[1024]; - ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)), - SyscallSucceeds()); -} - -// Test that link resolution in a chroot can escape the root by following an -// open proc fd. Regression test for b/32316719. -TEST(ChrootTest, ProcFdLinkResolutionInChroot) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - const TempPath file_outside_chroot = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file_outside_chroot.path(), O_RDONLY)); - - const FileDescriptor proc_fd = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc", O_DIRECTORY | O_RDONLY | O_CLOEXEC)); - - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Opening relative to an already open fd to a node outside the chroot works. - const FileDescriptor proc_self_fd = ASSERT_NO_ERRNO_AND_VALUE( - OpenAt(proc_fd.get(), "self/fd", O_DIRECTORY | O_RDONLY | O_CLOEXEC)); - - // Proc fd symlinks can escape the chroot if the fd the symlink refers to - // refers to an object outside the chroot. - struct stat s = {}; - EXPECT_THAT( - fstatat(proc_self_fd.get(), absl::StrCat(fd.get()).c_str(), &s, 0), - SyscallSucceeds()); - - // Try to stat the stdin fd. Internally, this is handled differently from a - // proc fd entry pointing to a file, since stdin is backed by a host fd, and - // isn't a walkable path on the filesystem inside the sandbox. - EXPECT_THAT(fstatat(proc_self_fd.get(), "0", &s, 0), SyscallSucceeds()); -} - -// This test will verify that when you hold a fd to proc before entering -// a chroot that any files inside the chroot will appear rooted to the -// base chroot when examining /proc/self/fd/{num}. -TEST(ChrootTest, ProcMemSelfFdsNoEscapeProcOpen) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - // Get a FD to /proc before we enter the chroot. - const FileDescriptor proc = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - - // Create and enter a chroot directory. - const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Open a file inside the chroot at /foo. - const FileDescriptor foo = - ASSERT_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); - - // Examine /proc/self/fd/{foo_fd} to see if it exposes the fact that we're - // inside a chroot, the path should be /foo and NOT {chroot_dir}/foo. - const std::string fd_path = absl::StrCat("self/fd/", foo.get()); - char buf[1024] = {}; - size_t bytes_read = 0; - ASSERT_THAT(bytes_read = - readlinkat(proc.get(), fd_path.c_str(), buf, sizeof(buf) - 1), - SyscallSucceeds()); - - // The link should resolve to something. - ASSERT_GT(bytes_read, 0); - - // Assert that the link doesn't contain the chroot path and is only /foo. - EXPECT_STREQ(buf, "/foo"); -} - -// This test will verify that a file inside a chroot when mmapped will not -// expose the full file path via /proc/self/maps and instead honor the chroot. -TEST(ChrootTest, ProcMemSelfMapsNoEscapeProcOpen) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - // Get a FD to /proc before we enter the chroot. - const FileDescriptor proc = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - - // Create and enter a chroot directory. - const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Open a file inside the chroot at /foo. - const FileDescriptor foo = - ASSERT_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); - - // Mmap the newly created file. - void* foo_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - foo.get(), 0); - ASSERT_THAT(reinterpret_cast<int64_t>(foo_map), SyscallSucceeds()); - - // Always unmap. - auto cleanup_map = Cleanup( - [&] { EXPECT_THAT(munmap(foo_map, kPageSize), SyscallSucceeds()); }); - - // Examine /proc/self/maps to be sure that /foo doesn't appear to be - // mapped with the full chroot path. - const FileDescriptor maps = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), "self/maps", O_RDONLY)); - - size_t bytes_read = 0; - char buf[8 * 1024] = {}; - ASSERT_THAT(bytes_read = ReadFd(maps.get(), buf, sizeof(buf)), - SyscallSucceeds()); - - // The maps file should have something. - ASSERT_GT(bytes_read, 0); - - // Finally we want to make sure the maps don't contain the chroot path - ASSERT_EQ(std::string(buf, bytes_read).find(temp_dir.path()), - std::string::npos); -} - -// Test that mounts outside the chroot will not appear in /proc/self/mounts or -// /proc/self/mountinfo. -TEST(ChrootTest, ProcMountsMountinfoNoEscape) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - - // We are going to create some mounts and then chroot. In order to be able to - // unmount the mounts after the test run, we must chdir to the root and use - // relative paths for all mounts. That way, as long as we never chdir into - // the new root, we can access the mounts via relative paths and unmount them. - ASSERT_THAT(chdir("/"), SyscallSucceeds()); - - // Create nested tmpfs mounts. Note the use of relative paths in Mount calls. - auto const outer_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const outer_mount = ASSERT_NO_ERRNO_AND_VALUE(Mount( - "none", JoinPath(".", outer_dir.path()), "tmpfs", 0, "mode=0700", 0)); - - auto const inner_dir = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(outer_dir.path())); - auto const inner_mount = ASSERT_NO_ERRNO_AND_VALUE(Mount( - "none", JoinPath(".", inner_dir.path()), "tmpfs", 0, "mode=0700", 0)); - - // Filenames that will be checked for mounts, all relative to /proc dir. - std::string paths[3] = {"mounts", "self/mounts", "self/mountinfo"}; - - for (const std::string& path : paths) { - // We should have both inner and outer mounts. - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContents(JoinPath("/proc", path))); - EXPECT_THAT(contents, AllOf(HasSubstr(outer_dir.path()), - HasSubstr(inner_dir.path()))); - // We better have at least two mounts: the mounts we created plus the root. - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_GT(submounts.size(), 2); - } - - // Get a FD to /proc before we enter the chroot. - const FileDescriptor proc = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - - // Chroot to outer mount. - ASSERT_THAT(chroot(outer_dir.path().c_str()), SyscallSucceeds()); - - for (const std::string& path : paths) { - const FileDescriptor proc_file = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); - - // Only two mounts visible from this chroot: the inner and outer. Both - // paths should be relative to the new chroot. - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); - EXPECT_THAT(contents, - AllOf(HasSubstr(absl::StrCat(Basename(inner_dir.path()))), - Not(HasSubstr(outer_dir.path())), - Not(HasSubstr(inner_dir.path())))); - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_EQ(submounts.size(), 2); - } - - // Chroot to inner mount. We must use an absolute path accessible to our - // chroot. - const std::string inner_dir_basename = - absl::StrCat("/", Basename(inner_dir.path())); - ASSERT_THAT(chroot(inner_dir_basename.c_str()), SyscallSucceeds()); - - for (const std::string& path : paths) { - const FileDescriptor proc_file = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); - - // Only the inner mount visible from this chroot. - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_EQ(submounts.size(), 1); - } - - // Chroot back to ".". - ASSERT_THAT(chroot("."), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/clock_getres.cc b/test/syscalls/linux/clock_getres.cc deleted file mode 100644 index c408b936c..000000000 --- a/test/syscalls/linux/clock_getres.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/time.h> -#include <time.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// clock_getres works regardless of whether or not a timespec is passed. -TEST(ClockGetres, Timespec) { - struct timespec ts; - EXPECT_THAT(clock_getres(CLOCK_MONOTONIC, &ts), SyscallSucceeds()); - EXPECT_THAT(clock_getres(CLOCK_MONOTONIC, nullptr), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/clock_gettime.cc b/test/syscalls/linux/clock_gettime.cc deleted file mode 100644 index 7f6015049..000000000 --- a/test/syscalls/linux/clock_gettime.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <pthread.h> -#include <sys/time.h> - -#include <cerrno> -#include <cstdint> -#include <ctime> -#include <list> -#include <memory> -#include <string> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -int64_t clock_gettime_nsecs(clockid_t id) { - struct timespec ts; - TEST_PCHECK(clock_gettime(id, &ts) == 0); - return (ts.tv_sec * 1000000000 + ts.tv_nsec); -} - -// Spin on the CPU for at least ns nanoseconds, based on -// CLOCK_THREAD_CPUTIME_ID. -void spin_ns(int64_t ns) { - int64_t start = clock_gettime_nsecs(CLOCK_THREAD_CPUTIME_ID); - int64_t end = start + ns; - - do { - constexpr int kLoopCount = 1000000; // large and arbitrary - // volatile to prevent the compiler from skipping this loop. - for (volatile int i = 0; i < kLoopCount; i++) { - } - } while (clock_gettime_nsecs(CLOCK_THREAD_CPUTIME_ID) < end); -} - -// Test that CLOCK_PROCESS_CPUTIME_ID is a superset of CLOCK_THREAD_CPUTIME_ID. -TEST(ClockGettime, CputimeId) { - constexpr int kNumThreads = 13; // arbitrary - - absl::Duration spin_time = absl::Seconds(1); - - // Start off the worker threads and compute the aggregate time spent by - // the workers. Note that we test CLOCK_PROCESS_CPUTIME_ID by having the - // workers execute in parallel and verifying that CLOCK_PROCESS_CPUTIME_ID - // accumulates the runtime of all threads. - int64_t start = clock_gettime_nsecs(CLOCK_PROCESS_CPUTIME_ID); - - // Create a kNumThreads threads. - std::list<ScopedThread> threads; - for (int i = 0; i < kNumThreads; i++) { - threads.emplace_back( - [spin_time] { spin_ns(absl::ToInt64Nanoseconds(spin_time)); }); - } - for (auto& t : threads) { - t.Join(); - } - - int64_t end = clock_gettime_nsecs(CLOCK_PROCESS_CPUTIME_ID); - - // The aggregate time spent in the worker threads must be at least - // 'kNumThreads' times the time each thread spun. - ASSERT_GE(end - start, kNumThreads * absl::ToInt64Nanoseconds(spin_time)); -} - -TEST(ClockGettime, JavaThreadTime) { - clockid_t clockid; - ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid)); - struct timespec tp; - ASSERT_THAT(clock_getres(clockid, &tp), SyscallSucceeds()); - EXPECT_TRUE(tp.tv_sec > 0 || tp.tv_nsec > 0); - // A thread cputime is updated each 10msec and there is no approximation - // if a task is running. - do { - ASSERT_THAT(clock_gettime(clockid, &tp), SyscallSucceeds()); - } while (tp.tv_sec == 0 && tp.tv_nsec == 0); - EXPECT_TRUE(tp.tv_sec > 0 || tp.tv_nsec > 0); -} - -// There is not much to test here, since CLOCK_REALTIME may be discontiguous. -TEST(ClockGettime, RealtimeWorks) { - struct timespec tp; - EXPECT_THAT(clock_gettime(CLOCK_REALTIME, &tp), SyscallSucceeds()); -} - -class MonotonicClockTest : public ::testing::TestWithParam<clockid_t> {}; - -TEST_P(MonotonicClockTest, IsMonotonic) { - auto end = absl::Now() + absl::Seconds(5); - - struct timespec tp; - EXPECT_THAT(clock_gettime(GetParam(), &tp), SyscallSucceeds()); - - auto prev = absl::TimeFromTimespec(tp); - while (absl::Now() < end) { - EXPECT_THAT(clock_gettime(GetParam(), &tp), SyscallSucceeds()); - auto now = absl::TimeFromTimespec(tp); - EXPECT_GE(now, prev); - prev = now; - } -} - -std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) { - switch (info.param) { - case CLOCK_MONOTONIC: - return "CLOCK_MONOTONIC"; - case CLOCK_MONOTONIC_COARSE: - return "CLOCK_MONOTONIC_COARSE"; - case CLOCK_MONOTONIC_RAW: - return "CLOCK_MONOTONIC_RAW"; - case CLOCK_BOOTTIME: - // CLOCK_BOOTTIME is a monotonic clock. - return "CLOCK_BOOTTIME"; - default: - return absl::StrCat(info.param); - } -} - -INSTANTIATE_TEST_SUITE_P(ClockGettime, MonotonicClockTest, - ::testing::Values(CLOCK_MONOTONIC, - CLOCK_MONOTONIC_COARSE, - CLOCK_MONOTONIC_RAW, CLOCK_BOOTTIME), - PrintClockId); - -TEST(ClockGettime, UnimplementedReturnsEINVAL) { - SKIP_IF(!IsRunningOnGvisor()); - - struct timespec tp; - EXPECT_THAT(clock_gettime(CLOCK_REALTIME_ALARM, &tp), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(clock_gettime(CLOCK_BOOTTIME_ALARM, &tp), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(ClockGettime, InvalidClockIDReturnsEINVAL) { - struct timespec tp; - EXPECT_THAT(clock_gettime(-1, &tp), SyscallFailsWithErrno(EINVAL)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/clock_nanosleep.cc b/test/syscalls/linux/clock_nanosleep.cc deleted file mode 100644 index b55cddc52..000000000 --- a/test/syscalls/linux/clock_nanosleep.cc +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <time.h> - -#include <atomic> -#include <utility> - -#include "gtest/gtest.h" -#include "absl/time/time.h" -#include "test/util/cleanup.h" -#include "test/util/posix_error.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// sys_clock_nanosleep is defined because the glibc clock_nanosleep returns -// error numbers directly and does not set errno. This makes our Syscall -// matchers look a little weird when expecting failure: -// "SyscallSucceedsWithValue(ERRNO)". -int sys_clock_nanosleep(clockid_t clkid, int flags, - const struct timespec* request, - struct timespec* remain) { - return syscall(SYS_clock_nanosleep, clkid, flags, request, remain); -} - -PosixErrorOr<absl::Time> GetTime(clockid_t clk) { - struct timespec ts = {}; - const int rc = clock_gettime(clk, &ts); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "clock_gettime"); - } - return absl::TimeFromTimespec(ts); -} - -class WallClockNanosleepTest : public ::testing::TestWithParam<clockid_t> {}; - -TEST_P(WallClockNanosleepTest, InvalidValues) { - const struct timespec invalid[] = { - {.tv_sec = -1, .tv_nsec = -1}, {.tv_sec = 0, .tv_nsec = INT32_MIN}, - {.tv_sec = 0, .tv_nsec = INT32_MAX}, {.tv_sec = 0, .tv_nsec = -1}, - {.tv_sec = -1, .tv_nsec = 0}, - }; - - for (auto const ts : invalid) { - EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &ts, nullptr), - SyscallFailsWithErrno(EINVAL)); - } -} - -TEST_P(WallClockNanosleepTest, SleepOneSecond) { - constexpr absl::Duration kSleepDuration = absl::Seconds(1); - struct timespec duration = absl::ToTimespec(kSleepDuration); - - const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - EXPECT_THAT( - RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &duration, &duration), - SyscallSucceeds()); - const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - - EXPECT_GE(after - before, kSleepDuration); -} - -TEST_P(WallClockNanosleepTest, InterruptedNanosleep) { - constexpr absl::Duration kSleepDuration = absl::Seconds(60); - struct timespec duration = absl::ToTimespec(kSleepDuration); - - // Install no-op signal handler for SIGALRM. - struct sigaction sa = {}; - sigfillset(&sa.sa_mask); - sa.sa_handler = +[](int signo) {}; - const auto cleanup_sa = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); - - // Measure time since setting the alarm, since the alarm will interrupt the - // sleep and hence determine how long we sleep. - const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - - // Set an alarm to go off while sleeping. - struct itimerval timer = {}; - timer.it_value.tv_sec = 1; - timer.it_value.tv_usec = 0; - timer.it_interval.tv_sec = 1; - timer.it_interval.tv_usec = 0; - const auto cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, timer)); - - EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &duration, &duration), - SyscallFailsWithErrno(EINTR)); - const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - - // Remaining time updated. - const absl::Duration remaining = absl::DurationFromTimespec(duration); - EXPECT_GE(after - before + remaining, kSleepDuration); -} - -// Remaining time is *not* updated if nanosleep completes uninterrupted. -TEST_P(WallClockNanosleepTest, UninterruptedNanosleep) { - constexpr absl::Duration kSleepDuration = absl::Milliseconds(10); - const struct timespec duration = absl::ToTimespec(kSleepDuration); - - while (true) { - constexpr int kRemainingMagic = 42; - struct timespec remaining; - remaining.tv_sec = kRemainingMagic; - remaining.tv_nsec = kRemainingMagic; - - int ret = sys_clock_nanosleep(GetParam(), 0, &duration, &remaining); - if (ret == EINTR) { - // Retry from beginning. We want a single uninterrupted call. - continue; - } - - EXPECT_THAT(ret, SyscallSucceeds()); - EXPECT_EQ(remaining.tv_sec, kRemainingMagic); - EXPECT_EQ(remaining.tv_nsec, kRemainingMagic); - break; - } -} - -TEST_P(WallClockNanosleepTest, SleepUntil) { - const absl::Time now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - const absl::Time until = now + absl::Seconds(2); - const struct timespec ts = absl::ToTimespec(until); - - EXPECT_THAT( - RetryEINTR(sys_clock_nanosleep)(GetParam(), TIMER_ABSTIME, &ts, nullptr), - SyscallSucceeds()); - const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - - EXPECT_GE(after, until); -} - -INSTANTIATE_TEST_SUITE_P(Sleepers, WallClockNanosleepTest, - ::testing::Values(CLOCK_REALTIME, CLOCK_MONOTONIC)); - -TEST(ClockNanosleepProcessTest, SleepFiveSeconds) { - const absl::Duration kSleepDuration = absl::Seconds(5); - struct timespec duration = absl::ToTimespec(kSleepDuration); - - // Ensure that CLOCK_PROCESS_CPUTIME_ID advances. - std::atomic<bool> done(false); - ScopedThread t([&] { - while (!done.load()) { - } - }); - const auto cleanup_done = Cleanup([&] { done.store(true); }); - - const absl::Time before = - ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID)); - EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0, - &duration, &duration), - SyscallSucceeds()); - const absl::Time after = - ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID)); - EXPECT_GE(after - before, kSleepDuration); -} -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc deleted file mode 100644 index 7cd6a75bd..000000000 --- a/test/syscalls/linux/concurrency.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> - -#include <atomic> - -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/platform_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { -namespace { - -// Test that a thread that never yields to the OS does not prevent other threads -// from running. -TEST(ConcurrencyTest, SingleProcessMultithreaded) { - std::atomic<int> a(0); - - ScopedThread t([&a]() { - while (!a.load()) { - } - }); - - absl::SleepFor(absl::Seconds(1)); - - // We are still able to execute code in this thread. The other hasn't - // permanently hung execution in both threads. - a.store(1); -} - -// Test that multiple threads in this process continue to execute in parallel, -// even if an unrelated second process is spawned. Regression test for -// b/32119508. -TEST(ConcurrencyTest, MultiProcessMultithreaded) { - // In PID 1, start TIDs 1 and 2, and put both to sleep. - // - // Start PID 3, which spins for 5 seconds, then exits. - // - // TIDs 1 and 2 wake and attempt to Activate, which cannot occur until PID 3 - // exits. - // - // Both TIDs 1 and 2 should be woken. If they are not both woken, the test - // hangs. - // - // This is all fundamentally racy. If we are failing to wake all threads, the - // expectation is that this test becomes flaky, rather than consistently - // failing. - // - // If additional background threads fail to block, we may never schedule the - // child, at which point this test effectively becomes - // MultiProcessConcurrency. That's not expected to occur. - - std::atomic<int> a(0); - ScopedThread t([&a]() { - // Block so that PID 3 can execute and we can wait on its exit. - absl::SleepFor(absl::Seconds(1)); - while (!a.load()) { - } - }); - - pid_t child_pid = fork(); - if (child_pid == 0) { - // Busy wait without making any blocking syscalls. - auto end = absl::Now() + absl::Seconds(5); - while (absl::Now() < end) { - } - _exit(0); - } - ASSERT_THAT(child_pid, SyscallSucceeds()); - - absl::SleepFor(absl::Seconds(1)); - - // If only TID 1 is woken, thread.Join will hang. - // If only TID 2 is woken, both will hang. - a.store(1); - t.Join(); - - int status = 0; - EXPECT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(WEXITSTATUS(status), 0); -} - -// Test that multiple processes can execute concurrently, even if one process -// never yields. -TEST(ConcurrencyTest, MultiProcessConcurrency) { - SKIP_IF(PlatformSupportMultiProcess() == PlatformSupport::NotSupported); - - pid_t child_pid = fork(); - if (child_pid == 0) { - while (true) { - } - } - ASSERT_THAT(child_pid, SyscallSucceeds()); - - absl::SleepFor(absl::Seconds(5)); - - // We are still able to execute code in this process. The other hasn't - // permanently hung execution in both processes. - ASSERT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); - int status = 0; - - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - ASSERT_TRUE(WIFSIGNALED(status)); - ASSERT_EQ(WTERMSIG(status), SIGKILL); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/connect_external.cc b/test/syscalls/linux/connect_external.cc deleted file mode 100644 index 1edb50e47..000000000 --- a/test/syscalls/linux/connect_external.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <stdlib.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <string> -#include <tuple> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/test_util.h" - -// This file contains tests specific to connecting to host UDS managed outside -// the sandbox / test. -// -// A set of ultity sockets will be created externally in $TEST_UDS_TREE and -// $TEST_UDS_ATTACH_TREE for these tests to interact with. - -namespace gvisor { -namespace testing { - -namespace { - -struct ProtocolSocket { - int protocol; - std::string name; -}; - -// Parameter is (socket root dir, ProtocolSocket). -using GoferStreamSeqpacketTest = - ::testing::TestWithParam<std::tuple<std::string, ProtocolSocket>>; - -// Connect to a socket and verify that write/read work. -// -// An "echo" socket doesn't work for dgram sockets because our socket is -// unnamed. The server thus has no way to reply to us. -TEST_P(GoferStreamSeqpacketTest, Echo) { - std::string env; - ProtocolSocket proto; - std::tie(env, proto) = GetParam(); - - char* val = getenv(env.c_str()); - ASSERT_NE(val, nullptr); - std::string root(val); - - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0)); - - std::string socket_path = JoinPath(root, proto.name, "echo"); - - struct sockaddr_un addr = {}; - addr.sun_family = AF_UNIX; - memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); - - ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - constexpr int kBufferSize = 64; - char send_buffer[kBufferSize]; - memset(send_buffer, 'a', sizeof(send_buffer)); - - ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)), - SyscallSucceedsWithValue(sizeof(send_buffer))); - - char recv_buffer[kBufferSize]; - ASSERT_THAT(ReadFd(sock.get(), recv_buffer, sizeof(recv_buffer)), - SyscallSucceedsWithValue(sizeof(recv_buffer))); - ASSERT_EQ(0, memcmp(send_buffer, recv_buffer, sizeof(send_buffer))); -} - -// It is not possible to connect to a bound but non-listening socket. -TEST_P(GoferStreamSeqpacketTest, NonListening) { - std::string env; - ProtocolSocket proto; - std::tie(env, proto) = GetParam(); - - char* val = getenv(env.c_str()); - ASSERT_NE(val, nullptr); - std::string root(val); - - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0)); - - std::string socket_path = JoinPath(root, proto.name, "nonlistening"); - - struct sockaddr_un addr = {}; - addr.sun_family = AF_UNIX; - memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); - - ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallFailsWithErrno(ECONNREFUSED)); -} - -INSTANTIATE_TEST_SUITE_P( - StreamSeqpacket, GoferStreamSeqpacketTest, - ::testing::Combine( - // Test access via standard path and attach point. - ::testing::Values("TEST_UDS_TREE", "TEST_UDS_ATTACH_TREE"), - ::testing::Values(ProtocolSocket{SOCK_STREAM, "stream"}, - ProtocolSocket{SOCK_SEQPACKET, "seqpacket"}))); - -// Parameter is socket root dir. -using GoferDgramTest = ::testing::TestWithParam<std::string>; - -// Connect to a socket and verify that write works. -// -// An "echo" socket doesn't work for dgram sockets because our socket is -// unnamed. The server thus has no way to reply to us. -TEST_P(GoferDgramTest, Null) { - std::string env = GetParam(); - char* val = getenv(env.c_str()); - ASSERT_NE(val, nullptr); - std::string root(val); - - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_DGRAM, 0)); - - std::string socket_path = JoinPath(root, "dgram/null"); - - struct sockaddr_un addr = {}; - addr.sun_family = AF_UNIX; - memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); - - ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - constexpr int kBufferSize = 64; - char send_buffer[kBufferSize]; - memset(send_buffer, 'a', sizeof(send_buffer)); - - ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)), - SyscallSucceedsWithValue(sizeof(send_buffer))); -} - -INSTANTIATE_TEST_SUITE_P(Dgram, GoferDgramTest, - // Test access via standard path and attach point. - ::testing::Values("TEST_UDS_TREE", - "TEST_UDS_ATTACH_TREE")); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/creat.cc b/test/syscalls/linux/creat.cc deleted file mode 100644 index 3c270d6da..000000000 --- a/test/syscalls/linux/creat.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/types.h> - -#include <string> - -#include "gtest/gtest.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr int kMode = 0666; - -TEST(CreatTest, CreatCreatesNewFile) { - std::string const path = NewTempAbsPath(); - struct stat buf; - int fd; - ASSERT_THAT(stat(path.c_str(), &buf), SyscallFailsWithErrno(ENOENT)); - ASSERT_THAT(fd = creat(path.c_str(), kMode), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - EXPECT_THAT(stat(path.c_str(), &buf), SyscallSucceeds()); -} - -TEST(CreatTest, CreatTruncatesExistingFile) { - auto temp_path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - int fd; - ASSERT_NO_ERRNO(SetContents(temp_path.path(), "non-empty")); - ASSERT_THAT(fd = creat(temp_path.path().c_str(), kMode), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - std::string new_contents; - ASSERT_NO_ERRNO(GetContents(temp_path.path(), &new_contents)); - EXPECT_EQ("", new_contents); -} - -TEST(CreatTest, CreatWithNameTooLong) { - // Start with a unique name, and pad it to NAME_MAX + 1; - std::string name = NewTempRelPath(); - int padding = (NAME_MAX + 1) - name.size(); - name.append(padding, 'x'); - const std::string& path = JoinPath(GetAbsoluteTestTmpdir(), name); - - // Creation should return ENAMETOOLONG. - ASSERT_THAT(creat(path.c_str(), kMode), SyscallFailsWithErrno(ENAMETOOLONG)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc deleted file mode 100644 index 4dd302eed..000000000 --- a/test/syscalls/linux/dev.cc +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(DevTest, LseekDevUrandom) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/urandom", O_RDONLY)); - EXPECT_THAT(lseek(fd.get(), -10, SEEK_CUR), SyscallSucceeds()); - EXPECT_THAT(lseek(fd.get(), -10, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); -} - -TEST(DevTest, LseekDevNull) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY)); - EXPECT_THAT(lseek(fd.get(), -10, SEEK_CUR), SyscallSucceeds()); - EXPECT_THAT(lseek(fd.get(), -10, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); - EXPECT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds()); -} - -TEST(DevTest, LseekDevZero) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY)); - EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); - EXPECT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds()); -} - -TEST(DevTest, LseekDevFull) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/full", O_RDONLY)); - EXPECT_THAT(lseek(fd.get(), 123, SEEK_SET), SyscallSucceedsWithValue(0)); - EXPECT_THAT(lseek(fd.get(), 123, SEEK_CUR), SyscallSucceedsWithValue(0)); - EXPECT_THAT(lseek(fd.get(), 123, SEEK_END), SyscallSucceedsWithValue(0)); -} - -TEST(DevTest, LseekDevNullFreshFile) { - // Seeks to /dev/null always return 0. - const FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY)); - const FileDescriptor fd2 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY)); - - EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - EXPECT_THAT(lseek(fd1.get(), 1000, SEEK_CUR), SyscallSucceedsWithValue(0)); - EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - - const FileDescriptor fd3 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY)); - EXPECT_THAT(lseek(fd3.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); -} - -TEST(DevTest, OpenTruncate) { - // Truncation is ignored on linux and gvisor for device files. - ASSERT_NO_ERRNO_AND_VALUE( - Open("/dev/null", O_CREAT | O_TRUNC | O_WRONLY, 0644)); - ASSERT_NO_ERRNO_AND_VALUE( - Open("/dev/zero", O_CREAT | O_TRUNC | O_WRONLY, 0644)); - ASSERT_NO_ERRNO_AND_VALUE( - Open("/dev/full", O_CREAT | O_TRUNC | O_WRONLY, 0644)); -} - -TEST(DevTest, Pread64DevNull) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY)); - char buf[1]; - EXPECT_THAT(pread64(fd.get(), buf, 1, 0), SyscallSucceedsWithValue(0)); -} - -TEST(DevTest, Pread64DevZero) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY)); - char buf[1]; - EXPECT_THAT(pread64(fd.get(), buf, 1, 0), SyscallSucceedsWithValue(1)); -} - -TEST(DevTest, Pread64DevFull) { - // /dev/full behaves like /dev/zero with respect to reads. - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/full", O_RDONLY)); - char buf[1]; - EXPECT_THAT(pread64(fd.get(), buf, 1, 0), SyscallSucceedsWithValue(1)); -} - -TEST(DevTest, ReadDevNull) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY)); - std::vector<char> buf(1); - EXPECT_THAT(ReadFd(fd.get(), buf.data(), 1), SyscallSucceeds()); -} - -// Do not allow random save as it could lead to partial reads. -TEST(DevTest, ReadDevZero_NoRandomSave) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY)); - - constexpr int kReadSize = 128 * 1024; - std::vector<char> buf(kReadSize, 1); - EXPECT_THAT(ReadFd(fd.get(), buf.data(), kReadSize), - SyscallSucceedsWithValue(kReadSize)); - EXPECT_EQ(std::vector<char>(kReadSize, 0), buf); -} - -TEST(DevTest, WriteDevNull) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_WRONLY)); - EXPECT_THAT(WriteFd(fd.get(), "a", 1), SyscallSucceedsWithValue(1)); -} - -TEST(DevTest, WriteDevZero) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); - EXPECT_THAT(WriteFd(fd.get(), "a", 1), SyscallSucceedsWithValue(1)); -} - -TEST(DevTest, WriteDevFull) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/full", O_WRONLY)); - EXPECT_THAT(WriteFd(fd.get(), "a", 1), SyscallFailsWithErrno(ENOSPC)); -} - -TEST(DevTest, TTYExists) { - struct stat statbuf = {}; - ASSERT_THAT(stat("/dev/tty", &statbuf), SyscallSucceeds()); - // Check that it's a character device with rw-rw-rw- permissions. - EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); -} - -} // namespace -} // namespace testing - -} // namespace gvisor diff --git a/test/syscalls/linux/dup.cc b/test/syscalls/linux/dup.cc deleted file mode 100644 index 4f773bc75..000000000 --- a/test/syscalls/linux/dup.cc +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/eventfd_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -PosixErrorOr<FileDescriptor> Dup2(const FileDescriptor& fd, int target_fd) { - int new_fd = dup2(fd.get(), target_fd); - if (new_fd < 0) { - return PosixError(errno, "Dup2"); - } - return FileDescriptor(new_fd); -} - -PosixErrorOr<FileDescriptor> Dup3(const FileDescriptor& fd, int target_fd, - int flags) { - int new_fd = dup3(fd.get(), target_fd, flags); - if (new_fd < 0) { - return PosixError(errno, "Dup2"); - } - return FileDescriptor(new_fd); -} - -void CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2) { - struct stat stat_result1, stat_result2; - ASSERT_THAT(fstat(fd1.get(), &stat_result1), SyscallSucceeds()); - ASSERT_THAT(fstat(fd2.get(), &stat_result2), SyscallSucceeds()); - EXPECT_EQ(stat_result1.st_dev, stat_result2.st_dev); - EXPECT_EQ(stat_result1.st_ino, stat_result2.st_ino); -} - -TEST(DupTest, Dup) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); - - // Dup the descriptor and make sure it's the same file. - FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); - ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); -} - -TEST(DupTest, DupClearsCloExec) { - // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag set. - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_CLOEXEC)); - EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); - - // Duplicate the descriptor. Ensure that it doesn't have FD_CLOEXEC set. - FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); - ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); - EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); -} - -TEST(DupTest, Dup2) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); - - // Regular dup once. - FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); - - ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); - - // Dup over the file above. - int target_fd = nfd.release(); - FileDescriptor nfd2 = ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, target_fd)); - EXPECT_EQ(target_fd, nfd2.get()); - CheckSameFile(fd, nfd2); -} - -TEST(DupTest, Dup2SameFD) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); - - // Should succeed. - ASSERT_THAT(dup2(fd.get(), fd.get()), SyscallSucceedsWithValue(fd.get())); -} - -TEST(DupTest, Dup3) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); - - // Regular dup once. - FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); - ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); - - // Dup over the file above, check that it has no CLOEXEC. - nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), 0)); - CheckSameFile(fd, nfd); - EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); - - // Dup over the file again, check that it does not CLOEXEC. - nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), O_CLOEXEC)); - CheckSameFile(fd, nfd); - EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); -} - -TEST(DupTest, Dup3FailsSameFD) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); - - // Only dup3 fails if the new and old fd are the same. - ASSERT_THAT(dup3(fd.get(), fd.get(), 0), SyscallFailsWithErrno(EINVAL)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc deleted file mode 100644 index a4f8f3cec..000000000 --- a/test/syscalls/linux/epoll.cc +++ /dev/null @@ -1,432 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <limits.h> -#include <pthread.h> -#include <signal.h> -#include <stdint.h> -#include <stdio.h> -#include <string.h> -#include <sys/epoll.h> -#include <sys/eventfd.h> -#include <time.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/epoll_util.h" -#include "test/util/eventfd_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr int kFDsPerEpoll = 3; -constexpr uint64_t kMagicConstant = 0x0102030405060708; - -uint64_t ms_elapsed(const struct timespec* begin, const struct timespec* end) { - return (end->tv_sec - begin->tv_sec) * 1000 + - (end->tv_nsec - begin->tv_nsec) / 1000000; -} - -TEST(EpollTest, AllWritable) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < kFDsPerEpoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), - EPOLLIN | EPOLLOUT, kMagicConstant + i)); - } - - struct epoll_event result[kFDsPerEpoll]; - ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(kFDsPerEpoll)); - // TODO(edahlgren): Why do some tests check epoll_event::data, and others - // don't? Does Linux actually guarantee that, in any of these test cases, - // epoll_wait will necessarily write out the epoll_events in the order that - // they were registered? - for (int i = 0; i < kFDsPerEpoll; i++) { - ASSERT_EQ(result[i].events, EPOLLOUT); - } -} - -TEST(EpollTest, LastReadable) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < kFDsPerEpoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), - EPOLLIN | EPOLLOUT, kMagicConstant + i)); - } - - uint64_t tmp = 1; - ASSERT_THAT(WriteFd(eventfds[kFDsPerEpoll - 1].get(), &tmp, sizeof(tmp)), - SyscallSucceedsWithValue(sizeof(tmp))); - - struct epoll_event result[kFDsPerEpoll]; - ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(kFDsPerEpoll)); - - int i; - for (i = 0; i < kFDsPerEpoll - 1; i++) { - EXPECT_EQ(result[i].events, EPOLLOUT); - } - EXPECT_EQ(result[i].events, EPOLLOUT | EPOLLIN); - EXPECT_EQ(result[i].data.u64, kMagicConstant + i); -} - -TEST(EpollTest, LastNonWritable) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < kFDsPerEpoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), - EPOLLIN | EPOLLOUT, kMagicConstant + i)); - } - - // Write the maximum value to the event fd so that writing to it again would - // block. - uint64_t tmp = ULLONG_MAX - 1; - ASSERT_THAT(WriteFd(eventfds[kFDsPerEpoll - 1].get(), &tmp, sizeof(tmp)), - SyscallSucceedsWithValue(sizeof(tmp))); - - struct epoll_event result[kFDsPerEpoll]; - ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(kFDsPerEpoll)); - - int i; - for (i = 0; i < kFDsPerEpoll - 1; i++) { - EXPECT_EQ(result[i].events, EPOLLOUT); - } - EXPECT_EQ(result[i].events, EPOLLIN); - EXPECT_THAT(ReadFd(eventfds[kFDsPerEpoll - 1].get(), &tmp, sizeof(tmp)), - sizeof(tmp)); - EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(kFDsPerEpoll)); - - for (i = 0; i < kFDsPerEpoll; i++) { - EXPECT_EQ(result[i].events, EPOLLOUT); - } -} - -TEST(EpollTest, Timeout_NoRandomSave) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < kFDsPerEpoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, - kMagicConstant + i)); - } - - constexpr int kTimeoutMs = 200; - struct timespec begin; - struct timespec end; - struct epoll_event result[kFDsPerEpoll]; - - { - const DisableSave ds; // Timing-related. - EXPECT_THAT(clock_gettime(CLOCK_MONOTONIC, &begin), SyscallSucceeds()); - - ASSERT_THAT( - RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, kTimeoutMs), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(clock_gettime(CLOCK_MONOTONIC, &end), SyscallSucceeds()); - } - - // Check the lower bound on the timeout. Checking for an upper bound is - // fragile because Linux can overrun the timeout due to scheduling delays. - EXPECT_GT(ms_elapsed(&begin, &end), kTimeoutMs - 1); -} - -void* writer(void* arg) { - int fd = *reinterpret_cast<int*>(arg); - uint64_t tmp = 1; - - usleep(200000); - if (WriteFd(fd, &tmp, sizeof(tmp)) != sizeof(tmp)) { - fprintf(stderr, "writer failed: errno %s\n", strerror(errno)); - } - - return nullptr; -} - -TEST(EpollTest, WaitThenUnblock) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < kFDsPerEpoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, - kMagicConstant + i)); - } - - // Fire off a thread that will make at least one of the event fds readable. - pthread_t thread; - int make_readable = eventfds[0].get(); - ASSERT_THAT(pthread_create(&thread, nullptr, writer, &make_readable), - SyscallSucceedsWithValue(0)); - - struct epoll_event result[kFDsPerEpoll]; - EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(1)); - EXPECT_THAT(pthread_detach(thread), SyscallSucceeds()); -} - -void sighandler(int s) {} - -void* signaler(void* arg) { - pthread_t* t = reinterpret_cast<pthread_t*>(arg); - // Repeatedly send the real-time signal until we are detached, because it's - // difficult to know exactly when epoll_wait on another thread (which this - // is intending to interrupt) has started blocking. - while (1) { - usleep(200000); - pthread_kill(*t, SIGRTMIN); - } - return nullptr; -} - -TEST(EpollTest, UnblockWithSignal) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < kFDsPerEpoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, - kMagicConstant + i)); - } - - signal(SIGRTMIN, sighandler); - // Unblock the real time signals that InitGoogle blocks :( - sigset_t unblock; - sigemptyset(&unblock); - sigaddset(&unblock, SIGRTMIN); - ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &unblock, nullptr), SyscallSucceeds()); - - pthread_t thread; - pthread_t cur = pthread_self(); - ASSERT_THAT(pthread_create(&thread, nullptr, signaler, &cur), - SyscallSucceedsWithValue(0)); - - struct epoll_event result[kFDsPerEpoll]; - EXPECT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallFailsWithErrno(EINTR)); - EXPECT_THAT(pthread_cancel(thread), SyscallSucceeds()); - EXPECT_THAT(pthread_detach(thread), SyscallSucceeds()); -} - -TEST(EpollTest, TimeoutNoFds) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - struct epoll_event result[kFDsPerEpoll]; - EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100), - SyscallSucceedsWithValue(0)); -} - -struct addr_ctx { - int epollfd; - int eventfd; -}; - -void* fd_adder(void* arg) { - struct addr_ctx* actx = reinterpret_cast<struct addr_ctx*>(arg); - struct epoll_event event; - event.events = EPOLLIN | EPOLLOUT; - event.data.u64 = 0xdeadbeeffacefeed; - - usleep(200000); - if (epoll_ctl(actx->epollfd, EPOLL_CTL_ADD, actx->eventfd, &event) == -1) { - fprintf(stderr, "epoll_ctl failed: %s\n", strerror(errno)); - } - - return nullptr; -} - -TEST(EpollTest, UnblockWithNewFD) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); - - pthread_t thread; - struct addr_ctx actx = {epollfd.get(), eventfd.get()}; - ASSERT_THAT(pthread_create(&thread, nullptr, fd_adder, &actx), - SyscallSucceedsWithValue(0)); - - struct epoll_event result[kFDsPerEpoll]; - // Wait while no FDs are ready, but after 200ms fd_adder will add a ready FD - // to epoll which will wake us up. - EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(1)); - EXPECT_THAT(pthread_detach(thread), SyscallSucceeds()); - EXPECT_EQ(result[0].data.u64, 0xdeadbeeffacefeed); -} - -TEST(EpollTest, Oneshot) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - std::vector<FileDescriptor> eventfds; - for (int i = 0; i < kFDsPerEpoll; i++) { - eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, - kMagicConstant + i)); - } - - struct epoll_event event; - event.events = EPOLLOUT | EPOLLONESHOT; - event.data.u64 = kMagicConstant; - ASSERT_THAT( - epoll_ctl(epollfd.get(), EPOLL_CTL_MOD, eventfds[0].get(), &event), - SyscallSucceeds()); - - struct epoll_event result[kFDsPerEpoll]; - // One-shot entry means that the first epoll_wait should succeed. - ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(result[0].data.u64, kMagicConstant); - - // One-shot entry means that the second epoll_wait should timeout. - EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100), - SyscallSucceedsWithValue(0)); -} - -TEST(EpollTest, EdgeTriggered_NoRandomSave) { - // Test edge-triggered entry: make it edge-triggered, first wait should - // return it, second one should time out, make it writable again, third wait - // should return it, fourth wait should timeout. - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfd.get(), - EPOLLOUT | EPOLLET, kMagicConstant)); - - struct epoll_event result[kFDsPerEpoll]; - - { - const DisableSave ds; // May trigger spurious event. - - // Edge-triggered entry means that the first epoll_wait should return the - // event. - ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(result[0].data.u64, kMagicConstant); - - // Edge-triggered entry means that the second epoll_wait should time out. - ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 100), - SyscallSucceedsWithValue(0)); - } - - uint64_t tmp = ULLONG_MAX - 1; - - // Make an fd non-writable. - ASSERT_THAT(WriteFd(eventfd.get(), &tmp, sizeof(tmp)), - SyscallSucceedsWithValue(sizeof(tmp))); - - // Make the same fd non-writable to trigger a change, which will trigger an - // edge-triggered event. - ASSERT_THAT(ReadFd(eventfd.get(), &tmp, sizeof(tmp)), - SyscallSucceedsWithValue(sizeof(tmp))); - - { - const DisableSave ds; // May trigger spurious event. - - // An edge-triggered event should now be returned. - ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(result[0].data.u64, kMagicConstant); - - // The edge-triggered event had been consumed above, we don't expect to - // get it again. - ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 100), - SyscallSucceedsWithValue(0)); - } -} - -TEST(EpollTest, OneshotAndEdgeTriggered) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); - ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfd.get(), - EPOLLOUT | EPOLLET | EPOLLONESHOT, - kMagicConstant)); - - struct epoll_event result[kFDsPerEpoll]; - // First time one shot edge-triggered entry means that epoll_wait should - // return the event. - ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(result[0].data.u64, kMagicConstant); - - // Edge-triggered entry means that the second epoll_wait should time out. - ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100), - SyscallSucceedsWithValue(0)); - - uint64_t tmp = ULLONG_MAX - 1; - // Make an fd non-writable. - ASSERT_THAT(WriteFd(eventfd.get(), &tmp, sizeof(tmp)), - SyscallSucceedsWithValue(sizeof(tmp))); - // Make the same fd non-writable to trigger a change, which will not trigger - // an edge-triggered event because we've also included EPOLLONESHOT. - ASSERT_THAT(ReadFd(eventfd.get(), &tmp, sizeof(tmp)), - SyscallSucceedsWithValue(sizeof(tmp))); - ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100), - SyscallSucceedsWithValue(0)); -} - -TEST(EpollTest, CycleOfOneDisallowed) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - - struct epoll_event event; - event.events = EPOLLOUT; - event.data.u64 = kMagicConstant; - - ASSERT_THAT(epoll_ctl(epollfd.get(), EPOLL_CTL_ADD, epollfd.get(), &event), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(EpollTest, CycleOfThreeDisallowed) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - auto epollfd1 = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - auto epollfd2 = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - - ASSERT_NO_ERRNO( - RegisterEpollFD(epollfd.get(), epollfd1.get(), EPOLLIN, kMagicConstant)); - ASSERT_NO_ERRNO( - RegisterEpollFD(epollfd1.get(), epollfd2.get(), EPOLLIN, kMagicConstant)); - - struct epoll_event event; - event.events = EPOLLIN; - event.data.u64 = kMagicConstant; - EXPECT_THAT(epoll_ctl(epollfd2.get(), EPOLL_CTL_ADD, epollfd.get(), &event), - SyscallFailsWithErrno(ELOOP)); -} - -TEST(EpollTest, CloseFile) { - auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); - ASSERT_NO_ERRNO( - RegisterEpollFD(epollfd.get(), eventfd.get(), EPOLLOUT, kMagicConstant)); - - struct epoll_event result[kFDsPerEpoll]; - ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(result[0].data.u64, kMagicConstant); - - // Close the event fd early. - eventfd.reset(); - - EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100), - SyscallSucceedsWithValue(0)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc deleted file mode 100644 index 927001eee..000000000 --- a/test/syscalls/linux/eventfd.cc +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <pthread.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <sys/epoll.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/epoll_util.h" -#include "test/util/eventfd_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(EventfdTest, Nonblock) { - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); - - uint64_t l; - ASSERT_THAT(read(efd.get(), &l, sizeof(l)), SyscallFailsWithErrno(EAGAIN)); - - l = 1; - ASSERT_THAT(write(efd.get(), &l, sizeof(l)), SyscallSucceeds()); - - l = 0; - ASSERT_THAT(read(efd.get(), &l, sizeof(l)), SyscallSucceeds()); - EXPECT_EQ(l, 1); - - ASSERT_THAT(read(efd.get(), &l, sizeof(l)), SyscallFailsWithErrno(EAGAIN)); -} - -void* read_three_times(void* arg) { - int efd = *reinterpret_cast<int*>(arg); - uint64_t l; - EXPECT_THAT(read(efd, &l, sizeof(l)), SyscallSucceedsWithValue(sizeof(l))); - EXPECT_THAT(read(efd, &l, sizeof(l)), SyscallSucceedsWithValue(sizeof(l))); - EXPECT_THAT(read(efd, &l, sizeof(l)), SyscallSucceedsWithValue(sizeof(l))); - return nullptr; -} - -TEST(EventfdTest, BlockingWrite) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_SEMAPHORE)); - int efd = fd.get(); - - pthread_t p; - ASSERT_THAT(pthread_create(&p, nullptr, read_three_times, - reinterpret_cast<void*>(&efd)), - SyscallSucceeds()); - - uint64_t l = 1; - ASSERT_THAT(write(efd, &l, sizeof(l)), SyscallSucceeds()); - EXPECT_EQ(l, 1); - - ASSERT_THAT(write(efd, &l, sizeof(l)), SyscallSucceeds()); - EXPECT_EQ(l, 1); - - ASSERT_THAT(write(efd, &l, sizeof(l)), SyscallSucceeds()); - EXPECT_EQ(l, 1); - - ASSERT_THAT(pthread_join(p, nullptr), SyscallSucceeds()); -} - -TEST(EventfdTest, SmallWrite) { - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); - - uint64_t l = 16; - ASSERT_THAT(write(efd.get(), &l, 4), SyscallFailsWithErrno(EINVAL)); -} - -TEST(EventfdTest, SmallRead) { - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); - - uint64_t l = 1; - ASSERT_THAT(write(efd.get(), &l, sizeof(l)), SyscallSucceeds()); - - l = 0; - ASSERT_THAT(read(efd.get(), &l, 4), SyscallFailsWithErrno(EINVAL)); -} - -TEST(EventfdTest, BigWrite) { - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); - - uint64_t big[16]; - big[0] = 16; - ASSERT_THAT(write(efd.get(), big, sizeof(big)), SyscallSucceeds()); -} - -TEST(EventfdTest, BigRead) { - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); - - uint64_t l = 1; - ASSERT_THAT(write(efd.get(), &l, sizeof(l)), SyscallSucceeds()); - - uint64_t big[16]; - ASSERT_THAT(read(efd.get(), big, sizeof(big)), SyscallSucceeds()); - EXPECT_EQ(big[0], 1); -} - -TEST(EventfdTest, BigWriteBigRead) { - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); - - uint64_t l[16]; - l[0] = 16; - ASSERT_THAT(write(efd.get(), l, sizeof(l)), SyscallSucceeds()); - ASSERT_THAT(read(efd.get(), l, sizeof(l)), SyscallSucceeds()); - EXPECT_EQ(l[0], 1); -} - -TEST(EventfdTest, SpliceFromPipePartialSucceeds) { - int pipes[2]; - ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds()); - const FileDescriptor pipe_rfd(pipes[0]); - const FileDescriptor pipe_wfd(pipes[1]); - constexpr uint64_t kVal{1}; - - FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK)); - - uint64_t event_array[2]; - event_array[0] = kVal; - event_array[1] = kVal; - ASSERT_THAT(write(pipe_wfd.get(), event_array, sizeof(event_array)), - SyscallSucceedsWithValue(sizeof(event_array))); - EXPECT_THAT(splice(pipe_rfd.get(), /*__offin=*/nullptr, efd.get(), - /*__offout=*/nullptr, sizeof(event_array[0]) + 1, - SPLICE_F_NONBLOCK), - SyscallSucceedsWithValue(sizeof(event_array[0]))); - - uint64_t val; - ASSERT_THAT(read(efd.get(), &val, sizeof(val)), - SyscallSucceedsWithValue(sizeof(val))); - EXPECT_EQ(val, kVal); -} - -// NotifyNonZero is inherently racy, so random save is disabled. -TEST(EventfdTest, NotifyNonZero_NoRandomSave) { - // Waits will time out at 10 seconds. - constexpr int kEpollTimeoutMs = 10000; - // Create an eventfd descriptor. - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(7, EFD_NONBLOCK | EFD_SEMAPHORE)); - // Create an epoll fd to listen to efd. - FileDescriptor epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - // Add efd to epoll. - ASSERT_NO_ERRNO( - RegisterEpollFD(epollfd.get(), efd.get(), EPOLLIN | EPOLLET, efd.get())); - - // Use epoll to get a value from efd. - struct epoll_event out_ev; - int wait_out = epoll_wait(epollfd.get(), &out_ev, 1, kEpollTimeoutMs); - EXPECT_EQ(wait_out, 1); - EXPECT_EQ(efd.get(), out_ev.data.fd); - uint64_t val = 0; - ASSERT_THAT(read(efd.get(), &val, sizeof(val)), SyscallSucceeds()); - EXPECT_EQ(val, 1); - - // Start a thread that, after this thread blocks on epoll_wait, will write to - // efd. This is racy -- it's possible that this write will happen after - // epoll_wait times out. - ScopedThread t([&efd] { - sleep(5); - uint64_t val = 1; - EXPECT_THAT(write(efd.get(), &val, sizeof(val)), - SyscallSucceedsWithValue(sizeof(val))); - }); - - // epoll_wait should return once the thread writes. - wait_out = epoll_wait(epollfd.get(), &out_ev, 1, kEpollTimeoutMs); - EXPECT_EQ(wait_out, 1); - EXPECT_EQ(efd.get(), out_ev.data.fd); - - val = 0; - ASSERT_THAT(read(efd.get(), &val, sizeof(val)), SyscallSucceeds()); - EXPECT_EQ(val, 1); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/exceptions.cc b/test/syscalls/linux/exceptions.cc deleted file mode 100644 index 420b9543f..000000000 --- a/test/syscalls/linux/exceptions.cc +++ /dev/null @@ -1,367 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> - -#include "gtest/gtest.h" -#include "test/util/logging.h" -#include "test/util/platform_util.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// Default value for the x87 FPU control word. See Intel SDM Vol 1, Ch 8.1.5 -// "x87 FPU Control Word". -constexpr uint16_t kX87ControlWordDefault = 0x37f; - -// Mask for the divide-by-zero exception. -constexpr uint16_t kX87ControlWordDiv0Mask = 1 << 2; - -// Default value for the SSE control register (MXCSR). See Intel SDM Vol 1, Ch -// 11.6.4 "Initialization of SSE/SSE3 Extensions". -constexpr uint32_t kMXCSRDefault = 0x1f80; - -// Mask for the divide-by-zero exception. -constexpr uint32_t kMXCSRDiv0Mask = 1 << 9; - -// Flag for a pending divide-by-zero exception. -constexpr uint32_t kMXCSRDiv0Flag = 1 << 2; - -void inline Halt() { asm("hlt\r\n"); } - -void inline SetAlignmentCheck() { - asm("subq $128, %%rsp\r\n" // Avoid potential red zone clobber - "pushf\r\n" - "pop %%rax\r\n" - "or $0x40000, %%rax\r\n" - "push %%rax\r\n" - "popf\r\n" - "addq $128, %%rsp\r\n" - : - : - : "ax"); -} - -void inline ClearAlignmentCheck() { - asm("subq $128, %%rsp\r\n" // Avoid potential red zone clobber - "pushf\r\n" - "pop %%rax\r\n" - "mov $0x40000, %%rbx\r\n" - "not %%rbx\r\n" - "and %%rbx, %%rax\r\n" - "push %%rax\r\n" - "popf\r\n" - "addq $128, %%rsp\r\n" - : - : - : "ax", "bx"); -} - -void inline Int3Normal() { asm(".byte 0xcd, 0x03\r\n"); } - -void inline Int3Compact() { asm(".byte 0xcc\r\n"); } - -void InIOHelper(int width, int value) { - EXPECT_EXIT( - { - switch (width) { - case 1: - asm volatile("inb %%dx, %%al" ::"d"(value) : "%eax"); - break; - case 2: - asm volatile("inw %%dx, %%ax" ::"d"(value) : "%eax"); - break; - case 4: - asm volatile("inl %%dx, %%eax" ::"d"(value) : "%eax"); - break; - default: - FAIL() << "invalid input width, only 1, 2 or 4 is allowed"; - } - }, - ::testing::KilledBySignal(SIGSEGV), ""); -} - -TEST(ExceptionTest, Halt) { - // In order to prevent the regular handler from messing with things (and - // perhaps refaulting until some other signal occurs), we reset the handler to - // the default action here and ensure that it dies correctly. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa)); - - EXPECT_EXIT(Halt(), ::testing::KilledBySignal(SIGSEGV), ""); -} - -TEST(ExceptionTest, DivideByZero) { - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); - - EXPECT_EXIT( - { - uint32_t remainder; - uint32_t quotient; - uint32_t divisor = 0; - uint64_t value = 1; - asm("divl 0(%2)\r\n" - : "=d"(remainder), "=a"(quotient) - : "r"(&divisor), "d"(value >> 32), "a"(value)); - TEST_CHECK(quotient > 0); // Force dependency. - }, - ::testing::KilledBySignal(SIGFPE), ""); -} - -// By default, x87 exceptions are masked and simply return a default value. -TEST(ExceptionTest, X87DivideByZeroMasked) { - int32_t quotient; - int32_t value = 1; - int32_t divisor = 0; - asm("fildl %[value]\r\n" - "fidivl %[divisor]\r\n" - "fistpl %[quotient]\r\n" - : [ quotient ] "=m"(quotient) - : [ value ] "m"(value), [ divisor ] "m"(divisor)); - - EXPECT_EQ(quotient, INT32_MIN); -} - -// When unmasked, division by zero raises SIGFPE. -TEST(ExceptionTest, X87DivideByZeroUnmasked) { - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); - - EXPECT_EXIT( - { - // Clear the divide by zero exception mask. - constexpr uint16_t kControlWord = - kX87ControlWordDefault & ~kX87ControlWordDiv0Mask; - - int32_t quotient; - int32_t value = 1; - int32_t divisor = 0; - asm volatile( - "fldcw %[cw]\r\n" - "fildl %[value]\r\n" - "fidivl %[divisor]\r\n" - "fistpl %[quotient]\r\n" - : [ quotient ] "=m"(quotient) - : [ cw ] "m"(kControlWord), [ value ] "m"(value), - [ divisor ] "m"(divisor)); - }, - ::testing::KilledBySignal(SIGFPE), ""); -} - -// Pending exceptions in the x87 status register are not clobbered by syscalls. -TEST(ExceptionTest, X87StatusClobber) { - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); - - EXPECT_EXIT( - { - // Clear the divide by zero exception mask. - constexpr uint16_t kControlWord = - kX87ControlWordDefault & ~kX87ControlWordDiv0Mask; - - int32_t quotient; - int32_t value = 1; - int32_t divisor = 0; - asm volatile( - "fildl %[value]\r\n" - "fidivl %[divisor]\r\n" - // Exception is masked, so it does not occur here. - "fistpl %[quotient]\r\n" - - // SYS_getpid placed in rax by constraint. - "syscall\r\n" - - // Unmask exception. The syscall didn't clobber the pending - // exception, so now it can be raised. - // - // N.B. "a floating-point exception will be generated upon execution - // of the *next* floating-point instruction". - "fldcw %[cw]\r\n" - "fwait\r\n" - : [ quotient ] "=m"(quotient) - : [ value ] "m"(value), [ divisor ] "m"(divisor), "a"(SYS_getpid), - [ cw ] "m"(kControlWord) - : "rcx", "r11"); - }, - ::testing::KilledBySignal(SIGFPE), ""); -} - -// By default, SSE exceptions are masked and simply return a default value. -TEST(ExceptionTest, SSEDivideByZeroMasked) { - uint32_t status; - int32_t quotient; - int32_t value = 1; - int32_t divisor = 0; - asm("cvtsi2ssl %[value], %%xmm0\r\n" - "cvtsi2ssl %[divisor], %%xmm1\r\n" - "divss %%xmm1, %%xmm0\r\n" - "cvtss2sil %%xmm0, %[quotient]\r\n" - : [ quotient ] "=r"(quotient), [ status ] "=r"(status) - : [ value ] "r"(value), [ divisor ] "r"(divisor) - : "xmm0", "xmm1"); - - EXPECT_EQ(quotient, INT32_MIN); -} - -// When unmasked, division by zero raises SIGFPE. -TEST(ExceptionTest, SSEDivideByZeroUnmasked) { - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); - - EXPECT_EXIT( - { - // Clear the divide by zero exception mask. - constexpr uint32_t kMXCSR = kMXCSRDefault & ~kMXCSRDiv0Mask; - - int32_t quotient; - int32_t value = 1; - int32_t divisor = 0; - asm volatile( - "ldmxcsr %[mxcsr]\r\n" - "cvtsi2ssl %[value], %%xmm0\r\n" - "cvtsi2ssl %[divisor], %%xmm1\r\n" - "divss %%xmm1, %%xmm0\r\n" - "cvtss2sil %%xmm0, %[quotient]\r\n" - : [ quotient ] "=r"(quotient) - : [ mxcsr ] "m"(kMXCSR), [ value ] "r"(value), - [ divisor ] "r"(divisor) - : "xmm0", "xmm1"); - }, - ::testing::KilledBySignal(SIGFPE), ""); -} - -// Pending exceptions in the SSE status register are not clobbered by syscalls. -TEST(ExceptionTest, SSEStatusClobber) { - uint32_t mxcsr; - int32_t quotient; - int32_t value = 1; - int32_t divisor = 0; - asm("cvtsi2ssl %[value], %%xmm0\r\n" - "cvtsi2ssl %[divisor], %%xmm1\r\n" - "divss %%xmm1, %%xmm0\r\n" - // Exception is masked, so it does not occur here. - "cvtss2sil %%xmm0, %[quotient]\r\n" - - // SYS_getpid placed in rax by constraint. - "syscall\r\n" - - // Intel SDM Vol 1, Ch 10.2.3.1 "SIMD Floating-Point Mask and Flag Bits": - // "If LDMXCSR or FXRSTOR clears a mask bit and sets the corresponding - // exception flag bit, a SIMD floating-point exception will not be - // generated as a result of this change. The unmasked exception will be - // generated only upon the execution of the next SSE/SSE2/SSE3 instruction - // that detects the unmasked exception condition." - // - // Though ambiguous, empirical evidence indicates that this means that - // exception flags set in the status register will never cause an - // exception to be raised; only a new exception condition will do so. - // - // Thus here we just check for the flag itself rather than trying to raise - // the exception. - "stmxcsr %[mxcsr]\r\n" - : [ quotient ] "=r"(quotient), [ mxcsr ] "+m"(mxcsr) - : [ value ] "r"(value), [ divisor ] "r"(divisor), "a"(SYS_getpid) - : "xmm0", "xmm1", "rcx", "r11"); - - EXPECT_TRUE(mxcsr & kMXCSRDiv0Flag); -} - -TEST(ExceptionTest, IOAccessFault) { - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa)); - - InIOHelper(1, 0x0); - InIOHelper(2, 0x7); - InIOHelper(4, 0x6); - InIOHelper(1, 0xffff); - InIOHelper(2, 0xffff); - InIOHelper(4, 0xfffd); -} - -TEST(ExceptionTest, Alignment) { - SetAlignmentCheck(); - ClearAlignmentCheck(); -} - -TEST(ExceptionTest, AlignmentHalt) { - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa)); - - // Reported upstream. We need to ensure that bad flags are cleared even in - // fault paths. Set the alignment flag and then generate an exception. - EXPECT_EXIT( - { - SetAlignmentCheck(); - Halt(); - }, - ::testing::KilledBySignal(SIGSEGV), ""); -} - -TEST(ExceptionTest, AlignmentCheck) { - SKIP_IF(PlatformSupportAlignmentCheck() != PlatformSupport::Allowed); - - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGBUS, sa)); - - EXPECT_EXIT( - { - char array[16]; - SetAlignmentCheck(); - for (int i = 0; i < 8; i++) { - // At least 7/8 offsets will be unaligned here. - uint64_t* ptr = reinterpret_cast<uint64_t*>(&array[i]); - asm("mov %0, 0(%0)\r\n" : : "r"(ptr) : "ax"); - } - }, - ::testing::KilledBySignal(SIGBUS), ""); -} - -TEST(ExceptionTest, Int3Normal) { - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGTRAP, sa)); - - EXPECT_EXIT(Int3Normal(), ::testing::KilledBySignal(SIGTRAP), ""); -} - -TEST(ExceptionTest, Int3Compact) { - // See above. - struct sigaction sa = {}; - sa.sa_handler = SIG_DFL; - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGTRAP, sa)); - - EXPECT_EXIT(Int3Compact(), ::testing::KilledBySignal(SIGTRAP), ""); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc deleted file mode 100644 index 07bd527e6..000000000 --- a/test/syscalls/linux/exec.cc +++ /dev/null @@ -1,872 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/exec.h" - -#include <errno.h> -#include <fcntl.h> -#include <sys/eventfd.h> -#include <sys/resource.h> -#include <sys/time.h> -#include <unistd.h> - -#include <iostream> -#include <memory> -#include <string> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/optional.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr char kBasicWorkload[] = "test/syscalls/linux/exec_basic_workload"; -constexpr char kExitScript[] = "test/syscalls/linux/exit_script"; -constexpr char kStateWorkload[] = "test/syscalls/linux/exec_state_workload"; -constexpr char kProcExeWorkload[] = - "test/syscalls/linux/exec_proc_exe_workload"; -constexpr char kAssertClosedWorkload[] = - "test/syscalls/linux/exec_assert_closed_workload"; -constexpr char kPriorityWorkload[] = "test/syscalls/linux/priority_execve"; - -constexpr char kExit42[] = "--exec_exit_42"; -constexpr char kExecWithThread[] = "--exec_exec_with_thread"; -constexpr char kExecFromThread[] = "--exec_exec_from_thread"; - -// Runs file specified by dirfd and pathname with argv and checks that the exit -// status is expect_status and that stderr contains expect_stderr. -void CheckExecHelper(const absl::optional<int32_t> dirfd, - const std::string& pathname, const ExecveArray& argv, - const ExecveArray& envv, const int flags, - int expect_status, const std::string& expect_stderr) { - int pipe_fds[2]; - ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds()); - - FileDescriptor read_fd(pipe_fds[0]); - FileDescriptor write_fd(pipe_fds[1]); - - pid_t child; - int execve_errno; - - const auto remap_stderr = [pipe_fds] { - // Remap stdin and stdout to /dev/null. - int fd = open("/dev/null", O_RDWR | O_CLOEXEC); - if (fd < 0) { - _exit(errno); - } - - int ret = dup2(fd, 0); - if (ret < 0) { - _exit(errno); - } - - ret = dup2(fd, 1); - if (ret < 0) { - _exit(errno); - } - - // And stderr to the pipe. - ret = dup2(pipe_fds[1], 2); - if (ret < 0) { - _exit(errno); - } - - // Here, we'd ideally close all other FDs inherited from the parent. - // However, that's not worth the effort and CloexecNormalFile and - // CloexecEventfd depend on that not happening. - }; - - Cleanup kill; - if (dirfd.has_value()) { - kill = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(*dirfd, pathname, argv, - envv, flags, remap_stderr, - &child, &execve_errno)); - } else { - kill = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(pathname, argv, envv, remap_stderr, &child, &execve_errno)); - } - - ASSERT_EQ(0, execve_errno); - - // Not needed anymore. - write_fd.reset(); - - // Read stderr until the child exits. - std::string output; - constexpr int kSize = 128; - char buf[kSize]; - int n; - do { - ASSERT_THAT(n = ReadFd(read_fd.get(), buf, kSize), SyscallSucceeds()); - if (n > 0) { - output.append(buf, n); - } - } while (n > 0); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); - EXPECT_EQ(status, expect_status); - - // Process cleanup no longer needed. - kill.Release(); - - EXPECT_TRUE(absl::StrContains(output, expect_stderr)) << output; -} - -void CheckExec(const std::string& filename, const ExecveArray& argv, - const ExecveArray& envv, int expect_status, - const std::string& expect_stderr) { - CheckExecHelper(/*dirfd=*/absl::optional<int32_t>(), filename, argv, envv, - /*flags=*/0, expect_status, expect_stderr); -} - -void CheckExecveat(const int32_t dirfd, const std::string& pathname, - const ExecveArray& argv, const ExecveArray& envv, - const int flags, int expect_status, - const std::string& expect_stderr) { - CheckExecHelper(absl::optional<int32_t>(dirfd), pathname, argv, envv, flags, - expect_status, expect_stderr); -} - -TEST(ExecTest, EmptyPath) { - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec("", {}, {}, nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, ENOENT); -} - -TEST(ExecTest, Basic) { - CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {}, - ArgEnvExitStatus(0, 0), - absl::StrCat(RunfilePath(kBasicWorkload), "\n")); -} - -TEST(ExecTest, OneArg) { - CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "1"}, {}, - ArgEnvExitStatus(1, 0), - absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n")); -} - -TEST(ExecTest, FiveArg) { - CheckExec(RunfilePath(kBasicWorkload), - {RunfilePath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, - ArgEnvExitStatus(5, 0), - absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); -} - -TEST(ExecTest, OneEnv) { - CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {"1"}, - ArgEnvExitStatus(0, 1), - absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n")); -} - -TEST(ExecTest, FiveEnv) { - CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, - {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5), - absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); -} - -TEST(ExecTest, OneArgOneEnv) { - CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "arg"}, - {"env"}, ArgEnvExitStatus(1, 1), - absl::StrCat(RunfilePath(kBasicWorkload), "\narg\nenv\n")); -} - -TEST(ExecTest, InterpreterScript) { - CheckExec(RunfilePath(kExitScript), {RunfilePath(kExitScript), "25"}, {}, - ArgEnvExitStatus(25, 0), ""); -} - -// Everything after the path in the interpreter script is a single argument. -TEST(ExecTest, InterpreterScriptArgSplit) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo bar"), - 0755)); - - CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo bar\n", script.path(), "\n")); -} - -// Original argv[0] is replaced with the script path. -TEST(ExecTest, InterpreterScriptArgvZero) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); - - CheckExec(script.path(), {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); -} - -// Original argv[0] is replaced with the script path, exactly as passed to -// execve. -TEST(ExecTest, InterpreterScriptArgvZeroRelative) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); - - auto cwd = ASSERT_NO_ERRNO_AND_VALUE(GetCWD()); - auto script_relative = - ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path())); - - CheckExec(script_relative, {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script_relative, "\n")); -} - -// argv[0] is added as the script path, even if there was none. -TEST(ExecTest, InterpreterScriptArgvZeroAdded) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); - - CheckExec(script.path(), {}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); -} - -// A NUL byte in the script line ends parsing. -TEST(ExecTest, InterpreterScriptArgNUL) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), - absl::StrCat("#!", link.path(), " foo", std::string(1, '\0'), "bar"), - 0755)); - - CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); -} - -// Trailing whitespace following interpreter path is ignored. -TEST(ExecTest, InterpreterScriptTrailingWhitespace) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " "), 0755)); - - CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); -} - -// Multiple whitespace characters between interpreter and arg allowed. -TEST(ExecTest, InterpreterScriptArgWhitespace) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo"), 0755)); - - CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); -} - -TEST(ExecTest, InterpreterScriptNoPath) { - TempPath script = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "#!", 0755)); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script.path(), {script.path()}, {}, nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, ENOEXEC); -} - -// AT_EXECFN is the path passed to execve. -TEST(ExecTest, ExecFn) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " PrintExecFn"), - 0755)); - - // Pass the script as a relative path and assert that is what appears in - // AT_EXECFN. - auto cwd = ASSERT_NO_ERRNO_AND_VALUE(GetCWD()); - auto script_relative = - ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path())); - - CheckExec(script_relative, {script_relative}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(script_relative, "\n")); -} - -TEST(ExecTest, ExecName) { - std::string path = RunfilePath(kStateWorkload); - - CheckExec(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(Basename(path).substr(0, 15), "\n")); -} - -TEST(ExecTest, ExecNameScript) { - // Symlink through /tmp to ensure the path is short enough. - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload))); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), - absl::StrCat("#!", link.path(), " PrintExecName"), 0755)); - - std::string script_path = script.path(); - - CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(Basename(script_path).substr(0, 15), "\n")); -} - -// execve may be called by a multithreaded process. -TEST(ExecTest, WithSiblingThread) { - CheckExec("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {}, - W_EXITCODE(42, 0), ""); -} - -// execve may be called from a thread other than the leader of a multithreaded -// process. -TEST(ExecTest, FromSiblingThread) { - CheckExec("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {}, - W_EXITCODE(42, 0), ""); -} - -TEST(ExecTest, NotFound) { - char* const argv[] = {nullptr}; - char* const envp[] = {nullptr}; - EXPECT_THAT(execve("/file/does/not/exist", argv, envp), - SyscallFailsWithErrno(ENOENT)); -} - -TEST(ExecTest, NoExecPerm) { - char* const argv[] = {nullptr}; - char* const envp[] = {nullptr}; - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - EXPECT_THAT(execve(f.path().c_str(), argv, envp), - SyscallFailsWithErrno(EACCES)); -} - -// A signal handler we never expect to be called. -void SignalHandler(int signo) { - std::cerr << "Signal " << signo << " raised." << std::endl; - exit(1); -} - -// Signal handlers are reset on execve(2), unless they have default or ignored -// disposition. -TEST(ExecStateTest, HandlerReset) { - struct sigaction sa; - sa.sa_handler = SignalHandler; - ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); - - ExecveArray args = { - RunfilePath(kStateWorkload), - "CheckSigHandler", - absl::StrCat(SIGUSR1), - absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_DFL))), - }; - - CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); -} - -// Ignored signal dispositions are not reset. -TEST(ExecStateTest, IgnorePreserved) { - struct sigaction sa; - sa.sa_handler = SIG_IGN; - ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); - - ExecveArray args = { - RunfilePath(kStateWorkload), - "CheckSigHandler", - absl::StrCat(SIGUSR1), - absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_IGN))), - }; - - CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); -} - -// Signal masks are not reset on exec -TEST(ExecStateTest, SignalMask) { - sigset_t s; - sigemptyset(&s); - sigaddset(&s, SIGUSR1); - ASSERT_THAT(sigprocmask(SIG_BLOCK, &s, nullptr), SyscallSucceeds()); - - ExecveArray args = { - RunfilePath(kStateWorkload), - "CheckSigBlocked", - absl::StrCat(SIGUSR1), - }; - - CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); -} - -// itimers persist across execve. -// N.B. Timers created with timer_create(2) should not be preserved! -TEST(ExecStateTest, ItimerPreserved) { - // The fork in ForkAndExec clears itimers, so only set them up after fork. - auto setup_itimer = [] { - // Ignore SIGALRM, as we don't actually care about timer - // expirations. - struct sigaction sa; - sa.sa_handler = SIG_IGN; - int ret = sigaction(SIGALRM, &sa, nullptr); - if (ret < 0) { - _exit(errno); - } - - struct itimerval itv; - itv.it_interval.tv_sec = 1; - itv.it_interval.tv_usec = 0; - itv.it_value.tv_sec = 1; - itv.it_value.tv_usec = 0; - ret = setitimer(ITIMER_REAL, &itv, nullptr); - if (ret < 0) { - _exit(errno); - } - }; - - std::string filename = RunfilePath(kStateWorkload); - ExecveArray argv = { - filename, - "CheckItimerEnabled", - absl::StrCat(ITIMER_REAL), - }; - - pid_t child; - int execve_errno; - auto kill = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(filename, argv, {}, setup_itimer, &child, &execve_errno)); - ASSERT_EQ(0, execve_errno); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); - EXPECT_EQ(0, status); - - // Process cleanup no longer needed. - kill.Release(); -} - -TEST(ProcSelfExe, ChangesAcrossExecve) { - // See exec_proc_exe_workload for more details. We simply - // assert that the /proc/self/exe link changes across execve. - CheckExec(RunfilePath(kProcExeWorkload), - {RunfilePath(kProcExeWorkload), - ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))}, - {}, W_EXITCODE(0, 0), ""); -} - -TEST(ExecTest, CloexecNormalFile) { - TempPath tempFile = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "bar", 0755)); - const FileDescriptor fd_closed_on_exec = - ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC)); - - CheckExec(RunfilePath(kAssertClosedWorkload), - {RunfilePath(kAssertClosedWorkload), - absl::StrCat(fd_closed_on_exec.get())}, - {}, W_EXITCODE(0, 0), ""); - - // The assert closed workload exits with code 2 if the file still exists. We - // can use this to do a negative test. - const FileDescriptor fd_open_on_exec = - ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY)); - - CheckExec( - RunfilePath(kAssertClosedWorkload), - {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd_open_on_exec.get())}, - {}, W_EXITCODE(2, 0), ""); -} - -TEST(ExecTest, CloexecEventfd) { - int efd; - ASSERT_THAT(efd = eventfd(0, EFD_CLOEXEC), SyscallSucceeds()); - FileDescriptor fd(efd); - - CheckExec(RunfilePath(kAssertClosedWorkload), - {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, - W_EXITCODE(0, 0), ""); -} - -constexpr int kLinuxMaxSymlinks = 40; - -TEST(ExecTest, SymlinkLimitExceeded) { - std::string path = RunfilePath(kBasicWorkload); - - // Hold onto TempPath objects so they are not destructed prematurely. - std::vector<TempPath> symlinks; - for (int i = 0; i < kLinuxMaxSymlinks + 1; i++) { - symlinks.push_back( - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", path))); - path = symlinks[i].path(); - } - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(path, {path}, {}, /*child=*/nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, ELOOP); -} - -TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) { - std::string tmp_dir = "/tmp"; - std::string interpreter_path = "/bin/echo"; - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - tmp_dir, absl::StrCat("#!", interpreter_path), 0755)); - std::string script_path = script.path(); - - // Hold onto TempPath objects so they are not destructed prematurely. - std::vector<TempPath> interpreter_symlinks; - std::vector<TempPath> script_symlinks; - for (int i = 0; i < kLinuxMaxSymlinks; i++) { - interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(tmp_dir, interpreter_path))); - interpreter_path = interpreter_symlinks[i].path(); - script_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(tmp_dir, script_path))); - script_path = script_symlinks[i].path(); - } - - CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), ""); -} - -TEST(ExecveatTest, BasicWithFDCWD) { - std::string path = RunfilePath(kBasicWorkload); - CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0), - absl::StrCat(path, "\n")); -} - -TEST(ExecveatTest, Basic) { - std::string absolute_path = RunfilePath(kBasicWorkload); - std::string parent_dir = std::string(Dirname(absolute_path)); - std::string base = std::string(Basename(absolute_path)); - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); - - CheckExecveat(dirfd.get(), base, {absolute_path}, {}, /*flags=*/0, - ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n")); -} - -TEST(ExecveatTest, FDNotADirectory) { - std::string absolute_path = RunfilePath(kBasicWorkload); - std::string base = std::string(Basename(absolute_path)); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(absolute_path, 0)); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), base, {absolute_path}, {}, - /*flags=*/0, /*child=*/nullptr, - &execve_errno)); - EXPECT_EQ(execve_errno, ENOTDIR); -} - -TEST(ExecveatTest, AbsolutePathWithFDCWD) { - std::string path = RunfilePath(kBasicWorkload); - CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, - absl::StrCat(path, "\n")); -} - -TEST(ExecveatTest, AbsolutePath) { - std::string path = RunfilePath(kBasicWorkload); - // File descriptor should be ignored when an absolute path is given. - const int32_t badFD = -1; - CheckExecveat(badFD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, - absl::StrCat(path, "\n")); -} - -TEST(ExecveatTest, EmptyPathBasic) { - std::string path = RunfilePath(kBasicWorkload); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); - - CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH, ArgEnvExitStatus(0, 0), - absl::StrCat(path, "\n")); -} - -TEST(ExecveatTest, EmptyPathWithDirFD) { - std::string path = RunfilePath(kBasicWorkload); - std::string parent_dir = std::string(Dirname(path)); - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), "", {path}, {}, - AT_EMPTY_PATH, - /*child=*/nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, EACCES); -} - -TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) { - std::string path = RunfilePath(kBasicWorkload); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( - fd.get(), "", {path}, {}, /*flags=*/0, /*child=*/nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, ENOENT); -} - -TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) { - std::string path = RunfilePath(kBasicWorkload); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); - - CheckExecveat(fd.get(), path, {path}, {}, AT_EMPTY_PATH, - ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); -} - -TEST(ExecveatTest, RelativePathWithEmptyPathFlag) { - std::string absolute_path = RunfilePath(kBasicWorkload); - std::string parent_dir = std::string(Dirname(absolute_path)); - std::string base = std::string(Basename(absolute_path)); - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); - - CheckExecveat(dirfd.get(), base, {absolute_path}, {}, AT_EMPTY_PATH, - ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n")); -} - -TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) { - std::string parent_dir = "/tmp"; - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload))); - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); - std::string base = std::string(Basename(link.path())); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, - AT_SYMLINK_NOFOLLOW, - /*child=*/nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, ELOOP); -} - -TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) { - std::string parent_dir = "/tmp"; - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload))); - std::string path = link.path(); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(AT_FDCWD, path, {path}, {}, - AT_SYMLINK_NOFOLLOW, - /*child=*/nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, ELOOP); -} - -TEST(ExecveatTest, SymlinkNoFollowAndEmptyPath) { - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); - std::string path = link.path(); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, 0)); - - CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH | AT_SYMLINK_NOFOLLOW, - ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); -} - -TEST(ExecveatTest, SymlinkNoFollowIgnoreSymlinkAncestor) { - TempPath parent_link = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", "/bin")); - std::string path_with_symlink = JoinPath(parent_link.path(), "echo"); - - CheckExecveat(AT_FDCWD, path_with_symlink, {path_with_symlink}, {}, - AT_SYMLINK_NOFOLLOW, ArgEnvExitStatus(0, 0), ""); -} - -TEST(ExecveatTest, SymlinkNoFollowWithNormalFile) { - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/bin", O_DIRECTORY)); - - CheckExecveat(dirfd.get(), "echo", {"echo"}, {}, AT_SYMLINK_NOFOLLOW, - ArgEnvExitStatus(0, 0), ""); -} - -TEST(ExecveatTest, BasicWithCloexecFD) { - std::string path = RunfilePath(kBasicWorkload); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); - - CheckExecveat(fd.get(), "", {path}, {}, AT_SYMLINK_NOFOLLOW | AT_EMPTY_PATH, - ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); -} - -TEST(ExecveatTest, InterpreterScriptWithCloexecFD) { - std::string path = RunfilePath(kExitScript); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), "", {path}, {}, - AT_EMPTY_PATH, /*child=*/nullptr, - &execve_errno)); - EXPECT_EQ(execve_errno, ENOENT); -} - -TEST(ExecveatTest, InterpreterScriptWithCloexecDirFD) { - std::string absolute_path = RunfilePath(kExitScript); - std::string parent_dir = std::string(Dirname(absolute_path)); - std::string base = std::string(Basename(absolute_path)); - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_CLOEXEC | O_DIRECTORY)); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, - /*flags=*/0, /*child=*/nullptr, - &execve_errno)); - EXPECT_EQ(execve_errno, ENOENT); -} - -TEST(ExecveatTest, InvalidFlags) { - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( - /*dirfd=*/-1, "", {}, {}, /*flags=*/0xFFFF, /*child=*/nullptr, - &execve_errno)); - EXPECT_EQ(execve_errno, EINVAL); -} - -// Priority consistent across calls to execve() -TEST(GetpriorityTest, ExecveMaintainsPriority) { - int prio = 16; - ASSERT_THAT(setpriority(PRIO_PROCESS, getpid(), prio), SyscallSucceeds()); - - // To avoid trying to use negative exit values, check for - // 20 - prio. Since prio should always be in the range [-20, 19], - // this leave expected_exit_code in the range [1, 40]. - int expected_exit_code = 20 - prio; - - // Program run (priority_execve) will exit(X) where - // X=getpriority(PRIO_PROCESS,0). Check that this exit value is prio. - CheckExec(RunfilePath(kPriorityWorkload), {RunfilePath(kPriorityWorkload)}, - {}, W_EXITCODE(expected_exit_code, 0), ""); -} - -void ExecWithThread() { - // Used to ensure that the thread has actually started. - absl::Mutex mu; - bool started = false; - - ScopedThread t([&] { - mu.Lock(); - started = true; - mu.Unlock(); - - while (true) { - pause(); - } - }); - - mu.LockWhen(absl::Condition(&started)); - mu.Unlock(); - - const ExecveArray argv = {"/proc/self/exe", kExit42}; - const ExecveArray envv; - - execve("/proc/self/exe", argv.get(), envv.get()); - exit(errno); -} - -void ExecFromThread() { - ScopedThread t([] { - const ExecveArray argv = {"/proc/self/exe", kExit42}; - const ExecveArray envv; - - execve("/proc/self/exe", argv.get(), envv.get()); - exit(errno); - }); - - while (true) { - pause(); - } -} - -bool ValidateProcCmdlineVsArgv(const int argc, const char* const* argv) { - auto contents_or = GetContents("/proc/self/cmdline"); - if (!contents_or.ok()) { - std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error(); - return false; - } - auto contents = contents_or.ValueOrDie(); - if (contents.back() != '\0') { - std::cerr << "Non-null terminated /proc/self/cmdline!"; - return false; - } - contents.pop_back(); - std::vector<std::string> procfs_cmdline = absl::StrSplit(contents, '\0'); - - if (static_cast<int>(procfs_cmdline.size()) != argc) { - std::cerr << "argc = " << argc << " != " << procfs_cmdline.size(); - return false; - } - - for (int i = 0; i < argc; ++i) { - if (procfs_cmdline[i] != argv[i]) { - std::cerr << "Procfs command line argument " << i << " mismatch " - << procfs_cmdline[i] << " != " << argv[i]; - return false; - } - } - return true; -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - // Start by validating that the stack argv is consistent with procfs. - if (!gvisor::testing::ValidateProcCmdlineVsArgv(argc, argv)) { - return 1; - } - - // Some of these tests require no background threads, so check for them before - // TestInit. - for (int i = 0; i < argc; i++) { - absl::string_view arg(argv[i]); - - if (arg == gvisor::testing::kExit42) { - return 42; - } - if (arg == gvisor::testing::kExecWithThread) { - gvisor::testing::ExecWithThread(); - return 1; - } - if (arg == gvisor::testing::kExecFromThread) { - gvisor::testing::ExecFromThread(); - return 1; - } - } - - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/exec.h b/test/syscalls/linux/exec.h deleted file mode 100644 index 5c0f7e654..000000000 --- a/test/syscalls/linux/exec.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_EXEC_H_ -#define GVISOR_TEST_SYSCALLS_EXEC_H_ - -#include <sys/wait.h> - -namespace gvisor { -namespace testing { - -// Returns the exit code used by exec_basic_workload. -inline int ArgEnvExitCode(int args, int envs) { return args + envs * 10; } - -// Returns the exit status used by exec_basic_workload. -inline int ArgEnvExitStatus(int args, int envs) { - return W_EXITCODE(ArgEnvExitCode(args, envs), 0); -} - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_EXEC_H_ diff --git a/test/syscalls/linux/exec_assert_closed_workload.cc b/test/syscalls/linux/exec_assert_closed_workload.cc deleted file mode 100644 index 95643618d..000000000 --- a/test/syscalls/linux/exec_assert_closed_workload.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <stdlib.h> -#include <sys/stat.h> -#include <unistd.h> - -#include <iostream> - -#include "absl/strings/numbers.h" - -int main(int argc, char** argv) { - if (argc != 2) { - std::cerr << "need two arguments, got " << argc; - exit(1); - } - int fd; - if (!absl::SimpleAtoi(argv[1], &fd)) { - std::cerr << "fd: " << argv[1] << " could not be parsed" << std::endl; - exit(1); - } - struct stat s; - if (fstat(fd, &s) == 0) { - std::cerr << "fd: " << argv[1] << " should not be valid" << std::endl; - exit(2); - } - if (errno != EBADF) { - std::cerr << "fstat fd: " << argv[1] << " got errno: " << errno - << " wanted: " << EBADF << std::endl; - exit(1); - } - return 0; -} diff --git a/test/syscalls/linux/exec_basic_workload.cc b/test/syscalls/linux/exec_basic_workload.cc deleted file mode 100644 index 1bbd6437e..000000000 --- a/test/syscalls/linux/exec_basic_workload.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdlib.h> - -#include <iostream> - -#include "test/syscalls/linux/exec.h" - -int main(int argc, char** argv, char** envp) { - int i; - for (i = 0; i < argc; i++) { - std::cerr << argv[i] << std::endl; - } - for (i = 0; envp[i] != nullptr; i++) { - std::cerr << envp[i] << std::endl; - } - exit(gvisor::testing::ArgEnvExitCode(argc - 1, i)); - return 0; -} diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc deleted file mode 100644 index 736452b0c..000000000 --- a/test/syscalls/linux/exec_binary.cc +++ /dev/null @@ -1,1521 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <elf.h> -#include <errno.h> -#include <signal.h> -#include <sys/ptrace.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <sys/user.h> -#include <unistd.h> - -#include <algorithm> -#include <functional> -#include <iterator> -#include <tuple> -#include <utility> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/proc_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -using ::testing::AnyOf; -using ::testing::Eq; - -#ifndef __x86_64__ -// The assembly stub and ELF internal details must be ported to other arches. -#error "Test only supported on x86-64" -#endif // __x86_64__ - -// amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP. -const char kPtraceCode[] = { - // movq $101, %rax /* ptrace */ - '\x48', - '\xc7', - '\xc0', - '\x65', - '\x00', - '\x00', - '\x00', - // movq $0, %rsi /* PTRACE_TRACEME */ - '\x48', - '\xc7', - '\xc6', - '\x00', - '\x00', - '\x00', - '\x00', - // movq $0, %rdi - '\x48', - '\xc7', - '\xc7', - '\x00', - '\x00', - '\x00', - '\x00', - // movq $0, %rdx - '\x48', - '\xc7', - '\xc2', - '\x00', - '\x00', - '\x00', - '\x00', - // movq $0, %r10 - '\x49', - '\xc7', - '\xc2', - '\x00', - '\x00', - '\x00', - '\x00', - // syscall - '\x0f', - '\x05', - - // movq $39, %rax /* getpid */ - '\x48', - '\xc7', - '\xc0', - '\x27', - '\x00', - '\x00', - '\x00', - // syscall - '\x0f', - '\x05', - - // movq %rax, %rdi /* pid */ - '\x48', - '\x89', - '\xc7', - // movq $62, %rax /* kill */ - '\x48', - '\xc7', - '\xc0', - '\x3e', - '\x00', - '\x00', - '\x00', - // movq $19, %rsi /* SIGSTOP */ - '\x48', - '\xc7', - '\xc6', - '\x13', - '\x00', - '\x00', - '\x00', - // syscall - '\x0f', - '\x05', -}; - -// Size of a syscall instruction. -constexpr int kSyscallSize = 2; - -// This test suite tests executable loading in the kernel (ELF and interpreter -// scripts). - -// Parameterized ELF types for 64 and 32 bit. -template <int Size> -struct ElfTypes; - -template <> -struct ElfTypes<64> { - typedef Elf64_Ehdr ElfEhdr; - typedef Elf64_Phdr ElfPhdr; -}; - -template <> -struct ElfTypes<32> { - typedef Elf32_Ehdr ElfEhdr; - typedef Elf32_Phdr ElfPhdr; -}; - -template <int Size> -struct ElfBinary { - using ElfEhdr = typename ElfTypes<Size>::ElfEhdr; - using ElfPhdr = typename ElfTypes<Size>::ElfPhdr; - - ElfEhdr header = {}; - std::vector<ElfPhdr> phdrs; - std::vector<char> data; - - // UpdateOffsets updates p_offset, p_vaddr in all phdrs to account for the - // space taken by the header and phdrs. - // - // It also updates header.e_phnum and adds the offset to header.e_entry to - // account for the headers residing in the first PT_LOAD segment. - // - // Before calling UpdateOffsets each of those fields should be the appropriate - // offset into data. - void UpdateOffsets() { - size_t offset = sizeof(header) + phdrs.size() * sizeof(ElfPhdr); - header.e_entry += offset; - header.e_phnum = phdrs.size(); - for (auto& p : phdrs) { - p.p_offset += offset; - p.p_vaddr += offset; - } - } - - // AddInterpreter adds a PT_INTERP segment with the passed contents. - // - // A later call to UpdateOffsets is required to make the new phdr valid. - void AddInterpreter(std::vector<char> contents) { - const int start = data.size(); - data.insert(data.end(), contents.begin(), contents.end()); - const int size = data.size() - start; - - ElfPhdr phdr = {}; - phdr.p_type = PT_INTERP; - phdr.p_offset = start; - phdr.p_filesz = size; - phdr.p_memsz = size; - // "If [PT_INTERP] is present, it must precede any loadable segment entry." - phdrs.insert(phdrs.begin(), phdr); - } - - // Writes the header, phdrs, and data to fd. - PosixError Write(int fd) const { - int ret = WriteFd(fd, &header, sizeof(header)); - if (ret < 0) { - return PosixError(errno, "failed to write header"); - } else if (ret != sizeof(header)) { - return PosixError(EIO, absl::StrCat("short write of header: ", ret)); - } - - for (auto const& p : phdrs) { - ret = WriteFd(fd, &p, sizeof(p)); - if (ret < 0) { - return PosixError(errno, "failed to write phdr"); - } else if (ret != sizeof(p)) { - return PosixError(EIO, absl::StrCat("short write of phdr: ", ret)); - } - } - - ret = WriteFd(fd, data.data(), data.size()); - if (ret < 0) { - return PosixError(errno, "failed to write data"); - } else if (ret != static_cast<int>(data.size())) { - return PosixError(EIO, absl::StrCat("short write of data: ", ret)); - } - - return NoError(); - } -}; - -// Creates a new temporary executable ELF file in parent with elf as the -// contents. -template <int Size> -PosixErrorOr<TempPath> CreateElfWith(absl::string_view parent, - ElfBinary<Size> const& elf) { - ASSIGN_OR_RETURN_ERRNO( - auto file, TempPath::CreateFileWith(parent, absl::string_view(), 0755)); - ASSIGN_OR_RETURN_ERRNO(auto fd, Open(file.path(), O_RDWR)); - RETURN_IF_ERRNO(elf.Write(fd.get())); - return std::move(file); -} - -// Creates a new temporary executable ELF file with elf as the contents. -template <int Size> -PosixErrorOr<TempPath> CreateElfWith(ElfBinary<Size> const& elf) { - return CreateElfWith(GetAbsoluteTestTmpdir(), elf); -} - -// Wait for pid to stop, and assert that it stopped via SIGSTOP. -PosixError WaitStopped(pid_t pid) { - int status; - int ret = RetryEINTR(waitpid)(pid, &status, 0); - MaybeSave(); - if (ret < 0) { - return PosixError(errno, "wait failed"); - } else if (ret != pid) { - return PosixError(ESRCH, absl::StrCat("wait got ", ret, " want ", pid)); - } - - if (!WIFSTOPPED(status) || WSTOPSIG(status) != SIGSTOP) { - return PosixError(EINVAL, - absl::StrCat("pid did not SIGSTOP; status = ", status)); - } - - return NoError(); -} - -// Returns a valid ELF that PTRACE_TRACEME and SIGSTOPs itself. -// -// UpdateOffsets must be called before writing this ELF. -ElfBinary<64> StandardElf() { - ElfBinary<64> elf; - elf.header.e_ident[EI_MAG0] = ELFMAG0; - elf.header.e_ident[EI_MAG1] = ELFMAG1; - elf.header.e_ident[EI_MAG2] = ELFMAG2; - elf.header.e_ident[EI_MAG3] = ELFMAG3; - elf.header.e_ident[EI_CLASS] = ELFCLASS64; - elf.header.e_ident[EI_DATA] = ELFDATA2LSB; - elf.header.e_ident[EI_VERSION] = EV_CURRENT; - elf.header.e_type = ET_EXEC; - elf.header.e_machine = EM_X86_64; - elf.header.e_version = EV_CURRENT; - elf.header.e_phoff = sizeof(elf.header); - elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr); - - // TODO(gvisor.dev/issue/153): Always include a PT_GNU_STACK segment to - // disable executable stacks. With this omitted the stack (and all PROT_READ) - // mappings should be executable, but gVisor doesn't support that. - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_GNU_STACK; - phdr.p_flags = PF_R | PF_W; - elf.phdrs.push_back(phdr); - - phdr = {}; - phdr.p_type = PT_LOAD; - phdr.p_flags = PF_R | PF_X; - phdr.p_offset = 0; - phdr.p_vaddr = 0x40000; - phdr.p_filesz = sizeof(kPtraceCode); - phdr.p_memsz = phdr.p_filesz; - elf.phdrs.push_back(phdr); - - elf.header.e_entry = phdr.p_vaddr; - - elf.data.assign(kPtraceCode, kPtraceCode + sizeof(kPtraceCode)); - - return elf; -} - -// Test that a trivial binary executes. -TEST(ElfTest, Execute) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - // Ensure it made it to SIGSTOP. - ASSERT_NO_ERRNO(WaitStopped(child)); - - struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - // RIP is just beyond the final syscall instruction. - EXPECT_EQ(regs.rip, elf.header.e_entry + sizeof(kPtraceCode)); - - EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - file.path().c_str()}, - }))); -} - -// StandardElf without data completes execve, but faults once running. -TEST(ElfTest, MissingText) { - ElfBinary<64> elf = StandardElf(); - elf.data.clear(); - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - // It runs off the end of the zeroes filling the end of the page. - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status; -} - -// Typical ELF with a data + bss segment -TEST(ElfTest, DataSegment) { - ElfBinary<64> elf = StandardElf(); - - // Create a standard ELF, but extend to 1.5 pages. The second page will be the - // beginning of a multi-page data + bss segment. - elf.data.resize(kPageSize + kPageSize / 2); - - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_LOAD; - phdr.p_flags = PF_R | PF_W; - phdr.p_offset = kPageSize; - phdr.p_vaddr = 0x41000; - phdr.p_filesz = kPageSize / 2; - // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a - // bit less than 2 pages so this mapping doesn't extend beyond 0x43000. - phdr.p_memsz = 2 * kPageSize - kPageSize / 2; - elf.phdrs.push_back(phdr); - - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - EXPECT_THAT( - child, ContainsMappings(std::vector<ProcMapsEntry>({ - // text page. - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - file.path().c_str()}, - // data + bss page from file. - {0x41000, 0x42000, true, true, false, true, kPageSize, 0, 0, 0, - file.path().c_str()}, - // bss page from anon. - {0x42000, 0x43000, true, true, false, true, 0, 0, 0, 0, ""}, - }))); -} - -// Additonal pages beyond filesz honor (only) execute protections. -// -// N.B. Linux changed this in 4.11 (16e72e9b30986 "powerpc: do not make the -// entire heap executable"). Previously, extra pages were always RW. -TEST(ElfTest, ExtraMemPages) { - // gVisor has the newer behavior. - if (!IsRunningOnGvisor()) { - auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); - SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 11)); - } - - ElfBinary<64> elf = StandardElf(); - - // Create a standard ELF, but extend to 1.5 pages. The second page will be the - // beginning of a multi-page data + bss segment. - elf.data.resize(kPageSize + kPageSize / 2); - - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_LOAD; - // RWX segment. The extra anon page will also be RWX. - // - // N.B. Linux uses clear_user to clear the end of the file-mapped page, which - // respects the mapping protections. Thus if we map this RO with memsz > - // (unaligned) filesz, then execve will fail with EFAULT. See padzero(elf_bss) - // in fs/binfmt_elf.c:load_elf_binary. - // - // N.N.B.B. The above only applies to the last segment. For earlier segments, - // the clear_user error is ignored. - phdr.p_flags = PF_R | PF_W | PF_X; - phdr.p_offset = kPageSize; - phdr.p_vaddr = 0x41000; - phdr.p_filesz = kPageSize / 2; - // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a - // bit less than 2 pages so this mapping doesn't extend beyond 0x43000. - phdr.p_memsz = 2 * kPageSize - kPageSize / 2; - elf.phdrs.push_back(phdr); - - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - EXPECT_THAT(child, - ContainsMappings(std::vector<ProcMapsEntry>({ - // text page. - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - file.path().c_str()}, - // data + bss page from file. - {0x41000, 0x42000, true, true, true, true, kPageSize, 0, 0, 0, - file.path().c_str()}, - // extra page from anon. - {0x42000, 0x43000, true, true, true, true, 0, 0, 0, 0, ""}, - }))); -} - -// An aligned segment with filesz == 0, memsz > 0 is anon-only. -TEST(ElfTest, AnonOnlySegment) { - ElfBinary<64> elf = StandardElf(); - - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_LOAD; - // RO segment. The extra anon page will be RW anyways. - phdr.p_flags = PF_R; - phdr.p_offset = 0; - phdr.p_vaddr = 0x41000; - phdr.p_filesz = 0; - phdr.p_memsz = kPageSize; - elf.phdrs.push_back(phdr); - - elf.UpdateOffsets(); - - // UpdateOffsets adjusts p_vaddr and p_offset by the header size, but we need - // a page-aligned p_vaddr to get a truly anon-only page. - elf.phdrs[2].p_vaddr = 0x41000; - // N.B. p_offset is now unaligned, but Linux doesn't care since this is - // anon-only. - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - EXPECT_THAT(child, - ContainsMappings(std::vector<ProcMapsEntry>({ - // text page. - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - file.path().c_str()}, - // anon page. - {0x41000, 0x42000, true, true, false, true, 0, 0, 0, 0, ""}, - }))); -} - -// p_offset must have the same alignment as p_vaddr. -TEST(ElfTest, UnalignedOffset) { - ElfBinary<64> elf = StandardElf(); - - // Unaligned offset. - elf.phdrs[1].p_offset += 1; - - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - - // execve(2) return EINVAL, but behavior varies between Linux and gVisor. - // - // On Linux, the new mm is committed before attempting to map into it. By the - // time we hit EINVAL in the segment mmap, the old mm is gone. Linux returns - // to an empty mm, which immediately segfaults. - // - // OTOH, gVisor maps into the new mm before committing it. Thus when it hits - // failure, the caller is still intact to receive the error. - if (IsRunningOnGvisor()) { - ASSERT_EQ(execve_errno, EINVAL); - } else { - ASSERT_EQ(execve_errno, 0); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status; - } -} - -// Linux will allow PT_LOAD segments to overlap. -TEST(ElfTest, DirectlyOverlappingSegments) { - // NOTE(b/37289926): see PIEOutOfOrderSegments. - SKIP_IF(IsRunningOnGvisor()); - - ElfBinary<64> elf = StandardElf(); - - // Same as the StandardElf mapping. - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_LOAD; - // Add PF_W so we can differentiate this mapping from the first. - phdr.p_flags = PF_R | PF_W | PF_X; - phdr.p_offset = 0; - phdr.p_vaddr = 0x40000; - phdr.p_filesz = sizeof(kPtraceCode); - phdr.p_memsz = phdr.p_filesz; - elf.phdrs.push_back(phdr); - - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ - {0x40000, 0x41000, true, true, true, true, 0, 0, 0, 0, - file.path().c_str()}, - }))); -} - -// Linux allows out-of-order PT_LOAD segments. -TEST(ElfTest, OutOfOrderSegments) { - // NOTE(b/37289926): see PIEOutOfOrderSegments. - SKIP_IF(IsRunningOnGvisor()); - - ElfBinary<64> elf = StandardElf(); - - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_LOAD; - phdr.p_flags = PF_R | PF_X; - phdr.p_offset = 0; - phdr.p_vaddr = 0x20000; - phdr.p_filesz = sizeof(kPtraceCode); - phdr.p_memsz = phdr.p_filesz; - elf.phdrs.push_back(phdr); - - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ - {0x20000, 0x21000, true, false, true, true, 0, 0, 0, 0, - file.path().c_str()}, - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - file.path().c_str()}, - }))); -} - -// header.e_phoff is bound the end of the file. -TEST(ElfTest, OutOfBoundsPhdrs) { - ElfBinary<64> elf = StandardElf(); - elf.header.e_phoff = 0x100000; - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - // On Linux 3.11, this caused EIO. On newer Linux, it causes ENOEXEC. - EXPECT_THAT(execve_errno, AnyOf(Eq(ENOEXEC), Eq(EIO))); -} - -// Claim there is a phdr beyond the end of the file, but don't include it. -TEST(ElfTest, MissingPhdr) { - ElfBinary<64> elf = StandardElf(); - - // Clear data so the file ends immediately after the phdrs. - // N.B. Per ElfTest.MissingData, StandardElf without data completes execve - // without error. - elf.data.clear(); - elf.UpdateOffsets(); - - // Claim that there is another phdr just beyond the end of the file. Of - // course, it isn't accessible. - elf.header.e_phnum++; - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - // On Linux 3.11, this caused EIO. On newer Linux, it causes ENOEXEC. - EXPECT_THAT(execve_errno, AnyOf(Eq(ENOEXEC), Eq(EIO))); -} - -// No headers at all, just the ELF magic. -TEST(ElfTest, MissingHeader) { - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0755)); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - const char kElfMagic[] = {0x7f, 'E', 'L', 'F'}; - - ASSERT_THAT(WriteFd(fd.get(), &kElfMagic, sizeof(kElfMagic)), - SyscallSucceeds()); - fd.reset(); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - EXPECT_EQ(execve_errno, ENOEXEC); -} - -// Load a PIE ELF with a data + bss segment. -TEST(ElfTest, PIE) { - ElfBinary<64> elf = StandardElf(); - - elf.header.e_type = ET_DYN; - - // Create a standard ELF, but extend to 1.5 pages. The second page will be the - // beginning of a multi-page data + bss segment. - elf.data.resize(kPageSize + kPageSize / 2); - - elf.header.e_entry = 0x0; - - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_LOAD; - phdr.p_flags = PF_R | PF_W; - phdr.p_offset = kPageSize; - // Put the data segment at a bit of an offset. - phdr.p_vaddr = 0x20000; - phdr.p_filesz = kPageSize / 2; - // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a - // bit less than 2 pages so this mapping doesn't extend beyond 0x43000. - phdr.p_memsz = 2 * kPageSize - kPageSize / 2; - elf.phdrs.push_back(phdr); - - elf.UpdateOffsets(); - - // The first segment really needs to start at 0 for a normal PIE binary, and - // thus includes the headers. - const uint64_t offset = elf.phdrs[1].p_offset; - elf.phdrs[1].p_offset = 0x0; - elf.phdrs[1].p_vaddr = 0x0; - elf.phdrs[1].p_filesz += offset; - elf.phdrs[1].p_memsz += offset; - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - // RIP tells us which page the first segment was loaded into. - struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); - - EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ - // text page. - {load_addr, load_addr + 0x1000, true, false, true, - true, 0, 0, 0, 0, file.path().c_str()}, - // data + bss page from file. - {load_addr + 0x20000, load_addr + 0x21000, true, true, - false, true, kPageSize, 0, 0, 0, file.path().c_str()}, - // bss page from anon. - {load_addr + 0x21000, load_addr + 0x22000, true, true, - false, true, 0, 0, 0, 0, ""}, - }))); -} - -// PIE binary with a non-zero start address. -// -// This is non-standard for a PIE binary, but valid. The binary is still loaded -// at an arbitrary address, not the first PT_LOAD vaddr. -// -// N.B. Linux changed this behavior in d1fd836dcf00d2028c700c7e44d2c23404062c90. -// Previously, with "randomization" enabled, PIE binaries with a non-zero start -// address would be be loaded at the address they specified because mmap was -// passed the load address, which wasn't 0 as expected. -// -// This change is present in kernel v4.1+. -TEST(ElfTest, PIENonZeroStart) { - // gVisor has the newer behavior. - if (!IsRunningOnGvisor()) { - auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); - SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 1)); - } - - ElfBinary<64> elf = StandardElf(); - - elf.header.e_type = ET_DYN; - - // Create a standard ELF, but extend to 1.5 pages. The second page will be the - // beginning of a multi-page data + bss segment. - elf.data.resize(kPageSize + kPageSize / 2); - - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_LOAD; - phdr.p_flags = PF_R | PF_W; - phdr.p_offset = kPageSize; - // Put the data segment at a bit of an offset. - phdr.p_vaddr = 0x60000; - phdr.p_filesz = kPageSize / 2; - // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a - // bit less than 2 pages so this mapping doesn't extend beyond 0x43000. - phdr.p_memsz = 2 * kPageSize - kPageSize / 2; - elf.phdrs.push_back(phdr); - - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - // RIP tells us which page the first segment was loaded into. - struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); - - // The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr. - // - // N.B. this is technically flaky, but Linux is *extremely* unlikely to pick - // this as the start address, as it searches from the top down. - EXPECT_NE(load_addr, 0x40000); - - EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ - // text page. - {load_addr, load_addr + 0x1000, true, false, true, - true, 0, 0, 0, 0, file.path().c_str()}, - // data + bss page from file. - {load_addr + 0x20000, load_addr + 0x21000, true, true, - false, true, kPageSize, 0, 0, 0, file.path().c_str()}, - // bss page from anon. - {load_addr + 0x21000, load_addr + 0x22000, true, true, - false, true, 0, 0, 0, 0, ""}, - }))); -} - -TEST(ElfTest, PIEOutOfOrderSegments) { - // TODO(b/37289926): This triggers a bug in Linux where it computes the size - // of the binary as 0x20000 - 0x40000 = 0xfffffffffffe0000, which obviously - // fails to map. - // - // We test gVisor's behavior (of rejecting the binary) because I assert that - // Linux is wrong and needs to be fixed. - SKIP_IF(!IsRunningOnGvisor()); - - ElfBinary<64> elf = StandardElf(); - - elf.header.e_type = ET_DYN; - - // Create a standard ELF, but extend to 1.5 pages. The second page will be the - // beginning of a multi-page data + bss segment. - elf.data.resize(kPageSize + kPageSize / 2); - - decltype(elf)::ElfPhdr phdr = {}; - phdr.p_type = PT_LOAD; - phdr.p_flags = PF_R | PF_W; - phdr.p_offset = kPageSize; - // Put the data segment *before* the first segment. - phdr.p_vaddr = 0x20000; - phdr.p_filesz = kPageSize / 2; - // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a - // bit less than 2 pages so this mapping doesn't extend beyond 0x43000. - phdr.p_memsz = 2 * kPageSize - kPageSize / 2; - elf.phdrs.push_back(phdr); - - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - EXPECT_EQ(execve_errno, ENOEXEC); -} - -// Standard dynamically linked binary with an ELF interpreter. -TEST(ElfTest, ELFInterpreter) { - ElfBinary<64> interpreter = StandardElf(); - interpreter.header.e_type = ET_DYN; - interpreter.header.e_entry = 0x0; - interpreter.UpdateOffsets(); - - // The first segment really needs to start at 0 for a normal PIE binary, and - // thus includes the headers. - uint64_t const offset = interpreter.phdrs[1].p_offset; - // N.B. Since Linux 4.10 (0036d1f7eb95b "binfmt_elf: fix calculations for bss - // padding"), Linux unconditionally zeroes the remainder of the highest mapped - // page in an interpreter, failing if the protections don't allow write. Thus - // we must mark this writeable. - interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; - interpreter.phdrs[1].p_offset = 0x0; - interpreter.phdrs[1].p_vaddr = 0x0; - interpreter.phdrs[1].p_filesz += offset; - interpreter.phdrs[1].p_memsz += offset; - - TempPath interpreter_file = - ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter)); - - ElfBinary<64> binary = StandardElf(); - - // Append the interpreter path. - int const interp_data_start = binary.data.size(); - for (char const c : interpreter_file.path()) { - binary.data.push_back(c); - } - // NUL-terminate. - binary.data.push_back(0); - int const interp_data_size = binary.data.size() - interp_data_start; - - decltype(binary)::ElfPhdr phdr = {}; - phdr.p_type = PT_INTERP; - phdr.p_offset = interp_data_start; - phdr.p_filesz = interp_data_size; - phdr.p_memsz = interp_data_size; - // "If [PT_INTERP] is present, it must precede any loadable segment entry." - // - // However, Linux allows it anywhere, so we just stick it at the end to make - // sure out-of-order PT_INTERP is OK. - binary.phdrs.push_back(phdr); - - binary.UpdateOffsets(); - - TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec( - binary_file.path(), {binary_file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - // RIP tells us which page the first segment of the interpreter was loaded - // into. - struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); - - EXPECT_THAT( - child, ContainsMappings(std::vector<ProcMapsEntry>({ - // Main binary - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - binary_file.path().c_str()}, - // Interpreter - {interp_load_addr, interp_load_addr + 0x1000, true, true, true, - true, 0, 0, 0, 0, interpreter_file.path().c_str()}, - }))); -} - -// Test parameter to ElfInterpterStaticTest cases. The first item is a suffix to -// add to the end of the interpreter path in the PT_INTERP segment and the -// second is the expected execve(2) errno. -using ElfInterpreterStaticParam = std::tuple<std::vector<char>, int>; - -class ElfInterpreterStaticTest - : public ::testing::TestWithParam<ElfInterpreterStaticParam> {}; - -// Statically linked ELF with a statically linked ELF interpreter. -TEST_P(ElfInterpreterStaticTest, Test) { - const std::vector<char> segment_suffix = std::get<0>(GetParam()); - const int expected_errno = std::get<1>(GetParam()); - - ElfBinary<64> interpreter = StandardElf(); - // See comment in ElfTest.ELFInterpreter. - interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; - interpreter.UpdateOffsets(); - TempPath interpreter_file = - ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter)); - - ElfBinary<64> binary = StandardElf(); - // The PT_LOAD segment conflicts with the interpreter's PT_LOAD segment. The - // interpreter's will be mapped directly over the binary's. - - // Interpreter path plus the parameterized suffix in the PT_INTERP segment. - const std::string path = interpreter_file.path(); - std::vector<char> segment(path.begin(), path.end()); - segment.insert(segment.end(), segment_suffix.begin(), segment_suffix.end()); - binary.AddInterpreter(segment); - - binary.UpdateOffsets(); - - TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec( - binary_file.path(), {binary_file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, expected_errno); - - if (expected_errno == 0) { - ASSERT_NO_ERRNO(WaitStopped(child)); - - EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ - // Interpreter. - {0x40000, 0x41000, true, true, true, true, 0, 0, 0, - 0, interpreter_file.path().c_str()}, - }))); - } -} - -INSTANTIATE_TEST_SUITE_P( - Cases, ElfInterpreterStaticTest, - ::testing::ValuesIn({ - // Simple NUL-terminator to run the interpreter as normal. - std::make_tuple(std::vector<char>({'\0'}), 0), - // Add some garbage to the segment followed by a NUL-terminator. This is - // ignored. - std::make_tuple(std::vector<char>({'\0', 'b', '\0'}), 0), - // Add some garbage to the segment without a NUL-terminator. Linux will - // reject - // this. - std::make_tuple(std::vector<char>({'\0', 'b'}), ENOEXEC), - })); - -// Test parameter to ElfInterpterBadPathTest cases. The first item is the -// contents of the PT_INTERP segment and the second is the expected execve(2) -// errno. -using ElfInterpreterBadPathParam = std::tuple<std::vector<char>, int>; - -class ElfInterpreterBadPathTest - : public ::testing::TestWithParam<ElfInterpreterBadPathParam> {}; - -TEST_P(ElfInterpreterBadPathTest, Test) { - const std::vector<char> segment = std::get<0>(GetParam()); - const int expected_errno = std::get<1>(GetParam()); - - ElfBinary<64> binary = StandardElf(); - binary.AddInterpreter(segment); - binary.UpdateOffsets(); - - TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary)); - - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec( - binary_file.path(), {binary_file.path()}, {}, nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, expected_errno); -} - -INSTANTIATE_TEST_SUITE_P( - Cases, ElfInterpreterBadPathTest, - ::testing::ValuesIn({ - // NUL-terminated fake path in the PT_INTERP segment. - std::make_tuple(std::vector<char>({'/', 'f', '/', 'b', '\0'}), ENOENT), - // ELF interpreter not NUL-terminated. - std::make_tuple(std::vector<char>({'/', 'f', '/', 'b'}), ENOEXEC), - // ELF interpreter path omitted entirely. - // - // fs/binfmt_elf.c:load_elf_binary returns ENOEXEC if p_filesz is < 2 - // bytes. - std::make_tuple(std::vector<char>({'\0'}), ENOEXEC), - // ELF interpreter path = "\0". - // - // fs/binfmt_elf.c:load_elf_binary returns ENOEXEC if p_filesz is < 2 - // bytes, so add an extra byte to pass that check. - // - // load_elf_binary -> open_exec -> do_open_execat fails to check that - // name != '\0' before calling do_filp_open, which thus opens the - // working directory. do_open_execat returns EACCES because the - // directory is not a regular file. - std::make_tuple(std::vector<char>({'\0', '\0'}), EACCES), - })); - -// Relative path to ELF interpreter. -TEST(ElfTest, ELFInterpreterRelative) { - ElfBinary<64> interpreter = StandardElf(); - interpreter.header.e_type = ET_DYN; - interpreter.header.e_entry = 0x0; - interpreter.UpdateOffsets(); - - // The first segment really needs to start at 0 for a normal PIE binary, and - // thus includes the headers. - uint64_t const offset = interpreter.phdrs[1].p_offset; - // See comment in ElfTest.ELFInterpreter. - interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; - interpreter.phdrs[1].p_offset = 0x0; - interpreter.phdrs[1].p_vaddr = 0x0; - interpreter.phdrs[1].p_filesz += offset; - interpreter.phdrs[1].p_memsz += offset; - - TempPath interpreter_file = - ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter)); - auto cwd = ASSERT_NO_ERRNO_AND_VALUE(GetCWD()); - auto interpreter_relative = - ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, interpreter_file.path())); - - ElfBinary<64> binary = StandardElf(); - - // NUL-terminated path in the PT_INTERP segment. - std::vector<char> segment(interpreter_relative.begin(), - interpreter_relative.end()); - segment.push_back(0); - binary.AddInterpreter(segment); - - binary.UpdateOffsets(); - - TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec( - binary_file.path(), {binary_file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - // RIP tells us which page the first segment of the interpreter was loaded - // into. - struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); - - EXPECT_THAT( - child, ContainsMappings(std::vector<ProcMapsEntry>({ - // Main binary - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - binary_file.path().c_str()}, - // Interpreter - {interp_load_addr, interp_load_addr + 0x1000, true, true, true, - true, 0, 0, 0, 0, interpreter_file.path().c_str()}, - }))); -} - -// ELF interpreter architecture doesn't match the binary. -TEST(ElfTest, ELFInterpreterWrongArch) { - ElfBinary<64> interpreter = StandardElf(); - interpreter.header.e_machine = EM_PPC64; - interpreter.header.e_type = ET_DYN; - interpreter.header.e_entry = 0x0; - interpreter.UpdateOffsets(); - - // The first segment really needs to start at 0 for a normal PIE binary, and - // thus includes the headers. - uint64_t const offset = interpreter.phdrs[1].p_offset; - // See comment in ElfTest.ELFInterpreter. - interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; - interpreter.phdrs[1].p_offset = 0x0; - interpreter.phdrs[1].p_vaddr = 0x0; - interpreter.phdrs[1].p_filesz += offset; - interpreter.phdrs[1].p_memsz += offset; - - TempPath interpreter_file = - ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter)); - - ElfBinary<64> binary = StandardElf(); - - // NUL-terminated path in the PT_INTERP segment. - const std::string path = interpreter_file.path(); - std::vector<char> segment(path.begin(), path.end()); - segment.push_back(0); - binary.AddInterpreter(segment); - - binary.UpdateOffsets(); - - TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec( - binary_file.path(), {binary_file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, ELIBBAD); -} - -// No execute permissions on the binary. -TEST(ElfTest, NoExecute) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - ASSERT_THAT(chmod(file.path().c_str(), 0644), SyscallSucceeds()); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - EXPECT_EQ(execve_errno, EACCES); -} - -// Execute, but no read permissions on the binary works just fine. -TEST(ElfTest, NoRead) { - // TODO(gvisor.dev/issue/160): gVisor's backing filesystem may prevent the - // sentry from reading the executable. - SKIP_IF(IsRunningOnGvisor()); - - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - ASSERT_THAT(chmod(file.path().c_str(), 0111), SyscallSucceeds()); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - ASSERT_NO_ERRNO(WaitStopped(child)); - - // TODO(gvisor.dev/issue/160): A task with a non-readable executable is marked - // non-dumpable, preventing access to proc files. gVisor does not implement - // this behavior. -} - -// No execute permissions on the ELF interpreter. -TEST(ElfTest, ElfInterpreterNoExecute) { - ElfBinary<64> interpreter = StandardElf(); - interpreter.header.e_type = ET_DYN; - interpreter.header.e_entry = 0x0; - interpreter.UpdateOffsets(); - - // The first segment really needs to start at 0 for a normal PIE binary, and - // thus includes the headers. - uint64_t const offset = interpreter.phdrs[1].p_offset; - // See comment in ElfTest.ELFInterpreter. - interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; - interpreter.phdrs[1].p_offset = 0x0; - interpreter.phdrs[1].p_vaddr = 0x0; - interpreter.phdrs[1].p_filesz += offset; - interpreter.phdrs[1].p_memsz += offset; - - TempPath interpreter_file = - ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter)); - - ElfBinary<64> binary = StandardElf(); - - // NUL-terminated path in the PT_INTERP segment. - const std::string path = interpreter_file.path(); - std::vector<char> segment(path.begin(), path.end()); - segment.push_back(0); - binary.AddInterpreter(segment); - - binary.UpdateOffsets(); - - TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary)); - - ASSERT_THAT(chmod(interpreter_file.path().c_str(), 0644), SyscallSucceeds()); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(interpreter_file.path(), {interpreter_file.path()}, {}, - &child, &execve_errno)); - EXPECT_EQ(execve_errno, EACCES); -} - -// Execute a basic interpreter script. -TEST(InterpreterScriptTest, Execute) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - // Use /tmp explicitly to ensure the path is short enough. - TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf)); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary.path()), 0755)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - EXPECT_NO_ERRNO(WaitStopped(child)); -} - -// Whitespace after #!. -TEST(InterpreterScriptTest, Whitespace) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - // Use /tmp explicitly to ensure the path is short enough. - TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf)); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#! \t \t", binary.path()), 0755)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - EXPECT_NO_ERRNO(WaitStopped(child)); -} - -// Interpreter script is missing execute permission. -TEST(InterpreterScriptTest, InterpreterScriptNoExecute) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - // Use /tmp explicitly to ensure the path is short enough. - TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf)); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary.path()), 0644)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, EACCES); -} - -// Binary interpreter script refers to is missing execute permission. -TEST(InterpreterScriptTest, BinaryNoExecute) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - // Use /tmp explicitly to ensure the path is short enough. - TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf)); - - ASSERT_THAT(chmod(binary.path().c_str(), 0644), SyscallSucceeds()); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary.path()), 0755)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, EACCES); -} - -// Linux will load interpreter scripts five levels deep, but no more. -TEST(InterpreterScriptTest, MaxRecursion) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - // Use /tmp explicitly to ensure the path is short enough. - TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf)); - - TempPath script1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - "/tmp", absl::StrCat("#!", binary.path()), 0755)); - TempPath script2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - "/tmp", absl::StrCat("#!", script1.path()), 0755)); - TempPath script3 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - "/tmp", absl::StrCat("#!", script2.path()), 0755)); - TempPath script4 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - "/tmp", absl::StrCat("#!", script3.path()), 0755)); - TempPath script5 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - "/tmp", absl::StrCat("#!", script4.path()), 0755)); - TempPath script6 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - "/tmp", absl::StrCat("#!", script5.path()), 0755)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script6.path(), {script6.path()}, {}, &child, &execve_errno)); - // Too many levels of recursion. - EXPECT_EQ(execve_errno, ELOOP); - - // The next level up is OK. - auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script5.path(), {script5.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - EXPECT_NO_ERRNO(WaitStopped(child)); -} - -// Interpreter script with a relative path. -TEST(InterpreterScriptTest, RelativePath) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf)); - - auto cwd = ASSERT_NO_ERRNO_AND_VALUE(GetCWD()); - auto binary_relative = - ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, binary.path())); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary_relative), 0755)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - EXPECT_NO_ERRNO(WaitStopped(child)); -} - -// Interpreter script with .. in a path component. -TEST(InterpreterScriptTest, UncleanPath) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - // Use /tmp explicitly to ensure the path is short enough. - TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf)); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!/tmp/../", binary.path()), - 0755)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - EXPECT_NO_ERRNO(WaitStopped(child)); -} - -// Passed interpreter script is a symlink. -TEST(InterpreterScriptTest, Symlink) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - // Use /tmp explicitly to ensure the path is short enough. - TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf)); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary.path()), 0755)); - - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), script.path())); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(link.path(), {link.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - EXPECT_NO_ERRNO(WaitStopped(child)); -} - -// Interpreter script points to a symlink loop. -TEST(InterpreterScriptTest, SymlinkLoop) { - std::string const link1 = NewTempAbsPathInDir("/tmp"); - std::string const link2 = NewTempAbsPathInDir("/tmp"); - - ASSERT_THAT(symlink(link2.c_str(), link1.c_str()), SyscallSucceeds()); - auto remove_link1 = Cleanup( - [&link1] { EXPECT_THAT(unlink(link1.c_str()), SyscallSucceeds()); }); - - ASSERT_THAT(symlink(link1.c_str(), link2.c_str()), SyscallSucceeds()); - auto remove_link2 = Cleanup( - [&link2] { EXPECT_THAT(unlink(link2.c_str()), SyscallSucceeds()); }); - - TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::StrCat("#!", link1), 0755)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno)); - EXPECT_EQ(execve_errno, ELOOP); -} - -// Binary is a symlink loop. -TEST(ExecveTest, SymlinkLoop) { - std::string const link1 = NewTempAbsPathInDir("/tmp"); - std::string const link2 = NewTempAbsPathInDir("/tmp"); - - ASSERT_THAT(symlink(link2.c_str(), link1.c_str()), SyscallSucceeds()); - auto remove_link = Cleanup( - [&link1] { EXPECT_THAT(unlink(link1.c_str()), SyscallSucceeds()); }); - - ASSERT_THAT(symlink(link1.c_str(), link2.c_str()), SyscallSucceeds()); - auto remove_link2 = Cleanup( - [&link2] { EXPECT_THAT(unlink(link2.c_str()), SyscallSucceeds()); }); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(link1, {link1}, {}, &child, &execve_errno)); - EXPECT_EQ(execve_errno, ELOOP); -} - -// Binary is a directory. -TEST(ExecveTest, Directory) { - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec("/tmp", {"/tmp"}, {}, &child, &execve_errno)); - EXPECT_EQ(execve_errno, EACCES); -} - -// Pass a valid binary as a directory (extra / on the end). -TEST(ExecveTest, BinaryAsDirectory) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - std::string const path = absl::StrCat(file.path(), "/"); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(path, {path}, {}, &child, &execve_errno)); - EXPECT_EQ(execve_errno, ENOTDIR); -} - -// The initial brk value is after the page at the end of the binary. -TEST(ExecveTest, BrkAfterBinary) { - ElfBinary<64> elf = StandardElf(); - elf.UpdateOffsets(); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); - - pid_t child; - int execve_errno; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); - ASSERT_EQ(execve_errno, 0); - - // Ensure it made it to SIGSTOP. - ASSERT_NO_ERRNO(WaitStopped(child)); - - struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - - // RIP is just beyond the final syscall instruction. Rewind to execute a brk - // syscall. - regs.rip -= kSyscallSize; - regs.rax = __NR_brk; - regs.rdi = 0; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child, 0, ®s), SyscallSucceeds()); - - // Resume the child, waiting for syscall entry. - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds()); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) - << "status = " << status; - - // Execute the syscall. - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) - << "status = " << status; - - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - - // brk is after the text page. - // - // The kernel does brk randomization, so we can't be sure what the exact - // address will be, but it is always beyond the final page in the binary. - // i.e., it does not start immediately after memsz in the middle of a page. - // Userspace may expect to use that space. - EXPECT_GE(regs.rax, 0x41000); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/exec_proc_exe_workload.cc b/test/syscalls/linux/exec_proc_exe_workload.cc deleted file mode 100644 index 2989379b7..000000000 --- a/test/syscalls/linux/exec_proc_exe_workload.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdlib.h> -#include <unistd.h> - -#include <iostream> - -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" - -int main(int argc, char** argv, char** envp) { - // This is annoying. Because remote build systems may put these binaries - // in a content-addressable-store, you may wind up with /proc/self/exe - // pointing to some random path (but with a sensible argv[0]). - // - // Therefore, this test simply checks that the /proc/self/exe - // is absolute and *doesn't* match argv[1]. - std::string exe = - gvisor::testing::ProcessExePath(getpid()).ValueOrDie(); - if (exe[0] != '/') { - std::cerr << "relative path: " << exe << std::endl; - exit(1); - } - if (exe.find(argv[1]) != std::string::npos) { - std::cerr << "matching path: " << exe << std::endl; - exit(1); - } - - return 0; -} diff --git a/test/syscalls/linux/exec_state_workload.cc b/test/syscalls/linux/exec_state_workload.cc deleted file mode 100644 index 028902b14..000000000 --- a/test/syscalls/linux/exec_state_workload.cc +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <stdint.h> -#include <stdio.h> -#include <stdlib.h> -#include <sys/auxv.h> -#include <sys/prctl.h> -#include <sys/time.h> - -#include <iostream> -#include <ostream> -#include <string> - -#include "absl/strings/numbers.h" - -// Pretty-print a sigset_t. -std::ostream& operator<<(std::ostream& out, const sigset_t& s) { - out << "{ "; - - for (int i = 0; i < NSIG; i++) { - if (sigismember(&s, i)) { - out << i << " "; - } - } - - out << "}"; - return out; -} - -// Verify that the signo handler is handler. -int CheckSigHandler(uint32_t signo, uintptr_t handler) { - struct sigaction sa; - int ret = sigaction(signo, nullptr, &sa); - if (ret < 0) { - perror("sigaction"); - return 1; - } - - if (reinterpret_cast<void (*)(int)>(handler) != sa.sa_handler) { - std::cerr << "signo " << signo << " handler got: " << sa.sa_handler - << " expected: " << std::hex << handler; - return 1; - } - return 0; -} - -// Verify that the signo is blocked. -int CheckSigBlocked(uint32_t signo) { - sigset_t s; - int ret = sigprocmask(SIG_SETMASK, nullptr, &s); - if (ret < 0) { - perror("sigprocmask"); - return 1; - } - - if (!sigismember(&s, signo)) { - std::cerr << "signal " << signo << " not blocked in signal mask: " << s - << std::endl; - return 1; - } - return 0; -} - -// Verify that the itimer is enabled. -int CheckItimerEnabled(uint32_t timer) { - struct itimerval itv; - int ret = getitimer(timer, &itv); - if (ret < 0) { - perror("getitimer"); - return 1; - } - - if (!itv.it_value.tv_sec && !itv.it_value.tv_usec && - !itv.it_interval.tv_sec && !itv.it_interval.tv_usec) { - std::cerr << "timer " << timer - << " not enabled. value sec: " << itv.it_value.tv_sec - << " usec: " << itv.it_value.tv_usec - << " interval sec: " << itv.it_interval.tv_sec - << " usec: " << itv.it_interval.tv_usec << std::endl; - return 1; - } - return 0; -} - -int PrintExecFn() { - unsigned long execfn = getauxval(AT_EXECFN); - if (!execfn) { - std::cerr << "AT_EXECFN missing" << std::endl; - return 1; - } - - std::cerr << reinterpret_cast<const char*>(execfn) << std::endl; - return 0; -} - -int PrintExecName() { - const size_t name_length = 20; - char name[name_length] = {0}; - if (prctl(PR_GET_NAME, name) < 0) { - std::cerr << "prctl(PR_GET_NAME) failed" << std::endl; - return 1; - } - - std::cerr << name << std::endl; - return 0; -} - -void usage(const std::string& prog) { - std::cerr << "usage:\n" - << "\t" << prog << " CheckSigHandler <signo> <handler addr (hex)>\n" - << "\t" << prog << " CheckSigBlocked <signo>\n" - << "\t" << prog << " CheckTimerDisabled <timer>\n" - << "\t" << prog << " PrintExecFn\n" - << "\t" << prog << " PrintExecName" << std::endl; -} - -int main(int argc, char** argv) { - if (argc < 2) { - usage(argv[0]); - return 1; - } - - std::string func(argv[1]); - - if (func == "CheckSigHandler") { - if (argc != 4) { - usage(argv[0]); - return 1; - } - - uint32_t signo; - if (!absl::SimpleAtoi(argv[2], &signo)) { - std::cerr << "invalid signo: " << argv[2] << std::endl; - return 1; - } - - uintptr_t handler; - if (!absl::numbers_internal::safe_strtoi_base(argv[3], &handler, 16)) { - std::cerr << "invalid handler: " << std::hex << argv[3] << std::endl; - return 1; - } - - return CheckSigHandler(signo, handler); - } - - if (func == "CheckSigBlocked") { - if (argc != 3) { - usage(argv[0]); - return 1; - } - - uint32_t signo; - if (!absl::SimpleAtoi(argv[2], &signo)) { - std::cerr << "invalid signo: " << argv[2] << std::endl; - return 1; - } - - return CheckSigBlocked(signo); - } - - if (func == "CheckItimerEnabled") { - if (argc != 3) { - usage(argv[0]); - return 1; - } - - uint32_t timer; - if (!absl::SimpleAtoi(argv[2], &timer)) { - std::cerr << "invalid signo: " << argv[2] << std::endl; - return 1; - } - - return CheckItimerEnabled(timer); - } - - if (func == "PrintExecFn") { - // N.B. This will be called as an interpreter script, with the script passed - // as the third argument. We don't care about that script. - return PrintExecFn(); - } - - if (func == "PrintExecName") { - // N.B. This may be called as an interpreter script like PrintExecFn. - return PrintExecName(); - } - - std::cerr << "Invalid function: " << func << std::endl; - return 1; -} diff --git a/test/syscalls/linux/exit.cc b/test/syscalls/linux/exit.cc deleted file mode 100644 index d52ea786b..000000000 --- a/test/syscalls/linux/exit.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/wait.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" -#include "test/util/time_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -void TestExit(int code) { - pid_t pid = fork(); - if (pid == 0) { - _exit(code); - } - - ASSERT_THAT(pid, SyscallSucceeds()); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == code) << status; -} - -TEST(ExitTest, Success) { TestExit(0); } - -TEST(ExitTest, Failure) { TestExit(1); } - -// This test ensures that a process's file descriptors are closed when it calls -// exit(). In order to test this, the parent tries to read from a pipe whose -// write end is held by the child. While the read is blocking, the child exits, -// which should cause the parent to read 0 bytes due to EOF. -TEST(ExitTest, CloseFds) { - int pipe_fds[2]; - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - - FileDescriptor read_fd(pipe_fds[0]); - FileDescriptor write_fd(pipe_fds[1]); - - pid_t pid = fork(); - if (pid == 0) { - read_fd.reset(); - - SleepSafe(absl::Seconds(10)); - - _exit(0); - } - - EXPECT_THAT(pid, SyscallSucceeds()); - - write_fd.reset(); - - char buf[10]; - EXPECT_THAT(ReadFd(read_fd.get(), buf, sizeof(buf)), - SyscallSucceedsWithValue(0)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/exit_script.sh b/test/syscalls/linux/exit_script.sh deleted file mode 100755 index 527518e06..000000000 --- a/test/syscalls/linux/exit_script.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -if [ $# -ne 1 ]; then - echo "Usage: $0 exit_code" - exit 255 -fi - -exit $1 diff --git a/test/syscalls/linux/fadvise64.cc b/test/syscalls/linux/fadvise64.cc deleted file mode 100644 index 2af7aa6d9..000000000 --- a/test/syscalls/linux/fadvise64.cc +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <syscall.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -TEST(FAdvise64Test, Basic) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - // fadvise64 is noop in gVisor, so just test that it succeeds. - ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NORMAL), - SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_RANDOM), - SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_SEQUENTIAL), - SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_WILLNEED), - SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_DONTNEED), - SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NOREUSE), - SyscallSucceeds()); -} - -TEST(FAdvise64Test, InvalidArgs) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - // Note: offset is allowed to be negative. - ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, static_cast<off_t>(-1), - POSIX_FADV_NORMAL), - SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, 12345), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(FAdvise64Test, NoPipes) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor read(fds[0]); - const FileDescriptor write(fds[1]); - - ASSERT_THAT(syscall(__NR_fadvise64, read.get(), 0, 10, POSIX_FADV_NORMAL), - SyscallFailsWithErrno(ESPIPE)); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc deleted file mode 100644 index 7819f4ac3..000000000 --- a/test/syscalls/linux/fallocate.cc +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <signal.h> -#include <sys/resource.h> -#include <sys/stat.h> -#include <syscall.h> -#include <time.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -int fallocate(int fd, int mode, off_t offset, off_t len) { - return RetryEINTR(syscall)(__NR_fallocate, fd, mode, offset, len); -} - -class AllocateTest : public FileTest { - void SetUp() override { FileTest::SetUp(); } -}; - -TEST_F(AllocateTest, Fallocate) { - // Check that it starts at size zero. - struct stat buf; - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 0); - - // Grow to ten bytes. - ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 10); - - // Allocate to a smaller size should be noop. - ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 10); - - // Grow again. - ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 20); - - // Grow with offset. - ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 30); - - // Grow with offset beyond EOF. - ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 40); -} - -TEST_F(AllocateTest, FallocateInvalid) { - // Invalid FD - EXPECT_THAT(fallocate(-1, 0, 0, 10), SyscallFailsWithErrno(EBADF)); - - // Negative offset and size. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, -1, 10), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, -1), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, -1, -1), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(AllocateTest, FallocateReadonly) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(AllocateTest, FallocatePipe) { - int pipes[2]; - EXPECT_THAT(pipe(pipes), SyscallSucceeds()); - auto cleanup = Cleanup([&pipes] { - EXPECT_THAT(close(pipes[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipes[1]), SyscallSucceeds()); - }); - - EXPECT_THAT(fallocate(pipes[1], 0, 0, 10), SyscallFailsWithErrno(ESPIPE)); -} - -TEST_F(AllocateTest, FallocateChar) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDWR)); - EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); -} - -TEST_F(AllocateTest, FallocateRlimit) { - // Get the current rlimit and restore after test run. - struct rlimit initial_lim; - ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - auto cleanup = Cleanup([&initial_lim] { - EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - }); - - // Try growing past the file size limit. - sigset_t new_mask; - sigemptyset(&new_mask); - sigaddset(&new_mask, SIGXFSZ); - sigprocmask(SIG_BLOCK, &new_mask, nullptr); - - struct rlimit setlim = {}; - setlim.rlim_cur = 1024; - setlim.rlim_max = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds()); - - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 1025), - SyscallFailsWithErrno(EFBIG)); - - struct timespec timelimit = {}; - timelimit.tv_sec = 10; - EXPECT_EQ(sigtimedwait(&new_mask, nullptr, &timelimit), SIGXFSZ); - ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds()); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/fault.cc b/test/syscalls/linux/fault.cc deleted file mode 100644 index a85750382..000000000 --- a/test/syscalls/linux/fault.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#define _GNU_SOURCE 1 -#include <signal.h> -#include <ucontext.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -__attribute__((noinline)) void Fault(void) { - volatile int* foo = nullptr; - *foo = 0; -} - -int GetPcFromUcontext(ucontext_t* uc, uintptr_t* pc) { -#if defined(__x86_64__) - *pc = uc->uc_mcontext.gregs[REG_RIP]; - return 1; -#elif defined(__i386__) - *pc = uc->uc_mcontext.gregs[REG_EIP]; - return 1; -#elif defined(__aarch64__) - *pc = uc->uc_mcontext.pc; - return 1; -#else - return 0; -#endif -} - -void sigact_handler(int sig, siginfo_t* siginfo, void* context) { - uintptr_t pc; - if (GetPcFromUcontext(reinterpret_cast<ucontext_t*>(context), &pc)) { - /* Expect Fault() to be at most 64 bytes in size. */ - uintptr_t fault_addr = reinterpret_cast<uintptr_t>(&Fault); - EXPECT_GE(pc, fault_addr); - EXPECT_LT(pc, fault_addr + 64); - exit(0); - } -} - -TEST(FaultTest, InRange) { - // Reset the signal handler to do nothing so that it doesn't freak out - // the test runner when we fire an alarm. - struct sigaction sa = {}; - sa.sa_sigaction = sigact_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - ASSERT_THAT(sigaction(SIGSEGV, &sa, nullptr), SyscallSucceeds()); - - Fault(); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/fchdir.cc b/test/syscalls/linux/fchdir.cc deleted file mode 100644 index 08bcae1e8..000000000 --- a/test/syscalls/linux/fchdir.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/socket.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/util/capability_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(FchdirTest, Success) { - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - int fd; - ASSERT_THAT(fd = open(temp_dir.path().c_str(), O_DIRECTORY | O_RDONLY), - SyscallSucceeds()); - - EXPECT_THAT(fchdir(fd), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - // Change CWD to a permanent location as temp dirs will be cleaned up. - EXPECT_THAT(chdir("/"), SyscallSucceeds()); -} - -TEST(FchdirTest, InvalidFD) { - EXPECT_THAT(fchdir(-1), SyscallFailsWithErrno(EBADF)); -} - -TEST(FchdirTest, PermissionDenied) { - // Drop capabilities that allow us to override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); - - int fd; - ASSERT_THAT(fd = open(temp_dir.path().c_str(), O_DIRECTORY | O_RDONLY), - SyscallSucceeds()); - - EXPECT_THAT(fchdir(fd), SyscallFailsWithErrno(EACCES)); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(FchdirTest, NotDir) { - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - int fd; - ASSERT_THAT(fd = open(temp_file.path().c_str(), O_CREAT | O_RDONLY, 0777), - SyscallSucceeds()); - - EXPECT_THAT(fchdir(fd), SyscallFailsWithErrno(ENOTDIR)); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc deleted file mode 100644 index c7cc5816e..000000000 --- a/test/syscalls/linux/fcntl.cc +++ /dev/null @@ -1,1132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <signal.h> -#include <sys/types.h> -#include <syscall.h> -#include <unistd.h> - -#include <string> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/base/port.h" -#include "absl/flags/flag.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/cleanup.h" -#include "test/util/eventfd_util.h" -#include "test/util/fs_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/timer_util.h" - -ABSL_FLAG(std::string, child_setlock_on, "", - "Contains the path to try to set a file lock on."); -ABSL_FLAG(bool, child_setlock_write, false, - "Whether to set a writable lock (otherwise readable)"); -ABSL_FLAG(bool, blocking, false, - "Whether to set a blocking lock (otherwise non-blocking)."); -ABSL_FLAG(bool, retry_eintr, false, - "Whether to retry in the subprocess on EINTR."); -ABSL_FLAG(uint64_t, child_setlock_start, 0, "The value of struct flock start"); -ABSL_FLAG(uint64_t, child_setlock_len, 0, "The value of struct flock len"); -ABSL_FLAG(int32_t, socket_fd, -1, - "A socket to use for communicating more state back " - "to the parent."); - -namespace gvisor { -namespace testing { - -class FcntlLockTest : public ::testing::Test { - public: - void SetUp() override { - // Let's make a socket pair. - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, fds_), SyscallSucceeds()); - } - - void TearDown() override { - EXPECT_THAT(close(fds_[0]), SyscallSucceeds()); - EXPECT_THAT(close(fds_[1]), SyscallSucceeds()); - } - - int64_t GetSubprocessFcntlTimeInUsec() { - int64_t ret = 0; - EXPECT_THAT(ReadFd(fds_[0], reinterpret_cast<void*>(&ret), sizeof(ret)), - SyscallSucceedsWithValue(sizeof(ret))); - return ret; - } - - // The first fd will remain with the process creating the subprocess - // and the second will go to the subprocess. - int fds_[2] = {}; -}; - -namespace { - -PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write, - bool blocking, bool retry_eintr, int fd, - off_t start, off_t length, pid_t* child) { - std::vector<std::string> args = { - "/proc/self/exe", "--child_setlock_on", path, - "--child_setlock_start", absl::StrCat(start), "--child_setlock_len", - absl::StrCat(length), "--socket_fd", absl::StrCat(fd)}; - - if (for_write) { - args.push_back("--child_setlock_write"); - } - - if (blocking) { - args.push_back("--blocking"); - } - - if (retry_eintr) { - args.push_back("--retry_eintr"); - } - - int execve_errno = 0; - ASSIGN_OR_RETURN_ERRNO( - auto cleanup, - ForkAndExec("/proc/self/exe", ExecveArray(args.begin(), args.end()), {}, - nullptr, child, &execve_errno)); - - if (execve_errno != 0) { - return PosixError(execve_errno, "execve"); - } - - return std::move(cleanup); -} - -TEST(FcntlTest, SetCloExec) { - // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set. - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); - ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0)); - - // Set the FD_CLOEXEC flag. - ASSERT_THAT(fcntl(fd.get(), F_SETFD, FD_CLOEXEC), SyscallSucceeds()); - ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); -} - -TEST(FcntlTest, ClearCloExec) { - // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag set. - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_CLOEXEC)); - ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); - - // Clear the FD_CLOEXEC flag. - ASSERT_THAT(fcntl(fd.get(), F_SETFD, 0), SyscallSucceeds()); - ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0)); -} - -TEST(FcntlTest, IndependentDescriptorFlags) { - // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set. - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); - ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0)); - - // Duplicate the descriptor. Ensure that it also doesn't have FD_CLOEXEC. - FileDescriptor newfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); - ASSERT_THAT(fcntl(newfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); - - // Set FD_CLOEXEC on the first FD. - ASSERT_THAT(fcntl(fd.get(), F_SETFD, FD_CLOEXEC), SyscallSucceeds()); - ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); - - // Ensure that the second FD is unaffected by the change on the first. - ASSERT_THAT(fcntl(newfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); -} - -// All file description flags passed to open appear in F_GETFL. -TEST(FcntlTest, GetAllFlags) { - TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - int flags = O_RDWR | O_DIRECT | O_SYNC | O_NONBLOCK | O_APPEND; - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), flags)); - - // Linux forces O_LARGEFILE on all 64-bit kernels and gVisor's is 64-bit. - int expected = flags | kOLargeFile; - - int rflags; - EXPECT_THAT(rflags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); - EXPECT_EQ(rflags, expected); -} - -TEST(FcntlTest, SetFlags) { - TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), 0)); - - int const flags = O_RDWR | O_DIRECT | O_SYNC | O_NONBLOCK | O_APPEND; - EXPECT_THAT(fcntl(fd.get(), F_SETFL, flags), SyscallSucceeds()); - - // Can't set O_RDWR or O_SYNC. - // Linux forces O_LARGEFILE on all 64-bit kernels and gVisor's is 64-bit. - int expected = O_DIRECT | O_NONBLOCK | O_APPEND | kOLargeFile; - - int rflags; - EXPECT_THAT(rflags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); - EXPECT_EQ(rflags, expected); -} - -TEST_F(FcntlLockTest, SetLockBadFd) { - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // len 0 has a special meaning: lock all bytes despite how - // large the file grows. - fl.l_len = 0; - EXPECT_THAT(fcntl(-1, F_SETLK, &fl), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(FcntlLockTest, SetLockPipe) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd, but doesn't matter, we expect this to fail. - fl.l_len = 0; - EXPECT_THAT(fcntl(fds[0], F_SETLK, &fl), SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(close(fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(fds[1]), SyscallSucceeds()); -} - -TEST_F(FcntlLockTest, SetLockDir) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0666)); - - struct flock fl; - fl.l_type = F_RDLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); -} - -TEST_F(FcntlLockTest, SetLockBadOpenFlagsWrite) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY, 0666)); - - struct flock fl0; - fl0.l_type = F_WRLCK; - fl0.l_whence = SEEK_SET; - fl0.l_start = 0; - // Same as SetLockBadFd. - fl0.l_len = 0; - - // Expect that setting a write lock using a read only file descriptor - // won't work. - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(FcntlLockTest, SetLockBadOpenFlagsRead) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY, 0666)); - - struct flock fl1; - fl1.l_type = F_RDLCK; - fl1.l_whence = SEEK_SET; - fl1.l_start = 0; - // Same as SetLockBadFd. - fl1.l_len = 0; - - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl1), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(FcntlLockTest, SetLockUnlockOnNothing) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_UNLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); -} - -TEST_F(FcntlLockTest, SetWriteLockSingleProc) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd0 = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - - EXPECT_THAT(fcntl(fd0.get(), F_SETLK, &fl), SyscallSucceeds()); - // Expect to be able to take the same lock on the same fd no problem. - EXPECT_THAT(fcntl(fd0.get(), F_SETLK, &fl), SyscallSucceeds()); - - FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - // Expect to be able to take the same lock from a different fd but for - // the same process. - EXPECT_THAT(fcntl(fd1.get(), F_SETLK, &fl), SyscallSucceeds()); -} - -TEST_F(FcntlLockTest, SetReadLockMultiProc) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_RDLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // spawn a child process to take a read lock on the same file. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), false /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST_F(FcntlLockTest, SetReadThenWriteLockMultiProc) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_RDLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // Assert that another process trying to lock on the same file will fail - // with EAGAIN. It's important that we keep the fd above open so that - // that the other process will contend with the lock. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), true /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN) - << "Exited with code: " << status; - - // Close the fd: we want to test that another process can acquire the - // lock after this point. - fd.reset(); - // Assert that another process can now acquire the lock. - - child_pid = 0; - auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), true /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST_F(FcntlLockTest, SetWriteThenReadLockMultiProc) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - // Same as SetReadThenWriteLockMultiProc. - - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - - // Same as SetReadThenWriteLockMultiProc. - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // Same as SetReadThenWriteLockMultiProc. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), false /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN) - << "Exited with code: " << status; - - // Same as SetReadThenWriteLockMultiProc. - fd.reset(); // Close the fd. - - // Same as SetReadThenWriteLockMultiProc. - child_pid = 0; - auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), false /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST_F(FcntlLockTest, SetWriteLockMultiProc) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - // Same as SetReadThenWriteLockMultiProc. - - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - - // Same as SetReadWriteLockMultiProc. - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // Same as SetReadWriteLockMultiProc. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), true /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN) - << "Exited with code: " << status; - - fd.reset(); // Close the FD. - // Same as SetReadWriteLockMultiProc. - child_pid = 0; - auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), true /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST_F(FcntlLockTest, SetLockIsRegional) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - fl.l_len = 4096; - - // Same as SetReadWriteLockMultiProc. - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // Same as SetReadWriteLockMultiProc. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), true /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_len, 0, &child_pid)); - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST_F(FcntlLockTest, SetLockUpgradeDowngrade) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_RDLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - - // Same as SetReadWriteLockMultiProc. - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // Upgrade to a write lock. This will prevent anyone else from taking - // the lock. - fl.l_type = F_WRLCK; - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // Same as SetReadWriteLockMultiProc., - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), false /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN) - << "Exited with code: " << status; - - // Downgrade back to a read lock. - fl.l_type = F_RDLCK; - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // Do the same stint as before, but this time it should succeed. - child_pid = 0; - auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), false /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST_F(FcntlLockTest, SetLockDroppedOnClose) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - // While somewhat surprising, obtaining another fd to the same file and - // then closing it in this process drops *all* locks. - FileDescriptor other_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - // Same as SetReadThenWriteLockMultiProc. - - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - - // Same as SetReadWriteLockMultiProc. - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - other_fd.reset(); // Close. - - // Expect to be able to get the lock, given that the close above dropped it. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(file.path(), true /* write lock */, - false /* nonblocking */, false /* no eintr retry */, - -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid)); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST_F(FcntlLockTest, SetLockUnlock) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - // Setup two regional locks with different permissions. - struct flock fl0; - fl0.l_type = F_WRLCK; - fl0.l_whence = SEEK_SET; - fl0.l_start = 0; - fl0.l_len = 4096; - - struct flock fl1; - fl1.l_type = F_RDLCK; - fl1.l_whence = SEEK_SET; - fl1.l_start = 4096; - // Same as SetLockBadFd. - fl1.l_len = 0; - - // Set both region locks. - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallSucceeds()); - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl1), SyscallSucceeds()); - - // Another process should fail to take a read lock on the entire file - // due to the regional write lock. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock( - file.path(), false /* write lock */, false /* nonblocking */, - false /* no eintr retry */, -1 /* no socket fd */, 0, 0, &child_pid)); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN) - << "Exited with code: " << status; - - // Then only unlock the writable one. This should ensure that other - // processes can take any read lock that it wants. - fl0.l_type = F_UNLCK; - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallSucceeds()); - - // Another process should now succeed to get a read lock on the entire file. - child_pid = 0; - auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock( - file.path(), false /* write lock */, false /* nonblocking */, - false /* no eintr retry */, -1 /* no socket fd */, 0, 0, &child_pid)); - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST_F(FcntlLockTest, SetLockAcrossRename) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - // Setup two regional locks with different permissions. - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; - - // Set the region lock. - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); - - // Rename the file to someplace nearby. - std::string const newpath = NewTempAbsPath(); - EXPECT_THAT(rename(file.path().c_str(), newpath.c_str()), SyscallSucceeds()); - - // Another process should fail to take a read lock on the renamed file - // since we still have an open handle to the inode. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - SubprocessLock(newpath, false /* write lock */, false /* nonblocking */, - false /* no eintr retry */, -1 /* no socket fd */, - fl.l_start, fl.l_len, &child_pid)); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN) - << "Exited with code: " << status; -} - -// NOTE: The blocking tests below aren't perfect. It's hard to assert exactly -// what the kernel did while handling a syscall. These tests are timing based -// because there really isn't any other reasonable way to assert that correct -// blocking behavior happened. - -// This test will verify that blocking works as expected when another process -// holds a write lock when obtaining a write lock. This test will hold the lock -// for some amount of time and then wait for the second process to send over the -// socket_fd the amount of time it was blocked for before the lock succeeded. -TEST_F(FcntlLockTest, SetWriteLockThenBlockingWriteLock) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - fl.l_len = 0; - - // Take the write lock. - ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds()); - - // Attempt to take the read lock in a sub process. This will immediately block - // so we will release our lock after some amount of time and then assert the - // amount of time the other process was blocked for. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock( - file.path(), true /* write lock */, true /* Blocking Lock */, - true /* Retry on EINTR */, fds_[1] /* Socket fd for timing information */, - fl.l_start, fl.l_len, &child_pid)); - - // We will wait kHoldLockForSec before we release our lock allowing the - // subprocess to obtain it. - constexpr absl::Duration kHoldLockFor = absl::Seconds(5); - const int64_t kMinBlockTimeUsec = absl::ToInt64Microseconds(absl::Seconds(1)); - - absl::SleepFor(kHoldLockFor); - - // Unlock our write lock. - fl.l_type = F_UNLCK; - ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds()); - - // Read the blocked time from the subprocess socket. - int64_t subprocess_blocked_time_usec = GetSubprocessFcntlTimeInUsec(); - - // We must have been waiting at least kMinBlockTime. - EXPECT_GT(subprocess_blocked_time_usec, kMinBlockTimeUsec); - - // The FCNTL write lock must always succeed as it will simply block until it - // can obtain the lock. - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -// This test will veirfy that blocking works as expected when another process -// holds a read lock when obtaining a write lock. This test will hold the lock -// for some amount of time and then wait for the second process to send over the -// socket_fd the amount of time it was blocked for before the lock succeeded. -TEST_F(FcntlLockTest, SetReadLockThenBlockingWriteLock) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_RDLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - fl.l_len = 0; - - // Take the write lock. - ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds()); - - // Attempt to take the read lock in a sub process. This will immediately block - // so we will release our lock after some amount of time and then assert the - // amount of time the other process was blocked for. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock( - file.path(), true /* write lock */, true /* Blocking Lock */, - true /* Retry on EINTR */, fds_[1] /* Socket fd for timing information */, - fl.l_start, fl.l_len, &child_pid)); - - // We will wait kHoldLockForSec before we release our lock allowing the - // subprocess to obtain it. - constexpr absl::Duration kHoldLockFor = absl::Seconds(5); - - const int64_t kMinBlockTimeUsec = absl::ToInt64Microseconds(absl::Seconds(1)); - - absl::SleepFor(kHoldLockFor); - - // Unlock our READ lock. - fl.l_type = F_UNLCK; - ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds()); - - // Read the blocked time from the subprocess socket. - int64_t subprocess_blocked_time_usec = GetSubprocessFcntlTimeInUsec(); - - // We must have been waiting at least kMinBlockTime. - EXPECT_GT(subprocess_blocked_time_usec, kMinBlockTimeUsec); - - // The FCNTL write lock must always succeed as it will simply block until it - // can obtain the lock. - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -// This test will veirfy that blocking works as expected when another process -// holds a write lock when obtaining a read lock. This test will hold the lock -// for some amount of time and then wait for the second process to send over the -// socket_fd the amount of time it was blocked for before the lock succeeded. -TEST_F(FcntlLockTest, SetWriteLockThenBlockingReadLock) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - fl.l_len = 0; - - // Take the write lock. - ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds()); - - // Attempt to take the read lock in a sub process. This will immediately block - // so we will release our lock after some amount of time and then assert the - // amount of time the other process was blocked for. - pid_t child_pid = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock( - file.path(), false /* read lock */, true /* Blocking Lock */, - true /* Retry on EINTR */, fds_[1] /* Socket fd for timing information */, - fl.l_start, fl.l_len, &child_pid)); - - // We will wait kHoldLockForSec before we release our lock allowing the - // subprocess to obtain it. - constexpr absl::Duration kHoldLockFor = absl::Seconds(5); - - const int64_t kMinBlockTimeUsec = absl::ToInt64Microseconds(absl::Seconds(1)); - - absl::SleepFor(kHoldLockFor); - - // Unlock our write lock. - fl.l_type = F_UNLCK; - ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds()); - - // Read the blocked time from the subprocess socket. - int64_t subprocess_blocked_time_usec = GetSubprocessFcntlTimeInUsec(); - - // We must have been waiting at least kMinBlockTime. - EXPECT_GT(subprocess_blocked_time_usec, kMinBlockTimeUsec); - - // The FCNTL read lock must always succeed as it will simply block until it - // can obtain the lock. - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -// This test will verify that when one process only holds a read lock that -// another will not block while obtaining a read lock when F_SETLKW is used. -TEST_F(FcntlLockTest, SetReadLockThenBlockingReadLock) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - - struct flock fl; - fl.l_type = F_RDLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - fl.l_len = 0; - - // Take the READ lock. - ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds()); - - // Attempt to take the read lock in a sub process. Since multiple processes - // can hold a read lock this should immediately return without blocking - // even though we used F_SETLKW in the subprocess. - pid_t child_pid = 0; - auto sp = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock( - file.path(), false /* read lock */, true /* Blocking Lock */, - true /* Retry on EINTR */, -1 /* No fd, should not block */, fl.l_start, - fl.l_len, &child_pid)); - - // We never release the lock and the subprocess should still obtain it without - // blocking for any period of time. - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -TEST(FcntlTest, GetO_ASYNC) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - int flag_fl = -1; - ASSERT_THAT(flag_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - EXPECT_EQ(flag_fl & O_ASYNC, 0); - - int flag_fd = -1; - ASSERT_THAT(flag_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds()); - EXPECT_EQ(flag_fd & O_ASYNC, 0); -} - -TEST(FcntlTest, SetFlO_ASYNC) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - int before_fl = -1; - ASSERT_THAT(before_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - - int before_fd = -1; - ASSERT_THAT(before_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds()); - - ASSERT_THAT(fcntl(s.get(), F_SETFL, before_fl | O_ASYNC), SyscallSucceeds()); - - int after_fl = -1; - ASSERT_THAT(after_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - EXPECT_EQ(after_fl, before_fl | O_ASYNC); - - int after_fd = -1; - ASSERT_THAT(after_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds()); - EXPECT_EQ(after_fd, before_fd); -} - -TEST(FcntlTest, SetFdO_ASYNC) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - int before_fl = -1; - ASSERT_THAT(before_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - - int before_fd = -1; - ASSERT_THAT(before_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds()); - - ASSERT_THAT(fcntl(s.get(), F_SETFD, before_fd | O_ASYNC), SyscallSucceeds()); - - int after_fl = -1; - ASSERT_THAT(after_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - EXPECT_EQ(after_fl, before_fl); - - int after_fd = -1; - ASSERT_THAT(after_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds()); - EXPECT_EQ(after_fd, before_fd); -} - -TEST(FcntlTest, DupAfterO_ASYNC) { - FileDescriptor s1 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - int before = -1; - ASSERT_THAT(before = fcntl(s1.get(), F_GETFL), SyscallSucceeds()); - - ASSERT_THAT(fcntl(s1.get(), F_SETFL, before | O_ASYNC), SyscallSucceeds()); - - FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(s1.Dup()); - - int after = -1; - ASSERT_THAT(after = fcntl(fd2.get(), F_GETFL), SyscallSucceeds()); - EXPECT_EQ(after & O_ASYNC, O_ASYNC); -} - -TEST(FcntlTest, GetOwn) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), 0); - MaybeSave(); -} - -TEST(FcntlTest, GetOwnEx) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex owner = {}; - EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &owner), - SyscallSucceedsWithValue(0)); -} - -TEST(FcntlTest, SetOwnExInvalidType) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex owner = {}; - owner.type = __pid_type(-1); - EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(FcntlTest, SetOwnExInvalidTid) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex owner = {}; - owner.type = F_OWNER_TID; - owner.pid = -1; - - EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallFailsWithErrno(ESRCH)); -} - -TEST(FcntlTest, SetOwnExInvalidPid) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex owner = {}; - owner.type = F_OWNER_PID; - owner.pid = -1; - - EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallFailsWithErrno(ESRCH)); -} - -TEST(FcntlTest, SetOwnExInvalidPgrp) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex owner = {}; - owner.type = F_OWNER_PGRP; - owner.pid = -1; - - EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallFailsWithErrno(ESRCH)); -} - -TEST(FcntlTest, SetOwnExTid) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex owner = {}; - owner.type = F_OWNER_TID; - EXPECT_THAT(owner.pid = syscall(__NR_gettid), SyscallSucceeds()); - - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); - - EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), owner.pid); - MaybeSave(); -} - -TEST(FcntlTest, SetOwnExPid) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex owner = {}; - owner.type = F_OWNER_PID; - EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); - - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); - - EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), owner.pid); - MaybeSave(); -} - -TEST(FcntlTest, SetOwnExPgrp) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex owner = {}; - owner.type = F_OWNER_PGRP; - EXPECT_THAT(owner.pid = getpgrp(), SyscallSucceeds()); - - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); - - // NOTE(igudger): I don't understand why, but this is flaky on Linux. - // GetOwnExPgrp (below) does not have this issue. - SKIP_IF(!IsRunningOnGvisor()); - - EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), -owner.pid); - MaybeSave(); -} - -TEST(FcntlTest, GetOwnExTid) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex set_owner = {}; - set_owner.type = F_OWNER_TID; - EXPECT_THAT(set_owner.pid = syscall(__NR_gettid), SyscallSucceeds()); - - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); - - f_owner_ex got_owner = {}; - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(got_owner.type, set_owner.type); - EXPECT_EQ(got_owner.pid, set_owner.pid); -} - -TEST(FcntlTest, GetOwnExPid) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex set_owner = {}; - set_owner.type = F_OWNER_PID; - EXPECT_THAT(set_owner.pid = getpid(), SyscallSucceeds()); - - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); - - f_owner_ex got_owner = {}; - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(got_owner.type, set_owner.type); - EXPECT_EQ(got_owner.pid, set_owner.pid); -} - -TEST(FcntlTest, GetOwnExPgrp) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - f_owner_ex set_owner = {}; - set_owner.type = F_OWNER_PGRP; - EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); - - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); - - f_owner_ex got_owner = {}; - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(got_owner.type, set_owner.type); - EXPECT_EQ(got_owner.pid, set_owner.pid); -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - gvisor::testing::TestInit(&argc, &argv); - - const std::string setlock_on = absl::GetFlag(FLAGS_child_setlock_on); - if (!setlock_on.empty()) { - int socket_fd = absl::GetFlag(FLAGS_socket_fd); - int fd = open(setlock_on.c_str(), O_RDWR, 0666); - if (fd == -1 && errno != 0) { - int err = errno; - std::cerr << "CHILD open " << setlock_on << " failed " << err - << std::endl; - exit(err); - } - - struct flock fl; - if (absl::GetFlag(FLAGS_child_setlock_write)) { - fl.l_type = F_WRLCK; - } else { - fl.l_type = F_RDLCK; - } - fl.l_whence = SEEK_SET; - fl.l_start = absl::GetFlag(FLAGS_child_setlock_start); - fl.l_len = absl::GetFlag(FLAGS_child_setlock_len); - - // Test the fcntl, no need to log, the error is unambiguously - // from fcntl at this point. - int err = 0; - int ret = 0; - - gvisor::testing::MonotonicTimer timer; - timer.Start(); - do { - ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl); - } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR); - auto usec = absl::ToInt64Microseconds(timer.Duration()); - - if (ret == -1 && errno != 0) { - err = errno; - } - - // If there is a socket fd let's send back the time in microseconds it took - // to execute this syscall. - if (socket_fd != -1) { - gvisor::testing::WriteFd(socket_fd, reinterpret_cast<void*>(&usec), - sizeof(usec)); - close(socket_fd); - } - - close(fd); - exit(err); - } - - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h deleted file mode 100644 index 6f80bc97c..000000000 --- a/test/syscalls/linux/file_base.h +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_FILE_BASE_H_ -#define GVISOR_TEST_SYSCALLS_FILE_BASE_H_ - -#include <arpa/inet.h> -#include <errno.h> -#include <fcntl.h> -#include <netinet/in.h> -#include <stddef.h> -#include <stdio.h> -#include <string.h> -#include <sys/socket.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <sys/uio.h> -#include <unistd.h> - -#include <cstring> -#include <string> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -class FileTest : public ::testing::Test { - public: - void SetUp() override { - test_pipe_[0] = -1; - test_pipe_[1] = -1; - - test_file_name_ = NewTempAbsPath(); - test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE( - Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR)); - - // FIXME(edahlgren): enable when mknod syscall is supported. - // test_fifo_name_ = NewTempAbsPath(); - // ASSERT_THAT(mknod(test_fifo_name_.c_str()), S_IFIFO|0644, 0, - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[1] = open(test_fifo_name_.c_str(), - // O_WRONLY), - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[0] = open(test_fifo_name_.c_str(), - // O_RDONLY), - // SyscallSucceeds()); - - ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds()); - ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds()); - } - - // CloseFile will allow the test to manually close the file descriptor. - void CloseFile() { test_file_fd_.reset(); } - - // UnlinkFile will allow the test to manually unlink the file. - void UnlinkFile() { - if (!test_file_name_.empty()) { - EXPECT_THAT(unlink(test_file_name_.c_str()), SyscallSucceeds()); - test_file_name_.clear(); - } - } - - // ClosePipes will allow the test to manually close the pipes. - void ClosePipes() { - if (test_pipe_[0] > 0) { - EXPECT_THAT(close(test_pipe_[0]), SyscallSucceeds()); - } - - if (test_pipe_[1] > 0) { - EXPECT_THAT(close(test_pipe_[1]), SyscallSucceeds()); - } - - test_pipe_[0] = -1; - test_pipe_[1] = -1; - } - - void TearDown() override { - CloseFile(); - UnlinkFile(); - ClosePipes(); - - // FIXME(edahlgren): enable when mknod syscall is supported. - // close(test_fifo_[0]); - // close(test_fifo_[1]); - // unlink(test_fifo_name_.c_str()); - } - - std::string test_file_name_; - std::string test_fifo_name_; - FileDescriptor test_file_fd_; - - int test_fifo_[2]; - int test_pipe_[2]; -}; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_FILE_BASE_H_ diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc deleted file mode 100644 index 3ecb8db8e..000000000 --- a/test/syscalls/linux/flock.cc +++ /dev/null @@ -1,589 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <sys/file.h> - -#include <string> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class FlockTest : public FileTest {}; - -TEST_F(FlockTest, BadFD) { - // EBADF: fd is not an open file descriptor. - ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(FlockTest, InvalidOpCombinations) { - // The operation cannot be both exclusive and shared. - EXPECT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_SH | LOCK_NB), - SyscallFailsWithErrno(EINVAL)); - - // Locking and Unlocking doesn't make sense. - EXPECT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_UN | LOCK_NB), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_UN | LOCK_NB), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(FlockTest, NoOperationSpecified) { - // Not specifying an operation is invalid. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(FlockTestNoFixture, FlockSupportsPipes) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds()); - EXPECT_THAT(close(fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(fds[1]), SyscallSucceeds()); -} - -TEST_F(FlockTest, TestSimpleExLock) { - // Test that we can obtain an exclusive lock (no other holders) - // and that we can unlock it. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestSimpleShLock) { - // Test that we can obtain a shared lock (no other holders) - // and that we can unlock it. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestLockableAnyMode) { - // flock(2): A shared or exclusive lock can be placed on a file - // regardless of the mode in which the file was opened. - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(test_file_name_, O_RDONLY)); // open read only to test - - // Mode shouldn't prevent us from taking an exclusive lock. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceedsWithValue(0)); - - // Unlock - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestUnlockWithNoHolders) { - // Test that unlocking when no one holds a lock succeeeds. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestRepeatedExLockingBySameHolder) { - // Test that repeated locking by the same holder for the - // same type of lock works correctly. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestRepeatedExLockingSingleUnlock) { - // Test that repeated locking by the same holder for the - // same type of lock works correctly and that a single unlock is required. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); - - // Should be unlocked at this point - ASSERT_THAT(flock(fd.get(), LOCK_NB | LOCK_EX), SyscallSucceedsWithValue(0)); - - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestRepeatedShLockingBySameHolder) { - // Test that repeated locking by the same holder for the - // same type of lock works correctly. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_SH), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_SH), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestSingleHolderUpgrade) { - // Test that a shared lock is upgradable when no one else holds a lock. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_SH), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestSingleHolderDowngrade) { - // Test single holder lock downgrade case. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestMultipleShared) { - // This is a simple test to verify that multiple independent shared - // locks will be granted. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // A shared lock should be granted as there only exists other shared locks. - ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB), SyscallSucceedsWithValue(0)); - - // Unlock both. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -/* - * flock(2): If a process uses open(2) (or similar) to obtain more than one - * descriptor for the same file, these descriptors are treated - * independently by flock(). An attempt to lock the file using one of - * these file descriptors may be denied by a lock that the calling process - * has already placed via another descriptor. - */ -TEST_F(FlockTest, TestMultipleHolderSharedExclusive) { - // This test will verify that an exclusive lock will not be granted - // while a shared is held. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Verify We're unable to get an exlcusive lock via the second FD. - // because someone is holding a shared lock. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Unlock - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestSharedLockFailExclusiveHolder) { - // This test will verify that a shared lock is denied while - // someone holds an exclusive lock. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Verify we're unable to get an shared lock via the second FD. - // because someone is holding an exclusive lock. - ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Unlock - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolder) { - // This test will verify that an exclusive lock is denied while - // someone already holds an exclsuive lock. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Verify we're unable to get an exclusive lock via the second FD - // because someone is already holding an exclusive lock. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Unlock - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestMultipleHolderSharedExclusiveUpgrade) { - // This test will verify that we cannot obtain an exclusive lock while - // a shared lock is held by another descriptor, then verify that an upgrade - // is possible on a shared lock once all other shared locks have closed. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Verify we're unable to get an exclusive lock via the second FD because - // a shared lock is held. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Verify that we can get a shared lock via the second descriptor instead - ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB), SyscallSucceedsWithValue(0)); - - // Unlock the first and there will only be one shared lock remaining. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); - - // Upgrade 2nd fd. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceedsWithValue(0)); - - // Finally unlock the second - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestMultipleHolderSharedExclusiveDowngrade) { - // This test will verify that a shared lock is not obtainable while an - // exclusive lock is held but that once the first is downgraded that - // the second independent file descriptor can also get a shared lock. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Verify We're unable to get a shared lock via the second FD because - // an exclusive lock is held. - ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Verify that we can downgrade the first. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB), - SyscallSucceedsWithValue(0)); - - // Now verify that we can obtain a shared lock since the first was downgraded. - ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB), SyscallSucceedsWithValue(0)); - - // Finally unlock both. - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -/* - * flock(2): Locks created by flock() are associated with an open file table - * entry. This means that duplicate file descriptors (created by, for example, - * fork(2) or dup(2)) refer to the same lock, and this lock may be modified or - * released using any of these descriptors. Furthermore, the lock is released - * either by an explicit LOCK_UN operation on any of these duplicate descriptors - * or when all such descriptors have been closed. - */ -TEST_F(FlockTest, TestDupFdUpgrade) { - // This test will verify that a shared lock is upgradeable via a dupped - // file descriptor, if the FD wasn't dupped this would fail. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup()); - - // Now we should be able to upgrade via the dupped fd. - ASSERT_THAT(flock(dup_fd.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - - // Validate unlock via dupped fd. - ASSERT_THAT(flock(dup_fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestDupFdDowngrade) { - // This test will verify that a exclusive lock is downgradable via a dupped - // file descriptor, if the FD wasn't dupped this would fail. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup()); - - // Now we should be able to downgrade via the dupped fd. - ASSERT_THAT(flock(dup_fd.get(), LOCK_SH | LOCK_NB), - SyscallSucceedsWithValue(0)); - - // Validate unlock via dupped fd - ASSERT_THAT(flock(dup_fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestDupFdCloseRelease) { - // flock(2): Furthermore, the lock is released either by an explicit LOCK_UN - // operation on any of these duplicate descriptors, or when all such - // descriptors have been closed. - // - // This test will verify that a dupped fd closing will not release the - // underlying lock until all such dupped fds have closed. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - - FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup()); - - // At this point we have ONE exclusive locked referenced by two different fds. - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Validate that we cannot get a lock on a new unrelated FD. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Closing the dupped fd shouldn't affect the lock until all are closed. - dup_fd.reset(); // Closed the duped fd. - - // Validate that we still cannot get a lock on a new unrelated FD. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Closing the first fd - CloseFile(); // Will validate the syscall succeeds. - - // Now we should actually be able to get a lock since all fds related to - // the first lock are closed. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceedsWithValue(0)); - - // Unlock. - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestDupFdUnlockRelease) { - /* flock(2): Furthermore, the lock is released either by an explicit LOCK_UN - * operation on any of these duplicate descriptors, or when all such - * descriptors have been closed. - */ - // This test will verify that an explict unlock on a dupped FD will release - // the underlying lock unlike the previous case where close on a dup was - // not enough to release the lock. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), - SyscallSucceedsWithValue(0)); - - const FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup()); - - // At this point we have ONE exclusive locked referenced by two different fds. - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Validate that we cannot get a lock on a new unrelated FD. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Explicitly unlock via the dupped descriptor. - ASSERT_THAT(flock(dup_fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); - - // Validate that we can now get the lock since we explicitly unlocked. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceedsWithValue(0)); - - // Unlock - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -TEST_F(FlockTest, TestDupFdFollowedByLock) { - // This test will verify that taking a lock on a file descriptor that has - // already been dupped means that the lock is shared between both. This is - // slightly different than than duping on an already locked FD. - FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup()); - - // Take a lock. - ASSERT_THAT(flock(dup_fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds()); - - // Now dup_fd and test_file_ should both reference the same lock. - // We shouldn't be able to obtain a lock until both are closed. - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Closing the first fd - dup_fd.reset(); // Close the duped fd. - - // Validate that we cannot get a lock yet because the dupped descriptor. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Closing the second fd. - CloseFile(); // CloseFile() will validate the syscall succeeds. - - // Now we should be able to get the lock. - ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds()); - - // Unlock. - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0)); -} - -// NOTE: These blocking tests are not perfect. Unfortunately it's very hard to -// determine if a thread was actually blocked in the kernel so we're forced -// to use timing. -TEST_F(FlockTest, BlockingLockNoBlockingForSharedLocks_NoRandomSave) { - // This test will verify that although LOCK_NB isn't specified - // two different fds can obtain shared locks without blocking. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH), SyscallSucceeds()); - - // kHoldLockTime is the amount of time we will hold the lock before releasing. - constexpr absl::Duration kHoldLockTime = absl::Seconds(30); - - const DisableSave ds; // Timing-related. - - // We do this in another thread so we can determine if it was actually - // blocked by timing the amount of time it took for the syscall to complete. - ScopedThread t([&] { - MonotonicTimer timer; - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // Only a single shared lock is held, the lock will be granted immediately. - // This should be granted without any blocking. Don't save here to avoid - // wild discrepencies on timing. - timer.Start(); - ASSERT_THAT(flock(fd.get(), LOCK_SH), SyscallSucceeds()); - - // We held the lock for 30 seconds but this thread should not have - // blocked at all so we expect a very small duration on syscall completion. - ASSERT_LT(timer.Duration(), - absl::Seconds(1)); // 1000ms is much less than 30s. - - // We can release our second shared lock - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceeds()); - }); - - // Sleep before unlocking. - absl::SleepFor(kHoldLockTime); - - // Release the first shared lock. Don't save in this situation to avoid - // discrepencies in timing. - EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); -} - -TEST_F(FlockTest, BlockingLockFirstSharedSecondExclusive_NoRandomSave) { - // This test will verify that if someone holds a shared lock any attempt to - // obtain an exclusive lock will result in blocking. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH), SyscallSucceeds()); - - // kHoldLockTime is the amount of time we will hold the lock before releasing. - constexpr absl::Duration kHoldLockTime = absl::Seconds(2); - - const DisableSave ds; // Timing-related. - - // We do this in another thread so we can determine if it was actually - // blocked by timing the amount of time it took for the syscall to complete. - ScopedThread t([&] { - MonotonicTimer timer; - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // This exclusive lock should block because someone is already holding a - // shared lock. We don't save here to avoid wild discrepencies on timing. - timer.Start(); - ASSERT_THAT(RetryEINTR(flock)(fd.get(), LOCK_EX), SyscallSucceeds()); - - // We should be blocked, we will expect to be blocked for more than 1.0s. - ASSERT_GT(timer.Duration(), absl::Seconds(1)); - - // We can release our exclusive lock. - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceeds()); - }); - - // Sleep before unlocking. - absl::SleepFor(kHoldLockTime); - - // Release the shared lock allowing the thread to proceed. - // We don't save here to avoid wild discrepencies in timing. - EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); -} - -TEST_F(FlockTest, BlockingLockFirstExclusiveSecondShared_NoRandomSave) { - // This test will verify that if someone holds an exclusive lock any attempt - // to obtain a shared lock will result in blocking. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX), SyscallSucceeds()); - - // kHoldLockTime is the amount of time we will hold the lock before releasing. - constexpr absl::Duration kHoldLockTime = absl::Seconds(2); - - const DisableSave ds; // Timing-related. - - // We do this in another thread so we can determine if it was actually - // blocked by timing the amount of time it took for the syscall to complete. - ScopedThread t([&] { - MonotonicTimer timer; - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // This shared lock should block because someone is already holding an - // exclusive lock. We don't save here to avoid wild discrepencies on timing. - timer.Start(); - ASSERT_THAT(RetryEINTR(flock)(fd.get(), LOCK_SH), SyscallSucceeds()); - - // We should be blocked, we will expect to be blocked for more than 1.0s. - ASSERT_GT(timer.Duration(), absl::Seconds(1)); - - // We can release our shared lock. - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceeds()); - }); - - // Sleep before unlocking. - absl::SleepFor(kHoldLockTime); - - // Release the exclusive lock allowing the blocked thread to proceed. - // We don't save here to avoid wild discrepencies in timing. - EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); -} - -TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive_NoRandomSave) { - // This test will verify that if someone holds an exclusive lock any attempt - // to obtain another exclusive lock will result in blocking. - ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX), SyscallSucceeds()); - - // kHoldLockTime is the amount of time we will hold the lock before releasing. - constexpr absl::Duration kHoldLockTime = absl::Seconds(2); - - const DisableSave ds; // Timing-related. - - // We do this in another thread so we can determine if it was actually - // blocked by timing the amount of time it took for the syscall to complete. - ScopedThread t([&] { - MonotonicTimer timer; - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - // This exclusive lock should block because someone is already holding an - // exclusive lock. - timer.Start(); - ASSERT_THAT(RetryEINTR(flock)(fd.get(), LOCK_EX), SyscallSucceeds()); - - // We should be blocked, we will expect to be blocked for more than 1.0s. - ASSERT_GT(timer.Duration(), absl::Seconds(1)); - - // We can release our exclusive lock. - ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceeds()); - }); - - // Sleep before unlocking. - absl::SleepFor(kHoldLockTime); - - // Release the exclusive lock allowing the blocked thread to proceed. - // We don't save to avoid wild discrepencies in timing. - EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc deleted file mode 100644 index ff8bdfeb0..000000000 --- a/test/syscalls/linux/fork.cc +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <sched.h> -#include <stdlib.h> -#include <sys/mman.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include <atomic> -#include <cstdlib> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/capability_util.h" -#include "test/util/logging.h" -#include "test/util/memory_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -using ::testing::Ge; - -class ForkTest : public ::testing::Test { - protected: - // SetUp creates a populated, open file. - void SetUp() override { - // Make a shared mapping. - shared_ = reinterpret_cast<char*>(mmap(0, kPageSize, PROT_READ | PROT_WRITE, - MAP_SHARED | MAP_ANONYMOUS, -1, 0)); - ASSERT_NE(reinterpret_cast<void*>(shared_), MAP_FAILED); - - // Make a private mapping. - private_ = - reinterpret_cast<char*>(mmap(0, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); - ASSERT_NE(reinterpret_cast<void*>(private_), MAP_FAILED); - - // Make a pipe. - ASSERT_THAT(pipe(pipes_), SyscallSucceeds()); - } - - // TearDown frees associated resources. - void TearDown() override { - EXPECT_THAT(munmap(shared_, kPageSize), SyscallSucceeds()); - EXPECT_THAT(munmap(private_, kPageSize), SyscallSucceeds()); - EXPECT_THAT(close(pipes_[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipes_[1]), SyscallSucceeds()); - } - - // Fork executes a clone system call. - pid_t Fork() { - pid_t pid = fork(); - MaybeSave(); - TEST_PCHECK_MSG(pid >= 0, "fork failed"); - return pid; - } - - // Wait waits for the given pid and returns the exit status. If the child was - // killed by a signal or an error occurs, then 256+signal is returned. - int Wait(pid_t pid) { - int status; - while (true) { - int rval = wait4(pid, &status, 0, NULL); - if (rval < 0) { - return rval; - } - if (rval != pid) { - continue; - } - if (WIFEXITED(status)) { - return WEXITSTATUS(status); - } - if (WIFSIGNALED(status)) { - return 256 + WTERMSIG(status); - } - } - } - - // Exit exits the proccess. - void Exit(int code) { - _exit(code); - - // Should never reach here. Since the exit above failed, we really don't - // have much in the way of options to indicate failure. So we just try to - // log an assertion failure to the logs. The parent process will likely - // fail anyways if exit is not working. - TEST_CHECK_MSG(false, "_exit returned"); - } - - // ReadByte reads a byte from the shared pipe. - char ReadByte() { - char val = -1; - TEST_PCHECK(ReadFd(pipes_[0], &val, 1) == 1); - MaybeSave(); - return val; - } - - // WriteByte writes a byte from the shared pipe. - void WriteByte(char val) { - TEST_PCHECK(WriteFd(pipes_[1], &val, 1) == 1); - MaybeSave(); - } - - // Shared pipe. - int pipes_[2]; - - // Shared mapping (one page). - char* shared_; - - // Private mapping (one page). - char* private_; -}; - -TEST_F(ForkTest, Simple) { - pid_t child = Fork(); - if (child == 0) { - Exit(0); - } - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); -} - -TEST_F(ForkTest, ExitCode) { - pid_t child = Fork(); - if (child == 0) { - Exit(123); - } - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(123)); - child = Fork(); - if (child == 0) { - Exit(1); - } - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(1)); -} - -TEST_F(ForkTest, Multi) { - pid_t child1 = Fork(); - if (child1 == 0) { - Exit(0); - } - pid_t child2 = Fork(); - if (child2 == 0) { - Exit(1); - } - EXPECT_THAT(Wait(child1), SyscallSucceedsWithValue(0)); - EXPECT_THAT(Wait(child2), SyscallSucceedsWithValue(1)); -} - -TEST_F(ForkTest, Pipe) { - pid_t child = Fork(); - if (child == 0) { - WriteByte(1); - Exit(0); - } - EXPECT_EQ(ReadByte(), 1); - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); -} - -TEST_F(ForkTest, SharedMapping) { - pid_t child = Fork(); - if (child == 0) { - // Wait for the parent. - ReadByte(); - if (shared_[0] == 1) { - Exit(0); - } - // Failed. - Exit(1); - } - // Change the mapping. - ASSERT_EQ(shared_[0], 0); - shared_[0] = 1; - // Unblock the child. - WriteByte(0); - // Did it work? - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); -} - -TEST_F(ForkTest, PrivateMapping) { - pid_t child = Fork(); - if (child == 0) { - // Wait for the parent. - ReadByte(); - if (private_[0] == 0) { - Exit(0); - } - // Failed. - Exit(1); - } - // Change the mapping. - ASSERT_EQ(private_[0], 0); - private_[0] = 1; - // Unblock the child. - WriteByte(0); - // Did it work? - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); -} - -// CPUID is x86 specific. -#ifdef __x86_64__ -// Test that cpuid works after a fork. -TEST_F(ForkTest, Cpuid) { - pid_t child = Fork(); - - // We should be able to determine the CPU vendor. - ASSERT_NE(GetCPUVendor(), CPUVendor::kUnknownVendor); - - if (child == 0) { - Exit(0); - } - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); -} -#endif - -TEST_F(ForkTest, Mmap) { - pid_t child = Fork(); - - if (child == 0) { - void* addr = - mmap(0, kPageSize, PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - MaybeSave(); - Exit(addr == MAP_FAILED); - } - - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); -} - -static volatile int alarmed = 0; - -void AlarmHandler(int sig, siginfo_t* info, void* context) { alarmed = 1; } - -TEST_F(ForkTest, Alarm) { - // Setup an alarm handler. - struct sigaction sa; - sa.sa_sigaction = AlarmHandler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - EXPECT_THAT(sigaction(SIGALRM, &sa, nullptr), SyscallSucceeds()); - - pid_t child = Fork(); - - if (child == 0) { - alarm(1); - sleep(3); - if (!alarmed) { - Exit(1); - } - Exit(0); - } - - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); - EXPECT_EQ(0, alarmed); -} - -// Child cannot affect parent private memory. Regression test for b/24137240. -TEST_F(ForkTest, PrivateMemory) { - std::atomic<uint32_t> local(0); - - pid_t child1 = Fork(); - if (child1 == 0) { - local++; - - pid_t child2 = Fork(); - if (child2 == 0) { - local++; - - TEST_CHECK(local.load() == 2); - - Exit(0); - } - - TEST_PCHECK(Wait(child2) == 0); - TEST_CHECK(local.load() == 1); - Exit(0); - } - - EXPECT_THAT(Wait(child1), SyscallSucceedsWithValue(0)); - EXPECT_EQ(0, local.load()); -} - -// Kernel-accessed buffers should remain coherent across COW. -// -// The buffer must be >= usermem.ZeroCopyMinBytes, as UnsafeAccess operates -// differently. Regression test for b/33811887. -TEST_F(ForkTest, COWSegment) { - constexpr int kBufSize = 1024; - char* read_buf = private_; - char* touch = private_ + kPageSize / 2; - - std::string contents(kBufSize, 'a'); - - ScopedThread t([&] { - // Wait to be sure the parent is blocked in read. - absl::SleepFor(absl::Seconds(3)); - - // Fork to mark private pages for COW. - // - // Use fork directly rather than the Fork wrapper to skip the multi-threaded - // check, and limit the child to async-signal-safe functions: - // - // "After a fork() in a multithreaded program, the child can safely call - // only async-signal-safe functions (see signal(7)) until such time as it - // calls execve(2)." - // - // Skip ASSERT in the child, as it isn't async-signal-safe. - pid_t child = fork(); - if (child == 0) { - // Wait to be sure parent touched memory. - sleep(3); - Exit(0); - } - - // Check success only in the parent. - ASSERT_THAT(child, SyscallSucceedsWithValue(Ge(0))); - - // Trigger COW on private page. - *touch = 42; - - // Write to pipe. Parent should still be able to read this. - EXPECT_THAT(WriteFd(pipes_[1], contents.c_str(), kBufSize), - SyscallSucceedsWithValue(kBufSize)); - - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); - }); - - EXPECT_THAT(ReadFd(pipes_[0], read_buf, kBufSize), - SyscallSucceedsWithValue(kBufSize)); - EXPECT_STREQ(contents.c_str(), read_buf); -} - -TEST_F(ForkTest, SigAltStack) { - std::vector<char> stack_mem(SIGSTKSZ); - stack_t stack = {}; - stack.ss_size = SIGSTKSZ; - stack.ss_sp = stack_mem.data(); - ASSERT_THAT(sigaltstack(&stack, nullptr), SyscallSucceeds()); - - pid_t child = Fork(); - - if (child == 0) { - stack_t oss = {}; - TEST_PCHECK(sigaltstack(nullptr, &oss) == 0); - MaybeSave(); - - TEST_CHECK((oss.ss_flags & SS_DISABLE) == 0); - TEST_CHECK(oss.ss_size == SIGSTKSZ); - TEST_CHECK(oss.ss_sp == stack.ss_sp); - - Exit(0); - } - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); -} - -TEST_F(ForkTest, Affinity) { - // Make a non-default cpumask. - cpu_set_t parent_mask; - EXPECT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &parent_mask), - SyscallSucceeds()); - // Knock out the lowest bit. - for (unsigned int n = 0; n < CPU_SETSIZE; n++) { - if (CPU_ISSET(n, &parent_mask)) { - CPU_CLR(n, &parent_mask); - break; - } - } - EXPECT_THAT(sched_setaffinity(/*pid=*/0, sizeof(cpu_set_t), &parent_mask), - SyscallSucceeds()); - - pid_t child = Fork(); - if (child == 0) { - cpu_set_t child_mask; - - int ret = sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &child_mask); - MaybeSave(); - if (ret < 0) { - Exit(-ret); - } - - TEST_CHECK(CPU_EQUAL(&child_mask, &parent_mask)); - - Exit(0); - } - - EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); -} - -TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) { - // "If CLONE_NEWUSER is specified along with other CLONE_NEW* flags in a - // single clone(2) or unshare(2) call, the user namespace is guaranteed to be - // created first, giving the child (clone(2)) or caller (unshare(2)) - // privileges over the remaining namespaces created by the call. Thus, it is - // possible for an unprivileged caller to specify this combination of flags." - // - user_namespaces(7) - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - Mapping child_stack = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - int child_pid; - // We only test with CLONE_NEWIPC, CLONE_NEWNET, and CLONE_NEWUTS since these - // namespaces were implemented in Linux before user namespaces. - ASSERT_THAT( - child_pid = clone( - +[](void*) { return 0; }, - reinterpret_cast<void*>(child_stack.addr() + kPageSize), - CLONE_NEWUSER | CLONE_NEWIPC | CLONE_NEWNET | CLONE_NEWUTS | SIGCHLD, - /* arg = */ nullptr), - SyscallSucceeds()); - - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status = " << status; -} - -#ifdef __x86_64__ -// Clone with CLONE_SETTLS and a non-canonical TLS address is rejected. -TEST(CloneTest, NonCanonicalTLS) { - constexpr uintptr_t kNonCanonical = 1ull << 48; - - // We need a valid address for the stack pointer. We'll never actually execute - // on this. - char stack; - - EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, - nullptr, kNonCanonical), - SyscallFailsWithErrno(EPERM)); -} -#endif - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc deleted file mode 100644 index a346f1f00..000000000 --- a/test/syscalls/linux/fpsig_fork.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This test verifies that fork(2) in a signal handler will correctly -// restore floating point state after the signal handler returns in both -// the child and parent. -#include <sys/time.h> - -#include "gtest/gtest.h" -#include "test/util/logging.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -#define GET_XMM(__var, __xmm) \ - asm volatile("movq %%" #__xmm ", %0" : "=r"(__var)) -#define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var)) - -int parent, child; - -void sigusr1(int s, siginfo_t* siginfo, void* _uc) { - // Fork and clobber %xmm0. The fpstate should be restored by sigreturn(2) - // in both parent and child. - child = fork(); - TEST_CHECK_MSG(child >= 0, "fork failed"); - - uint64_t val = SIGUSR1; - SET_XMM(val, xmm0); -} - -TEST(FPSigTest, Fork) { - parent = getpid(); - pid_t parent_tid = gettid(); - - struct sigaction sa = {}; - sigemptyset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - sa.sa_sigaction = sigusr1; - ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); - - // The amd64 ABI specifies that the XMM register set is caller-saved. This - // implies that if there is any function call between SET_XMM and GET_XMM the - // compiler might save/restore xmm0 implicitly. This defeats the entire - // purpose of the test which is to verify that fpstate is restored by - // sigreturn(2). - // - // This is the reason why 'tgkill(getpid(), gettid(), SIGUSR1)' is implemented - // in inline assembly below. - // - // If the OS is broken and registers are clobbered by the child, using tgkill - // to signal the current thread increases the likelihood that this thread will - // be the one clobbered. - - uint64_t expected = 0xdeadbeeffacefeed; - SET_XMM(expected, xmm0); - - asm volatile( - "movl %[killnr], %%eax;" - "movl %[parent], %%edi;" - "movl %[tid], %%esi;" - "movl %[sig], %%edx;" - "syscall;" - : - : [ killnr ] "i"(__NR_tgkill), [ parent ] "rm"(parent), - [ tid ] "rm"(parent_tid), [ sig ] "i"(SIGUSR1) - : "rax", "rdi", "rsi", "rdx", - // Clobbered by syscall. - "rcx", "r11"); - - uint64_t got; - GET_XMM(got, xmm0); - - if (getpid() == parent) { // Parent. - int status; - ASSERT_THAT(waitpid(child, &status, 0), SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); - } - - // TEST_CHECK_MSG since this may run in the child. - TEST_CHECK_MSG(expected == got, "Bad xmm0 value"); - - if (getpid() != parent) { // Child. - _exit(0); - } -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/fpsig_nested.cc b/test/syscalls/linux/fpsig_nested.cc deleted file mode 100644 index c476a8e7a..000000000 --- a/test/syscalls/linux/fpsig_nested.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This program verifies that application floating point state is restored -// correctly after a signal handler returns. It also verifies that this works -// with nested signals. -#include <sys/time.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -#define GET_XMM(__var, __xmm) \ - asm volatile("movq %%" #__xmm ", %0" : "=r"(__var)) -#define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var)) - -int pid; -int tid; - -volatile uint64_t entryxmm[2] = {~0UL, ~0UL}; -volatile uint64_t exitxmm[2]; - -void sigusr2(int s, siginfo_t* siginfo, void* _uc) { - uint64_t val = SIGUSR2; - - // Record the value of %xmm0 on entry and then clobber it. - GET_XMM(entryxmm[1], xmm0); - SET_XMM(val, xmm0); - GET_XMM(exitxmm[1], xmm0); -} - -void sigusr1(int s, siginfo_t* siginfo, void* _uc) { - uint64_t val = SIGUSR1; - - // Record the value of %xmm0 on entry and then clobber it. - GET_XMM(entryxmm[0], xmm0); - SET_XMM(val, xmm0); - - // Send a SIGUSR2 to ourself. The signal mask is configured such that - // the SIGUSR2 handler will run before this handler returns. - asm volatile( - "movl %[killnr], %%eax;" - "movl %[pid], %%edi;" - "movl %[tid], %%esi;" - "movl %[sig], %%edx;" - "syscall;" - : - : [ killnr ] "i"(__NR_tgkill), [ pid ] "rm"(pid), [ tid ] "rm"(tid), - [ sig ] "i"(SIGUSR2) - : "rax", "rdi", "rsi", "rdx", - // Clobbered by syscall. - "rcx", "r11"); - - // Record value of %xmm0 again to verify that the nested signal handler - // does not clobber it. - GET_XMM(exitxmm[0], xmm0); -} - -TEST(FPSigTest, NestedSignals) { - pid = getpid(); - tid = gettid(); - - struct sigaction sa = {}; - sigemptyset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - sa.sa_sigaction = sigusr1; - ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); - - sa.sa_sigaction = sigusr2; - ASSERT_THAT(sigaction(SIGUSR2, &sa, nullptr), SyscallSucceeds()); - - // The amd64 ABI specifies that the XMM register set is caller-saved. This - // implies that if there is any function call between SET_XMM and GET_XMM the - // compiler might save/restore xmm0 implicitly. This defeats the entire - // purpose of the test which is to verify that fpstate is restored by - // sigreturn(2). - // - // This is the reason why 'tgkill(getpid(), gettid(), SIGUSR1)' is implemented - // in inline assembly below. - // - // If the OS is broken and registers are clobbered by the signal, using tgkill - // to signal the current thread ensures that this is the clobbered thread. - - uint64_t expected = 0xdeadbeeffacefeed; - SET_XMM(expected, xmm0); - - asm volatile( - "movl %[killnr], %%eax;" - "movl %[pid], %%edi;" - "movl %[tid], %%esi;" - "movl %[sig], %%edx;" - "syscall;" - : - : [ killnr ] "i"(__NR_tgkill), [ pid ] "rm"(pid), [ tid ] "rm"(tid), - [ sig ] "i"(SIGUSR1) - : "rax", "rdi", "rsi", "rdx", - // Clobbered by syscall. - "rcx", "r11"); - - uint64_t got; - GET_XMM(got, xmm0); - - // - // The checks below verifies the following: - // - signal handlers must called with a clean fpu state. - // - sigreturn(2) must restore fpstate of the interrupted context. - // - EXPECT_EQ(expected, got); - EXPECT_EQ(entryxmm[0], 0); - EXPECT_EQ(entryxmm[1], 0); - EXPECT_EQ(exitxmm[0], SIGUSR1); - EXPECT_EQ(exitxmm[1], SIGUSR2); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/fsync.cc b/test/syscalls/linux/fsync.cc deleted file mode 100644 index e7e057f06..000000000 --- a/test/syscalls/linux/fsync.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdio.h> -#include <unistd.h> - -#include <string> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(FsyncTest, TempFileSucceeds) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); - const std::string data = "some data to sync"; - EXPECT_THAT(write(fd.get(), data.c_str(), data.size()), - SyscallSucceedsWithValue(data.size())); - EXPECT_THAT(fsync(fd.get()), SyscallSucceeds()); -} - -TEST(FsyncTest, TempDirSucceeds) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); - EXPECT_THAT(fsync(fd.get()), SyscallSucceeds()); -} - -TEST(FsyncTest, CannotFsyncOnUnopenedFd) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - int fd; - ASSERT_THAT(fd = open(file.path().c_str(), O_RDONLY), SyscallSucceeds()); - ASSERT_THAT(close(fd), SyscallSucceeds()); - - // fd is now invalid. - EXPECT_THAT(fsync(fd), SyscallFailsWithErrno(EBADF)); -} -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc deleted file mode 100644 index 40c80a6e1..000000000 --- a/test/syscalls/linux/futex.cc +++ /dev/null @@ -1,742 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <linux/futex.h> -#include <linux/types.h> -#include <sys/syscall.h> -#include <sys/time.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> -#include <atomic> -#include <memory> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/memory_util.h" -#include "test/util/save_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/time_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Amount of time we wait for threads doing futex_wait to start running before -// doing futex_wake. -constexpr auto kWaiterStartupDelay = absl::Seconds(3); - -// Default timeout for waiters in tests where we expect a futex_wake to be -// ineffective. -constexpr auto kIneffectiveWakeTimeout = absl::Seconds(6); - -static_assert(kWaiterStartupDelay < kIneffectiveWakeTimeout, - "futex_wait will time out before futex_wake is called"); - -int futex_wait(bool priv, std::atomic<int>* uaddr, int val, - absl::Duration timeout = absl::InfiniteDuration()) { - int op = FUTEX_WAIT; - if (priv) { - op |= FUTEX_PRIVATE_FLAG; - } - - if (timeout == absl::InfiniteDuration()) { - return RetryEINTR(syscall)(SYS_futex, uaddr, op, val, nullptr); - } - - // FUTEX_WAIT doesn't adjust the timeout if it returns EINTR, so we have to do - // so. - while (true) { - auto const timeout_ts = absl::ToTimespec(timeout); - MonotonicTimer timer; - timer.Start(); - int const ret = syscall(SYS_futex, uaddr, op, val, &timeout_ts); - if (ret != -1 || errno != EINTR) { - return ret; - } - timeout = std::max(timeout - timer.Duration(), absl::ZeroDuration()); - } -} - -int futex_wait_bitset(bool priv, std::atomic<int>* uaddr, int val, int bitset, - absl::Time deadline = absl::InfiniteFuture()) { - int op = FUTEX_WAIT_BITSET | FUTEX_CLOCK_REALTIME; - if (priv) { - op |= FUTEX_PRIVATE_FLAG; - } - - auto const deadline_ts = absl::ToTimespec(deadline); - return RetryEINTR(syscall)( - SYS_futex, uaddr, op, val, - deadline == absl::InfiniteFuture() ? nullptr : &deadline_ts, nullptr, - bitset); -} - -int futex_wake(bool priv, std::atomic<int>* uaddr, int count) { - int op = FUTEX_WAKE; - if (priv) { - op |= FUTEX_PRIVATE_FLAG; - } - return syscall(SYS_futex, uaddr, op, count); -} - -int futex_wake_bitset(bool priv, std::atomic<int>* uaddr, int count, - int bitset) { - int op = FUTEX_WAKE_BITSET; - if (priv) { - op |= FUTEX_PRIVATE_FLAG; - } - return syscall(SYS_futex, uaddr, op, count, nullptr, nullptr, bitset); -} - -int futex_wake_op(bool priv, std::atomic<int>* uaddr1, std::atomic<int>* uaddr2, - int nwake1, int nwake2, uint32_t sub_op) { - int op = FUTEX_WAKE_OP; - if (priv) { - op |= FUTEX_PRIVATE_FLAG; - } - return syscall(SYS_futex, uaddr1, op, nwake1, nwake2, uaddr2, sub_op); -} - -int futex_lock_pi(bool priv, std::atomic<int>* uaddr) { - int op = FUTEX_LOCK_PI; - if (priv) { - op |= FUTEX_PRIVATE_FLAG; - } - int zero = 0; - if (uaddr->compare_exchange_strong(zero, gettid())) { - return 0; - } - return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); -} - -int futex_trylock_pi(bool priv, std::atomic<int>* uaddr) { - int op = FUTEX_TRYLOCK_PI; - if (priv) { - op |= FUTEX_PRIVATE_FLAG; - } - int zero = 0; - if (uaddr->compare_exchange_strong(zero, gettid())) { - return 0; - } - return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); -} - -int futex_unlock_pi(bool priv, std::atomic<int>* uaddr) { - int op = FUTEX_UNLOCK_PI; - if (priv) { - op |= FUTEX_PRIVATE_FLAG; - } - int tid = gettid(); - if (uaddr->compare_exchange_strong(tid, 0)) { - return 0; - } - return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); -} - -// Fixture for futex tests parameterized by whether to use private or shared -// futexes. -class PrivateAndSharedFutexTest : public ::testing::TestWithParam<bool> { - protected: - bool IsPrivate() const { return GetParam(); } - int PrivateFlag() const { return IsPrivate() ? FUTEX_PRIVATE_FLAG : 0; } -}; - -// FUTEX_WAIT with 0 timeout does not block. -TEST_P(PrivateAndSharedFutexTest, Wait_ZeroTimeout) { - struct timespec timeout = {}; - - // Don't use the futex_wait helper because it adjusts timeout. - int a = 1; - EXPECT_THAT(syscall(SYS_futex, &a, FUTEX_WAIT | PrivateFlag(), a, &timeout), - SyscallFailsWithErrno(ETIMEDOUT)); -} - -TEST_P(PrivateAndSharedFutexTest, Wait_Timeout) { - std::atomic<int> a = ATOMIC_VAR_INIT(1); - - MonotonicTimer timer; - timer.Start(); - constexpr absl::Duration kTimeout = absl::Seconds(1); - EXPECT_THAT(futex_wait(IsPrivate(), &a, a, kTimeout), - SyscallFailsWithErrno(ETIMEDOUT)); - EXPECT_GE(timer.Duration(), kTimeout); -} - -TEST_P(PrivateAndSharedFutexTest, Wait_BitsetTimeout) { - std::atomic<int> a = ATOMIC_VAR_INIT(1); - - MonotonicTimer timer; - timer.Start(); - constexpr absl::Duration kTimeout = absl::Seconds(1); - EXPECT_THAT( - futex_wait_bitset(IsPrivate(), &a, a, 0xffffffff, absl::Now() + kTimeout), - SyscallFailsWithErrno(ETIMEDOUT)); - EXPECT_GE(timer.Duration(), kTimeout); -} - -TEST_P(PrivateAndSharedFutexTest, WaitBitset_NegativeTimeout) { - std::atomic<int> a = ATOMIC_VAR_INIT(1); - - MonotonicTimer timer; - timer.Start(); - EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, a, 0xffffffff, - absl::Now() - absl::Seconds(1)), - SyscallFailsWithErrno(ETIMEDOUT)); -} - -TEST_P(PrivateAndSharedFutexTest, Wait_WrongVal) { - std::atomic<int> a = ATOMIC_VAR_INIT(1); - EXPECT_THAT(futex_wait(IsPrivate(), &a, a + 1), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(PrivateAndSharedFutexTest, Wait_ZeroBitset) { - std::atomic<int> a = ATOMIC_VAR_INIT(1); - EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, a, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - // Prevent save/restore from interrupting futex_wait, which will cause it to - // return EAGAIN instead of the expected result if futex_wait is restarted - // after we change the value of a below. - DisableSave ds; - ScopedThread thread([&] { - EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), - SyscallSucceedsWithValue(0)); - }); - absl::SleepFor(kWaiterStartupDelay); - - // Change a so that if futex_wake happens before futex_wait, the latter - // returns EAGAIN instead of hanging the test. - a.fetch_add(1); - EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1)); -} - -TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - // Prevent save/restore from interrupting futex_wait, which will cause it to - // return EAGAIN instead of the expected result if futex_wait is restarted - // after we change the value of a below. - DisableSave ds; - ScopedThread thread([&] { - EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), - SyscallSucceedsWithValue(0)); - }); - absl::SleepFor(kWaiterStartupDelay); - - // Change a so that if futex_wake happens before futex_wait, the latter - // returns EAGAIN instead of hanging the test. - a.fetch_add(1); - // The Linux kernel wakes one waiter even if val is 0 or negative. - EXPECT_THAT(futex_wake(IsPrivate(), &a, 0), SyscallSucceedsWithValue(1)); -} - -TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - DisableSave ds; - constexpr int kThreads = 5; - std::vector<std::unique_ptr<ScopedThread>> threads; - threads.reserve(kThreads); - for (int i = 0; i < kThreads; i++) { - threads.push_back(absl::make_unique<ScopedThread>([&] { - EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), - SyscallSucceeds()); - })); - } - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - EXPECT_THAT(futex_wake(IsPrivate(), &a, kThreads), - SyscallSucceedsWithValue(kThreads)); -} - -TEST_P(PrivateAndSharedFutexTest, WakeSome_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - DisableSave ds; - constexpr int kThreads = 5; - constexpr int kWokenThreads = 3; - static_assert(kWokenThreads < kThreads, - "can't wake more threads than are created"); - std::vector<std::unique_ptr<ScopedThread>> threads; - threads.reserve(kThreads); - std::vector<int> rets; - rets.reserve(kThreads); - std::vector<int> errs; - errs.reserve(kThreads); - for (int i = 0; i < kThreads; i++) { - rets.push_back(-1); - errs.push_back(0); - } - for (int i = 0; i < kThreads; i++) { - threads.push_back(absl::make_unique<ScopedThread>([&, i] { - rets[i] = - futex_wait(IsPrivate(), &a, kInitialValue, kIneffectiveWakeTimeout); - errs[i] = errno; - })); - } - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - EXPECT_THAT(futex_wake(IsPrivate(), &a, kWokenThreads), - SyscallSucceedsWithValue(kWokenThreads)); - - int woken = 0; - int timedout = 0; - for (int i = 0; i < kThreads; i++) { - threads[i]->Join(); - if (rets[i] == 0) { - woken++; - } else if (errs[i] == ETIMEDOUT) { - timedout++; - } else { - ADD_FAILURE() << " thread " << i << ": returned " << rets[i] << ", errno " - << errs[i]; - } - } - EXPECT_EQ(woken, kWokenThreads); - EXPECT_EQ(timedout, kThreads - kWokenThreads); -} - -TEST_P(PrivateAndSharedFutexTest, WaitBitset_Wake_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - DisableSave ds; - ScopedThread thread([&] { - EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, kInitialValue, 0b01001000), - SyscallSucceeds()); - }); - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1)); -} - -TEST_P(PrivateAndSharedFutexTest, Wait_WakeBitset_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - DisableSave ds; - ScopedThread thread([&] { - EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), SyscallSucceeds()); - }); - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - EXPECT_THAT(futex_wake_bitset(IsPrivate(), &a, 1, 0b01001000), - SyscallSucceedsWithValue(1)); -} - -TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetMatch_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - constexpr int kBitset = 0b01001000; - - DisableSave ds; - ScopedThread thread([&] { - EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, kInitialValue, kBitset), - SyscallSucceeds()); - }); - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - EXPECT_THAT(futex_wake_bitset(IsPrivate(), &a, 1, kBitset), - SyscallSucceedsWithValue(1)); -} - -TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetNoMatch_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - constexpr int kWaitBitset = 0b01000001; - constexpr int kWakeBitset = 0b00101000; - static_assert((kWaitBitset & kWakeBitset) == 0, - "futex_wake_bitset will wake waiter"); - - DisableSave ds; - ScopedThread thread([&] { - EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, kInitialValue, kWaitBitset, - absl::Now() + kIneffectiveWakeTimeout), - SyscallFailsWithErrno(ETIMEDOUT)); - }); - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - EXPECT_THAT(futex_wake_bitset(IsPrivate(), &a, 1, kWakeBitset), - SyscallSucceedsWithValue(0)); -} - -TEST_P(PrivateAndSharedFutexTest, WakeOpCondSuccess_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - std::atomic<int> b = ATOMIC_VAR_INIT(kInitialValue); - - DisableSave ds; - ScopedThread thread_a([&] { - EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), SyscallSucceeds()); - }); - ScopedThread thread_b([&] { - EXPECT_THAT(futex_wait(IsPrivate(), &b, kInitialValue), SyscallSucceeds()); - }); - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - b.fetch_add(1); - // This futex_wake_op should: - // - Wake 1 waiter on a unconditionally. - // - Wake 1 waiter on b if b == kInitialValue + 1, which it is. - // - Do "b += 1". - EXPECT_THAT(futex_wake_op(IsPrivate(), &a, &b, 1, 1, - FUTEX_OP(FUTEX_OP_ADD, 1, FUTEX_OP_CMP_EQ, - (kInitialValue + 1))), - SyscallSucceedsWithValue(2)); - EXPECT_EQ(b, kInitialValue + 2); -} - -TEST_P(PrivateAndSharedFutexTest, WakeOpCondFailure_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - std::atomic<int> b = ATOMIC_VAR_INIT(kInitialValue); - - DisableSave ds; - ScopedThread thread_a([&] { - EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), SyscallSucceeds()); - }); - ScopedThread thread_b([&] { - EXPECT_THAT( - futex_wait(IsPrivate(), &b, kInitialValue, kIneffectiveWakeTimeout), - SyscallFailsWithErrno(ETIMEDOUT)); - }); - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - b.fetch_add(1); - // This futex_wake_op should: - // - Wake 1 waiter on a unconditionally. - // - Wake 1 waiter on b if b == kInitialValue - 1, which it isn't. - // - Do "b += 1". - EXPECT_THAT(futex_wake_op(IsPrivate(), &a, &b, 1, 1, - FUTEX_OP(FUTEX_OP_ADD, 1, FUTEX_OP_CMP_EQ, - (kInitialValue - 1))), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(b, kInitialValue + 2); -} - -TEST_P(PrivateAndSharedFutexTest, NoWakeInterprocessPrivateAnon_NoRandomSave) { - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr()); - constexpr int kInitialValue = 1; - ptr->store(kInitialValue); - - DisableSave ds; - pid_t const child_pid = fork(); - if (child_pid == 0) { - TEST_PCHECK(futex_wait(IsPrivate(), ptr, kInitialValue, - kIneffectiveWakeTimeout) == -1 && - errno == ETIMEDOUT); - _exit(0); - } - ASSERT_THAT(child_pid, SyscallSucceeds()); - absl::SleepFor(kWaiterStartupDelay); - - EXPECT_THAT(futex_wake(IsPrivate(), ptr, 1), SyscallSucceedsWithValue(0)); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -TEST_P(PrivateAndSharedFutexTest, WakeAfterCOWBreak_NoRandomSave) { - // Use a futex on a non-stack mapping so we can be sure that the child process - // below isn't the one that breaks copy-on-write. - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr()); - constexpr int kInitialValue = 1; - ptr->store(kInitialValue); - - DisableSave ds; - ScopedThread thread([&] { - EXPECT_THAT(futex_wait(IsPrivate(), ptr, kInitialValue), SyscallSucceeds()); - }); - absl::SleepFor(kWaiterStartupDelay); - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // Wait to be killed by the parent. - while (true) pause(); - } - ASSERT_THAT(child_pid, SyscallSucceeds()); - auto cleanup_child = Cleanup([&] { - EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; - }); - - // In addition to preventing a late futex_wait from sleeping, this breaks - // copy-on-write on the mapped page. - ptr->fetch_add(1); - EXPECT_THAT(futex_wake(IsPrivate(), ptr, 1), SyscallSucceedsWithValue(1)); -} - -TEST_P(PrivateAndSharedFutexTest, WakeWrongKind_NoRandomSave) { - constexpr int kInitialValue = 1; - std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); - - DisableSave ds; - ScopedThread thread([&] { - EXPECT_THAT( - futex_wait(IsPrivate(), &a, kInitialValue, kIneffectiveWakeTimeout), - SyscallFailsWithErrno(ETIMEDOUT)); - }); - absl::SleepFor(kWaiterStartupDelay); - - a.fetch_add(1); - // The value of priv passed to futex_wake is the opposite of that passed to - // the futex_waiter; we expect this not to wake the waiter. - EXPECT_THAT(futex_wake(!IsPrivate(), &a, 1), SyscallSucceedsWithValue(0)); -} - -INSTANTIATE_TEST_SUITE_P(SharedPrivate, PrivateAndSharedFutexTest, - ::testing::Bool()); - -// Passing null as the address only works for private futexes. - -TEST(PrivateFutexTest, WakeOp0Set) { - std::atomic<int> a = ATOMIC_VAR_INIT(1); - - int futex_op = FUTEX_OP(FUTEX_OP_SET, 2, 0, 0); - EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(a, 2); -} - -TEST(PrivateFutexTest, WakeOp0Add) { - std::atomic<int> a = ATOMIC_VAR_INIT(1); - int futex_op = FUTEX_OP(FUTEX_OP_ADD, 1, 0, 0); - EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(a, 2); -} - -TEST(PrivateFutexTest, WakeOp0Or) { - std::atomic<int> a = ATOMIC_VAR_INIT(0b01); - int futex_op = FUTEX_OP(FUTEX_OP_OR, 0b10, 0, 0); - EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(a, 0b11); -} - -TEST(PrivateFutexTest, WakeOp0Andn) { - std::atomic<int> a = ATOMIC_VAR_INIT(0b11); - int futex_op = FUTEX_OP(FUTEX_OP_ANDN, 0b10, 0, 0); - EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(a, 0b01); -} - -TEST(PrivateFutexTest, WakeOp0Xor) { - std::atomic<int> a = ATOMIC_VAR_INIT(0b1010); - int futex_op = FUTEX_OP(FUTEX_OP_XOR, 0b1100, 0, 0); - EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(a, 0b0110); -} - -TEST(SharedFutexTest, WakeInterprocessSharedAnon_NoRandomSave) { - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED)); - auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr()); - constexpr int kInitialValue = 1; - ptr->store(kInitialValue); - - DisableSave ds; - pid_t const child_pid = fork(); - if (child_pid == 0) { - TEST_PCHECK(futex_wait(false, ptr, kInitialValue) == 0); - _exit(0); - } - ASSERT_THAT(child_pid, SyscallSucceeds()); - auto kill_child = Cleanup( - [&] { EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); }); - absl::SleepFor(kWaiterStartupDelay); - - ptr->fetch_add(1); - // This is an ASSERT so that if it fails, we immediately abort the test (and - // kill the subprocess). - ASSERT_THAT(futex_wake(false, ptr, 1), SyscallSucceedsWithValue(1)); - - kill_child.Release(); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -TEST(SharedFutexTest, WakeInterprocessFile_NoRandomSave) { - auto const file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_THAT(truncate(file.path().c_str(), kPageSize), SyscallSucceeds()); - auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0)); - auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr()); - constexpr int kInitialValue = 1; - ptr->store(kInitialValue); - - DisableSave ds; - pid_t const child_pid = fork(); - if (child_pid == 0) { - TEST_PCHECK(futex_wait(false, ptr, kInitialValue) == 0); - _exit(0); - } - ASSERT_THAT(child_pid, SyscallSucceeds()); - auto kill_child = Cleanup( - [&] { EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); }); - absl::SleepFor(kWaiterStartupDelay); - - ptr->fetch_add(1); - // This is an ASSERT so that if it fails, we immediately abort the test (and - // kill the subprocess). - ASSERT_THAT(futex_wake(false, ptr, 1), SyscallSucceedsWithValue(1)); - - kill_child.Release(); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -TEST_P(PrivateAndSharedFutexTest, PIBasic) { - std::atomic<int> a = ATOMIC_VAR_INIT(0); - - ASSERT_THAT(futex_lock_pi(IsPrivate(), &a), SyscallSucceeds()); - EXPECT_EQ(a.load(), gettid()); - EXPECT_THAT(futex_lock_pi(IsPrivate(), &a), SyscallFailsWithErrno(EDEADLK)); - - ASSERT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallSucceeds()); - EXPECT_EQ(a.load(), 0); - EXPECT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallFailsWithErrno(EPERM)); -} - -TEST_P(PrivateAndSharedFutexTest, PIConcurrency_NoRandomSave) { - DisableSave ds; // Too many syscalls. - - std::atomic<int> a = ATOMIC_VAR_INIT(0); - const bool is_priv = IsPrivate(); - - std::unique_ptr<ScopedThread> threads[100]; - for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) { - threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] { - for (size_t j = 0; j < 10; ++j) { - ASSERT_THAT(futex_lock_pi(is_priv, &a), SyscallSucceeds()); - EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid()); - SleepSafe(absl::Milliseconds(5)); - ASSERT_THAT(futex_unlock_pi(is_priv, &a), SyscallSucceeds()); - } - }); - } -} - -TEST_P(PrivateAndSharedFutexTest, PIWaiters) { - std::atomic<int> a = ATOMIC_VAR_INIT(0); - const bool is_priv = IsPrivate(); - - ASSERT_THAT(futex_lock_pi(is_priv, &a), SyscallSucceeds()); - EXPECT_EQ(a.load(), gettid()); - - ScopedThread th([is_priv, &a] { - ASSERT_THAT(futex_lock_pi(is_priv, &a), SyscallSucceeds()); - ASSERT_THAT(futex_unlock_pi(is_priv, &a), SyscallSucceeds()); - }); - - // Wait until the thread blocks on the futex, setting the waiters bit. - auto start = absl::Now(); - while (a.load() != (FUTEX_WAITERS | gettid())) { - ASSERT_LT(absl::Now() - start, absl::Seconds(5)); - absl::SleepFor(absl::Milliseconds(100)); - } - ASSERT_THAT(futex_unlock_pi(is_priv, &a), SyscallSucceeds()); -} - -TEST_P(PrivateAndSharedFutexTest, PITryLock) { - std::atomic<int> a = ATOMIC_VAR_INIT(0); - const bool is_priv = IsPrivate(); - - ASSERT_THAT(futex_trylock_pi(IsPrivate(), &a), SyscallSucceeds()); - EXPECT_EQ(a.load(), gettid()); - - EXPECT_THAT(futex_trylock_pi(is_priv, &a), SyscallFailsWithErrno(EDEADLK)); - ScopedThread th([is_priv, &a] { - EXPECT_THAT(futex_trylock_pi(is_priv, &a), SyscallFailsWithErrno(EAGAIN)); - }); - th.Join(); - - ASSERT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallSucceeds()); -} - -TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) { - DisableSave ds; // Too many syscalls. - - std::atomic<int> a = ATOMIC_VAR_INIT(0); - const bool is_priv = IsPrivate(); - - std::unique_ptr<ScopedThread> threads[10]; - for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) { - threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] { - for (size_t j = 0; j < 10;) { - if (futex_trylock_pi(is_priv, &a) == 0) { - ++j; - EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid()); - SleepSafe(absl::Milliseconds(5)); - ASSERT_THAT(futex_unlock_pi(is_priv, &a), SyscallSucceeds()); - } - } - }); - } -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/getcpu.cc b/test/syscalls/linux/getcpu.cc deleted file mode 100644 index f4d94bd6a..000000000 --- a/test/syscalls/linux/getcpu.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sched.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(GetcpuTest, IsValidCpuStress) { - const int num_cpus = NumCPUs(); - absl::Time deadline = absl::Now() + absl::Seconds(10); - while (absl::Now() < deadline) { - int cpu; - ASSERT_THAT(cpu = sched_getcpu(), SyscallSucceeds()); - ASSERT_LT(cpu, num_cpus); - } -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc deleted file mode 100644 index b147d6181..000000000 --- a/test/syscalls/linux/getdents.cc +++ /dev/null @@ -1,539 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <dirent.h> -#include <errno.h> -#include <fcntl.h> -#include <stddef.h> -#include <stdint.h> -#include <stdio.h> -#include <string.h> -#include <sys/mman.h> -#include <sys/types.h> -#include <syscall.h> -#include <unistd.h> - -#include <map> -#include <string> -#include <unordered_map> -#include <unordered_set> -#include <utility> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "test/util/eventfd_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -using ::testing::Contains; -using ::testing::IsEmpty; -using ::testing::IsSupersetOf; -using ::testing::Not; -using ::testing::NotNull; - -namespace gvisor { -namespace testing { - -namespace { - -// New Linux dirent format. -struct linux_dirent64 { - uint64_t d_ino; // Inode number - int64_t d_off; // Offset to next linux_dirent64 - unsigned short d_reclen; // NOLINT, Length of this linux_dirent64 - unsigned char d_type; // NOLINT, File type - char d_name[0]; // Filename (null-terminated) -}; - -// Old Linux dirent format. -struct linux_dirent { - unsigned long d_ino; // NOLINT - unsigned long d_off; // NOLINT - unsigned short d_reclen; // NOLINT - char d_name[0]; -}; - -// Wraps a buffer to provide a set of dirents. -// T is the underlying dirent type. -template <typename T> -class DirentBuffer { - public: - // DirentBuffer manages the buffer. - explicit DirentBuffer(size_t size) - : managed_(true), actual_size_(size), reported_size_(size) { - data_ = new char[actual_size_]; - } - - // The buffer is managed externally. - DirentBuffer(char* data, size_t actual_size, size_t reported_size) - : managed_(false), - data_(data), - actual_size_(actual_size), - reported_size_(reported_size) {} - - ~DirentBuffer() { - if (managed_) { - delete[] data_; - } - } - - T* Data() { return reinterpret_cast<T*>(data_); } - - T* Start(size_t read) { - read_ = read; - if (read_) { - return Data(); - } else { - return nullptr; - } - } - - T* Current() { return reinterpret_cast<T*>(&data_[off_]); } - - T* Next() { - size_t new_off = off_ + Current()->d_reclen; - if (new_off >= read_ || new_off >= actual_size_) { - return nullptr; - } - - off_ = new_off; - return Current(); - } - - size_t Size() { return reported_size_; } - - void Reset() { - off_ = 0; - read_ = 0; - memset(data_, 0, actual_size_); - } - - private: - bool managed_; - char* data_; - size_t actual_size_; - size_t reported_size_; - - size_t off_ = 0; - - size_t read_ = 0; -}; - -// Test for getdents/getdents64. -// T is the Linux dirent type. -template <typename T> -class GetdentsTest : public ::testing::Test { - public: - using LinuxDirentType = T; - using DirentBufferType = DirentBuffer<T>; - - protected: - void SetUp() override { - dir_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - fd_ = ASSERT_NO_ERRNO_AND_VALUE(Open(dir_.path(), O_RDONLY | O_DIRECTORY)); - } - - // Must be overridden with explicit specialization. See below. - int SyscallNum(); - - int Getdents(LinuxDirentType* dirp, unsigned int count) { - return RetryEINTR(syscall)(SyscallNum(), fd_.get(), dirp, count); - } - - // Fill directory with num files, named by number starting at 0. - void FillDirectory(size_t num) { - for (size_t i = 0; i < num; i++) { - auto name = JoinPath(dir_.path(), absl::StrCat(i)); - TEST_CHECK(CreateWithContents(name, "").ok()); - } - } - - // Fill directory with a given list of filenames. - void FillDirectoryWithFiles(const std::vector<std::string>& filenames) { - for (const auto& filename : filenames) { - auto name = JoinPath(dir_.path(), filename); - TEST_CHECK(CreateWithContents(name, "").ok()); - } - } - - // Seek to the start of the directory. - PosixError SeekStart() { - constexpr off_t kStartOfFile = 0; - off_t offset = lseek(fd_.get(), kStartOfFile, SEEK_SET); - if (offset < 0) { - return PosixError(errno, absl::StrCat("error seeking to ", kStartOfFile)); - } - if (offset != kStartOfFile) { - return PosixError(EINVAL, absl::StrCat("tried to seek to ", kStartOfFile, - " but got ", offset)); - } - return NoError(); - } - - // Call getdents multiple times, reading all dirents and calling f on each. - // f has the type signature PosixError f(T*). - // If f returns a non-OK error, so does ReadDirents. - template <typename F> - PosixError ReadDirents(DirentBufferType* dirents, F const& f) { - int n; - do { - dirents->Reset(); - - n = Getdents(dirents->Data(), dirents->Size()); - MaybeSave(); - if (n < 0) { - return PosixError(errno, "getdents"); - } - - for (auto d = dirents->Start(n); d; d = dirents->Next()) { - RETURN_IF_ERRNO(f(d)); - } - } while (n > 0); - - return NoError(); - } - - // Call Getdents successively and count all entries. - int ReadAndCountAllEntries(DirentBufferType* dirents) { - int found = 0; - - EXPECT_NO_ERRNO(ReadDirents(dirents, [&](LinuxDirentType* d) { - found++; - return NoError(); - })); - - return found; - } - - private: - TempPath dir_; - FileDescriptor fd_; -}; - -// Multiple template parameters are not allowed, so we must use explicit -// template specialization to set the syscall number. - -// SYS_getdents isn't defined on arm64. -#ifdef __x86_64__ -template <> -int GetdentsTest<struct linux_dirent>::SyscallNum() { - return SYS_getdents; -} -#endif - -template <> -int GetdentsTest<struct linux_dirent64>::SyscallNum() { - return SYS_getdents64; -} - -#ifdef __x86_64__ -// Test both legacy getdents and getdents64 on x86_64. -typedef ::testing::Types<struct linux_dirent, struct linux_dirent64> - GetdentsTypes; -#elif __aarch64__ -// Test only getdents64 on arm64. -typedef ::testing::Types<struct linux_dirent64> GetdentsTypes; -#endif -TYPED_TEST_SUITE(GetdentsTest, GetdentsTypes); - -// N.B. TYPED_TESTs require explicitly using this-> to access members of -// GetdentsTest, since we are inside of a derived class template. - -TYPED_TEST(GetdentsTest, VerifyEntries) { - typename TestFixture::DirentBufferType dirents(1024); - - this->FillDirectory(2); - - // Map of all the entries we expect to find. - std::map<std::string, bool> found; - found["."] = false; - found[".."] = false; - found["0"] = false; - found["1"] = false; - - EXPECT_NO_ERRNO(this->ReadDirents( - &dirents, [&](typename TestFixture::LinuxDirentType* d) { - auto kv = found.find(d->d_name); - EXPECT_NE(kv, found.end()) << "Unexpected file: " << d->d_name; - if (kv != found.end()) { - EXPECT_FALSE(kv->second); - } - found[d->d_name] = true; - return NoError(); - })); - - for (auto& kv : found) { - EXPECT_TRUE(kv.second) << "File not found: " << kv.first; - } -} - -TYPED_TEST(GetdentsTest, VerifyPadding) { - typename TestFixture::DirentBufferType dirents(1024); - - // Create files with names of length 1 through 16. - std::vector<std::string> files; - std::string filename; - for (int i = 0; i < 16; ++i) { - absl::StrAppend(&filename, "a"); - files.push_back(filename); - } - this->FillDirectoryWithFiles(files); - - // We expect to find all the files, plus '.' and '..'. - const int expect_found = 2 + files.size(); - int found = 0; - - EXPECT_NO_ERRNO(this->ReadDirents( - &dirents, [&](typename TestFixture::LinuxDirentType* d) { - EXPECT_EQ(d->d_reclen % 8, 0) - << "Dirent " << d->d_name - << " had reclen that was not byte aligned: " << d->d_name; - found++; - return NoError(); - })); - - // Make sure we found all the files. - EXPECT_EQ(found, expect_found); -} - -// For a small directory, the provided buffer should be large enough -// for all entries. -TYPED_TEST(GetdentsTest, SmallDir) { - // . and .. should be in an otherwise empty directory. - int expect = 2; - - // Add some actual files. - this->FillDirectory(2); - expect += 2; - - typename TestFixture::DirentBufferType dirents(256); - - EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents)); -} - -// A directory with lots of files requires calling getdents multiple times. -TYPED_TEST(GetdentsTest, LargeDir) { - // . and .. should be in an otherwise empty directory. - int expect = 2; - - // Add some actual files. - this->FillDirectory(100); - expect += 100; - - typename TestFixture::DirentBufferType dirents(256); - - EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents)); -} - -// If we lie about the size of the buffer, we should still be able to read the -// entries with the available space. -TYPED_TEST(GetdentsTest, PartialBuffer) { - // . and .. should be in an otherwise empty directory. - int expect = 2; - - // Add some actual files. - this->FillDirectory(100); - expect += 100; - - void* addr = mmap(0, 2 * kPageSize, PROT_READ | PROT_WRITE, - MAP_ANONYMOUS | MAP_PRIVATE, -1, 0); - ASSERT_NE(addr, MAP_FAILED); - - char* buf = reinterpret_cast<char*>(addr); - - // Guard page - EXPECT_THAT( - mprotect(reinterpret_cast<void*>(buf + kPageSize), kPageSize, PROT_NONE), - SyscallSucceeds()); - - // Limit space in buf to 256 bytes. - buf += kPageSize - 256; - - // Lie about the buffer. Even though we claim the buffer is 1 page, - // we should still get all of the dirents in the first 256 bytes. - typename TestFixture::DirentBufferType dirents(buf, 256, kPageSize); - - EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents)); - - EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()); -} - -// Open many file descriptors, then scan through /proc/self/fd to find and close -// them all. (The latter is commonly used to handle races between fork/execve -// and the creation of unwanted non-O_CLOEXEC file descriptors.) This tests that -// getdents iterates correctly despite mutation of /proc/self/fd. -TYPED_TEST(GetdentsTest, ProcSelfFd) { - constexpr size_t kNfds = 10; - std::unordered_map<int, FileDescriptor> fds; - fds.reserve(kNfds); - for (size_t i = 0; i < kNfds; i++) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); - fds.emplace(fd.get(), std::move(fd)); - } - - const FileDescriptor proc_self_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/fd", O_RDONLY | O_DIRECTORY)); - - // Make the buffer very small since we want to iterate. - typename TestFixture::DirentBufferType dirents( - 2 * sizeof(typename TestFixture::LinuxDirentType)); - std::unordered_set<int> prev_fds; - while (true) { - dirents.Reset(); - int rv; - ASSERT_THAT(rv = RetryEINTR(syscall)(this->SyscallNum(), proc_self_fd.get(), - dirents.Data(), dirents.Size()), - SyscallSucceeds()); - if (rv == 0) { - break; - } - for (auto* d = dirents.Start(rv); d; d = dirents.Next()) { - int dfd; - if (!absl::SimpleAtoi(d->d_name, &dfd)) continue; - EXPECT_TRUE(prev_fds.insert(dfd).second) - << "Repeated observation of /proc/self/fd/" << dfd; - fds.erase(dfd); - } - } - - // Check that we closed every fd. - EXPECT_THAT(fds, ::testing::IsEmpty()); -} - -// Test that getdents returns ENOTDIR when called on a file. -TYPED_TEST(GetdentsTest, NotDir) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - typename TestFixture::DirentBufferType dirents(256); - EXPECT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(), - dirents.Size()), - SyscallFailsWithErrno(ENOTDIR)); -} - -// Test that SEEK_SET to 0 causes getdents to re-read the entries. -TYPED_TEST(GetdentsTest, SeekResetsCursor) { - // . and .. should be in an otherwise empty directory. - int expect = 2; - - // Add some files to the directory. - this->FillDirectory(10); - expect += 10; - - typename TestFixture::DirentBufferType dirents(256); - - // We should get all the expected entries. - EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents)); - - // Seek back to 0. - ASSERT_NO_ERRNO(this->SeekStart()); - - // We should get all the expected entries again. - EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents)); -} - -// Test that getdents() after SEEK_END succeeds. -// This is a regression test for #128. -TYPED_TEST(GetdentsTest, Issue128ProcSeekEnd) { - auto fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self", O_RDONLY | O_DIRECTORY)); - typename TestFixture::DirentBufferType dirents(256); - - ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(), - dirents.Size()), - SyscallSucceeds()); -} - -// Some tests using the glibc readdir interface. -TEST(ReaddirTest, OpenDir) { - DIR* dev; - ASSERT_THAT(dev = opendir("/dev"), NotNull()); - EXPECT_THAT(closedir(dev), SyscallSucceeds()); -} - -TEST(ReaddirTest, RootContainsBasicDirectories) { - EXPECT_THAT(ListDir("/", true), - IsPosixErrorOkAndHolds(IsSupersetOf( - {"bin", "dev", "etc", "lib", "proc", "sbin", "usr"}))); -} - -TEST(ReaddirTest, Bug24096713Dev) { - auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/dev", true)); - EXPECT_THAT(contents, Not(IsEmpty())); -} - -TEST(ReaddirTest, Bug24096713ProcTid) { - auto contents = ASSERT_NO_ERRNO_AND_VALUE( - ListDir(absl::StrCat("/proc/", syscall(SYS_gettid), "/"), true)); - EXPECT_THAT(contents, Not(IsEmpty())); -} - -TEST(ReaddirTest, Bug33429925Proc) { - auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc", true)); - EXPECT_THAT(contents, Not(IsEmpty())); -} - -TEST(ReaddirTest, Bug35110122Root) { - auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/", true)); - EXPECT_THAT(contents, Not(IsEmpty())); -} - -// Unlink should invalidate getdents cache. -TEST(ReaddirTest, GoneAfterRemoveCache) { - TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - std::string name = std::string(Basename(file.path())); - - auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), true)); - EXPECT_THAT(contents, Contains(name)); - - file.reset(); - - contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), true)); - EXPECT_THAT(contents, Not(Contains(name))); -} - -// Regression test for b/137398511. Rename should invalidate getdents cache. -TEST(ReaddirTest, GoneAfterRenameCache) { - TempPath src = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath dst = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(src.path())); - std::string name = std::string(Basename(file.path())); - - auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(src.path(), true)); - EXPECT_THAT(contents, Contains(name)); - - ASSERT_THAT(rename(file.path().c_str(), JoinPath(dst.path(), name).c_str()), - SyscallSucceeds()); - // Release file since it was renamed. dst cleanup will ultimately delete it. - file.release(); - - contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(src.path(), true)); - EXPECT_THAT(contents, Not(Contains(name))); - - contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dst.path(), true)); - EXPECT_THAT(contents, Contains(name)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc deleted file mode 100644 index f97f60029..000000000 --- a/test/syscalls/linux/getrandom.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -#ifndef SYS_getrandom -#if defined(__x86_64__) -#define SYS_getrandom 318 -#elif defined(__i386__) -#define SYS_getrandom 355 -#else -#error "Unknown architecture" -#endif -#endif // SYS_getrandom - -bool SomeByteIsNonZero(char* random_bytes, int length) { - for (int i = 0; i < length; i++) { - if (random_bytes[i] != 0) { - return true; - } - } - return false; -} - -TEST(GetrandomTest, IsRandom) { - // This test calls get_random and makes sure that the array is filled in with - // something that is non-zero. Perhaps we get back \x00\x00\x00\x00\x00.... as - // a random result, but it's so unlikely that we'll just ignore this. - char random_bytes[64] = {}; - int n = syscall(SYS_getrandom, random_bytes, 64, 0); - SKIP_IF(!IsRunningOnGvisor() && n < 0 && errno == ENOSYS); - EXPECT_THAT(n, SyscallSucceeds()); - EXPECT_GT(n, 0); // Some bytes should be returned. - EXPECT_TRUE(SomeByteIsNonZero(random_bytes, n)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/getrusage.cc b/test/syscalls/linux/getrusage.cc deleted file mode 100644 index 0e51d42a8..000000000 --- a/test/syscalls/linux/getrusage.cc +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <sys/mman.h> -#include <sys/resource.h> -#include <sys/types.h> -#include <sys/wait.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/logging.h" -#include "test/util/memory_util.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(GetrusageTest, BasicFork) { - pid_t pid = fork(); - if (pid == 0) { - struct rusage rusage_self; - TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); - struct rusage rusage_children; - TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0); - // The child has consumed some memory. - TEST_CHECK(rusage_self.ru_maxrss != 0); - // The child has no children of its own. - TEST_CHECK(rusage_children.ru_maxrss == 0); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), SyscallSucceeds()); - struct rusage rusage_self; - ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds()); - struct rusage rusage_children; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds()); - // The parent has consumed some memory. - EXPECT_GT(rusage_self.ru_maxrss, 0); - // The child has consumed some memory, and because it has exited we can get - // its max RSS. - EXPECT_GT(rusage_children.ru_maxrss, 0); -} - -// Verifies that a process can get the max resident set size of its grandchild, -// i.e. that maxrss propagates correctly from children to waiting parents. -TEST(GetrusageTest, Grandchild) { - constexpr int kGrandchildSizeKb = 1024; - pid_t pid = fork(); - if (pid == 0) { - pid = fork(); - if (pid == 0) { - int flags = MAP_ANONYMOUS | MAP_POPULATE | MAP_PRIVATE; - void* addr = - mmap(nullptr, kGrandchildSizeKb * 1024, PROT_WRITE, flags, -1, 0); - TEST_PCHECK(addr != MAP_FAILED); - } else { - int status; - TEST_PCHECK(RetryEINTR(waitpid)(pid, &status, 0) == pid); - } - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), SyscallSucceeds()); - struct rusage rusage_self; - ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds()); - struct rusage rusage_children; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds()); - // The parent has consumed some memory. - EXPECT_GT(rusage_self.ru_maxrss, 0); - // The child should consume next to no memory, but the grandchild will - // consume at least 1MB. Verify that usage bubbles up to the grandparent. - EXPECT_GT(rusage_children.ru_maxrss, kGrandchildSizeKb); -} - -// Verifies that processes ignoring SIGCHLD do not have updated child maxrss -// updated. -TEST(GetrusageTest, IgnoreSIGCHLD) { - struct sigaction sa; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa)); - pid_t pid = fork(); - if (pid == 0) { - struct rusage rusage_self; - TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); - // The child has consumed some memory. - TEST_CHECK(rusage_self.ru_maxrss != 0); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), - SyscallFailsWithErrno(ECHILD)); - struct rusage rusage_self; - ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds()); - struct rusage rusage_children; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds()); - // The parent has consumed some memory. - EXPECT_GT(rusage_self.ru_maxrss, 0); - // The child's maxrss should not have propagated up. - EXPECT_EQ(rusage_children.ru_maxrss, 0); -} - -// Verifies that zombie processes do not update their parent's maxrss. Only -// reaped processes should do this. -TEST(GetrusageTest, IgnoreZombie) { - pid_t pid = fork(); - if (pid == 0) { - struct rusage rusage_self; - TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); - struct rusage rusage_children; - TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0); - // The child has consumed some memory. - TEST_CHECK(rusage_self.ru_maxrss != 0); - // The child has no children of its own. - TEST_CHECK(rusage_children.ru_maxrss == 0); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - // Give the child time to exit. Because we don't call wait, the child should - // remain a zombie. - absl::SleepFor(absl::Seconds(5)); - struct rusage rusage_self; - ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds()); - struct rusage rusage_children; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds()); - // The parent has consumed some memory. - EXPECT_GT(rusage_self.ru_maxrss, 0); - // The child has consumed some memory, but hasn't been reaped. - EXPECT_EQ(rusage_children.ru_maxrss, 0); -} - -TEST(GetrusageTest, Wait4) { - pid_t pid = fork(); - if (pid == 0) { - struct rusage rusage_self; - TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); - struct rusage rusage_children; - TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0); - // The child has consumed some memory. - TEST_CHECK(rusage_self.ru_maxrss != 0); - // The child has no children of its own. - TEST_CHECK(rusage_children.ru_maxrss == 0); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - struct rusage rusage_children; - int status; - ASSERT_THAT(RetryEINTR(wait4)(pid, &status, 0, &rusage_children), - SyscallSucceeds()); - // The child has consumed some memory, and because it has exited we can get - // its max RSS. - EXPECT_GT(rusage_children.ru_maxrss, 0); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc deleted file mode 100644 index 0e13ad190..000000000 --- a/test/syscalls/linux/inotify.cc +++ /dev/null @@ -1,1629 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <libgen.h> -#include <sched.h> -#include <sys/epoll.h> -#include <sys/inotify.h> -#include <sys/ioctl.h> -#include <sys/time.h> - -#include <atomic> -#include <list> -#include <string> -#include <vector> - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/epoll_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { -namespace { - -using ::absl::StreamFormat; -using ::absl::StrFormat; - -constexpr int kBufSize = 1024; - -// C++-friendly version of struct inotify_event. -struct Event { - int32_t wd; - uint32_t mask; - uint32_t cookie; - uint32_t len; - std::string name; - - Event(uint32_t mask, int32_t wd, absl::string_view name, uint32_t cookie) - : wd(wd), - mask(mask), - cookie(cookie), - len(name.size()), - name(std::string(name)) {} - Event(uint32_t mask, int32_t wd, absl::string_view name) - : Event(mask, wd, name, 0) {} - Event(uint32_t mask, int32_t wd) : Event(mask, wd, "", 0) {} - Event() : Event(0, 0, "", 0) {} -}; - -// Prints the symbolic name for a struct inotify_event's 'mask' field. -std::string FlagString(uint32_t flags) { - std::vector<std::string> names; - -#define EMIT(target) \ - if (flags & target) { \ - names.push_back(#target); \ - flags &= ~target; \ - } - - EMIT(IN_ACCESS); - EMIT(IN_ATTRIB); - EMIT(IN_CLOSE_WRITE); - EMIT(IN_CLOSE_NOWRITE); - EMIT(IN_CREATE); - EMIT(IN_DELETE); - EMIT(IN_DELETE_SELF); - EMIT(IN_MODIFY); - EMIT(IN_MOVE_SELF); - EMIT(IN_MOVED_FROM); - EMIT(IN_MOVED_TO); - EMIT(IN_OPEN); - - EMIT(IN_DONT_FOLLOW); - EMIT(IN_EXCL_UNLINK); - EMIT(IN_ONESHOT); - EMIT(IN_ONLYDIR); - - EMIT(IN_IGNORED); - EMIT(IN_ISDIR); - EMIT(IN_Q_OVERFLOW); - EMIT(IN_UNMOUNT); - -#undef EMIT - - // If we have anything left over at the end, print it as a hex value. - if (flags) { - names.push_back(absl::StrCat("0x", absl::Hex(flags))); - } - - return absl::StrJoin(names, "|"); -} - -std::string DumpEvent(const Event& event) { - return StrFormat( - "%s, wd=%d%s%s", FlagString(event.mask), event.wd, - (event.len > 0) ? StrFormat(", name=%s", event.name) : "", - (event.cookie > 0) ? StrFormat(", cookie=%ud", event.cookie) : ""); -} - -std::string DumpEvents(const std::vector<Event>& events, int indent_level) { - std::stringstream ss; - ss << StreamFormat("%d event%s:\n", events.size(), - (events.size() > 1) ? "s" : ""); - int i = 0; - for (const Event& ev : events) { - ss << StreamFormat("%sevents[%d]: %s\n", std::string(indent_level, '\t'), - i++, DumpEvent(ev)); - } - return ss.str(); -} - -// A matcher which takes an expected list of events to match against another -// list of inotify events, in order. This is similar to the ElementsAre matcher, -// but displays more informative messages on mismatch. -class EventsAreMatcher - : public ::testing::MatcherInterface<std::vector<Event>> { - public: - explicit EventsAreMatcher(std::vector<Event> references) - : references_(std::move(references)) {} - - bool MatchAndExplain( - std::vector<Event> events, - ::testing::MatchResultListener* const listener) const override { - if (references_.size() != events.size()) { - *listener << StreamFormat("\n\tCount mismatch, got %s", - DumpEvents(events, 2)); - return false; - } - - bool success = true; - for (unsigned int i = 0; i < references_.size(); ++i) { - const Event& reference = references_[i]; - const Event& target = events[i]; - - if (target.mask != reference.mask || target.wd != reference.wd || - target.name != reference.name || target.cookie != reference.cookie) { - *listener << StreamFormat("\n\tMismatch at index %d, want %s, got %s,", - i, DumpEvent(reference), DumpEvent(target)); - success = false; - } - } - - if (!success) { - *listener << StreamFormat("\n\tIn total of %s", DumpEvents(events, 2)); - } - return success; - } - - void DescribeTo(::std::ostream* const os) const override { - *os << StreamFormat("%s", DumpEvents(references_, 1)); - } - - void DescribeNegationTo(::std::ostream* const os) const override { - *os << StreamFormat("mismatch from %s", DumpEvents(references_, 1)); - } - - private: - std::vector<Event> references_; -}; - -::testing::Matcher<std::vector<Event>> Are(std::vector<Event> events) { - return MakeMatcher(new EventsAreMatcher(std::move(events))); -} - -// Similar to the EventsAre matcher, but the order of events are ignored. -class UnorderedEventsAreMatcher - : public ::testing::MatcherInterface<std::vector<Event>> { - public: - explicit UnorderedEventsAreMatcher(std::vector<Event> references) - : references_(std::move(references)) {} - - bool MatchAndExplain( - std::vector<Event> events, - ::testing::MatchResultListener* const listener) const override { - if (references_.size() != events.size()) { - *listener << StreamFormat("\n\tCount mismatch, got %s", - DumpEvents(events, 2)); - return false; - } - - std::vector<Event> unmatched(references_); - - for (const Event& candidate : events) { - for (auto it = unmatched.begin(); it != unmatched.end();) { - const Event& reference = *it; - if (candidate.mask == reference.mask && candidate.wd == reference.wd && - candidate.name == reference.name && - candidate.cookie == reference.cookie) { - it = unmatched.erase(it); - break; - } else { - ++it; - } - } - } - - // Anything left unmatched? If so, the matcher fails. - if (!unmatched.empty()) { - *listener << StreamFormat("\n\tFailed to match %s", - DumpEvents(unmatched, 2)); - *listener << StreamFormat("\n\tIn total of %s", DumpEvents(events, 2)); - return false; - } - - return true; - } - - void DescribeTo(::std::ostream* const os) const override { - *os << StreamFormat("unordered %s", DumpEvents(references_, 1)); - } - - void DescribeNegationTo(::std::ostream* const os) const override { - *os << StreamFormat("mismatch from unordered %s", - DumpEvents(references_, 1)); - } - - private: - std::vector<Event> references_; -}; - -::testing::Matcher<std::vector<Event>> AreUnordered(std::vector<Event> events) { - return MakeMatcher(new UnorderedEventsAreMatcher(std::move(events))); -} - -// Reads events from an inotify fd until either EOF, or read returns EAGAIN. -PosixErrorOr<std::vector<Event>> DrainEvents(int fd) { - std::vector<Event> events; - while (true) { - int events_size = 0; - if (ioctl(fd, FIONREAD, &events_size) < 0) { - return PosixError(errno, "ioctl(FIONREAD) failed on inotify fd"); - } - // Deliberately use a buffer that is larger than necessary, expecting to - // only read events_size bytes. - std::vector<char> buf(events_size + kBufSize, 0); - const ssize_t readlen = read(fd, buf.data(), buf.size()); - MaybeSave(); - // Read error? - if (readlen < 0) { - if (errno == EAGAIN) { - // If EAGAIN, no more events at the moment. Return what we have so far. - return events; - } - // Some other read error. Return an error. Right now if we encounter this - // after already reading some events, they get lost. However, we don't - // expect to see any error, and the calling test will fail immediately if - // we signal an error anyways, so this is acceptable. - return PosixError(errno, "read() failed on inotify fd"); - } - if (readlen < static_cast<int>(sizeof(struct inotify_event))) { - // Impossibly short read. - return PosixError( - EIO, - "read() didn't return enough data represent even a single event"); - } - if (readlen != events_size) { - return PosixError(EINVAL, absl::StrCat("read ", readlen, - " bytes, expected ", events_size)); - } - if (readlen == 0) { - // EOF. - return events; - } - - // Normal read. - const char* cursor = buf.data(); - while (cursor < (buf.data() + readlen)) { - struct inotify_event event = {}; - memcpy(&event, cursor, sizeof(struct inotify_event)); - - Event ev; - ev.wd = event.wd; - ev.mask = event.mask; - ev.cookie = event.cookie; - ev.len = event.len; - if (event.len > 0) { - TEST_CHECK(static_cast<int>(sizeof(struct inotify_event) + event.len) <= - readlen); - ev.name = std::string(cursor + - offsetof(struct inotify_event, name)); // NOLINT - // Name field should always be smaller than event.len, otherwise we have - // a buffer overflow. The two sizes aren't equal because the string - // constructor will stop at the first null byte, while event.name may be - // padded up to event.len using multiple null bytes. - TEST_CHECK(ev.name.size() <= event.len); - } - - events.push_back(ev); - cursor += sizeof(struct inotify_event) + event.len; - } - } -} - -PosixErrorOr<FileDescriptor> InotifyInit1(int flags) { - int fd; - EXPECT_THAT(fd = inotify_init1(flags), SyscallSucceeds()); - if (fd < 0) { - return PosixError(errno, "inotify_init1() failed"); - } - return FileDescriptor(fd); -} - -PosixErrorOr<int> InotifyAddWatch(int fd, const std::string& path, - uint32_t mask) { - int wd; - EXPECT_THAT(wd = inotify_add_watch(fd, path.c_str(), mask), - SyscallSucceeds()); - if (wd < 0) { - return PosixError(errno, "inotify_add_watch() failed"); - } - return wd; -} - -TEST(Inotify, InotifyFdNotWritable) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); - EXPECT_THAT(write(fd.get(), "x", 1), SyscallFailsWithErrno(EBADF)); -} - -TEST(Inotify, NonBlockingReadReturnsEagain) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - std::vector<char> buf(kBufSize, 0); - - // The read below should return fail with EAGAIN because there is no data to - // read and we've specified IN_NONBLOCK. We're guaranteed that there is no - // data to read because we haven't registered any watches yet. - EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST(Inotify, AddWatchOnInvalidFdFails) { - // Garbage fd. - EXPECT_THAT(inotify_add_watch(-1, "/tmp", IN_ALL_EVENTS), - SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(inotify_add_watch(1337, "/tmp", IN_ALL_EVENTS), - SyscallFailsWithErrno(EBADF)); - - // Non-inotify fds. - EXPECT_THAT(inotify_add_watch(0, "/tmp", IN_ALL_EVENTS), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(inotify_add_watch(1, "/tmp", IN_ALL_EVENTS), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(inotify_add_watch(2, "/tmp", IN_ALL_EVENTS), - SyscallFailsWithErrno(EINVAL)); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/tmp", O_RDONLY)); - EXPECT_THAT(inotify_add_watch(fd.get(), "/tmp", IN_ALL_EVENTS), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(Inotify, RemovingWatchGeneratesEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallSucceeds()); - - // Read events, ensure the first event is IN_IGNORED. - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - EXPECT_THAT(events, Are({Event(IN_IGNORED, wd)})); -} - -TEST(Inotify, CanDeleteFileAfterRemovingWatch) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallSucceeds()); - file1.reset(); -} - -TEST(Inotify, CanRemoveWatchAfterDeletingFile) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - file1.reset(); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd), Event(IN_DELETE_SELF, wd), - Event(IN_IGNORED, wd)})); - - EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallFailsWithErrno(EINVAL)); -} - -TEST(Inotify, DuplicateWatchRemovalFails) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallSucceeds()); - EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallFailsWithErrno(EINVAL)); -} - -TEST(Inotify, ConcurrentFileDeletionAndWatchRemoval) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const std::string filename = NewTempAbsPathInDir(root.path()); - - auto file_create_delete = [filename]() { - const DisableSave ds; // Too expensive. - for (int i = 0; i < 100; ++i) { - FileDescriptor file_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT, S_IRUSR | S_IWUSR)); - file_fd.reset(); // Close before unlinking (although save is disabled). - EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds()); - } - }; - - const int shared_fd = fd.get(); // We need to pass it to the thread. - auto add_remove_watch = [shared_fd, filename]() { - for (int i = 0; i < 100; ++i) { - int wd = inotify_add_watch(shared_fd, filename.c_str(), IN_ALL_EVENTS); - MaybeSave(); - if (wd != -1) { - // Watch added successfully, try removal. - if (inotify_rm_watch(shared_fd, wd)) { - // If removal fails, the only acceptable reason is if the wd - // is invalid, which will be the case if we try to remove - // the watch after the file has been deleted. - EXPECT_EQ(errno, EINVAL); - } - } else { - // Add watch failed, this should only fail if the target file doesn't - // exist. - EXPECT_EQ(errno, ENOENT); - } - } - }; - - ScopedThread t1(file_create_delete); - ScopedThread t2(add_remove_watch); -} - -TEST(Inotify, DeletingChildGeneratesEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - const std::string file1_path = file1.reset(); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT( - events, - AreUnordered({Event(IN_ATTRIB, file1_wd), Event(IN_DELETE_SELF, file1_wd), - Event(IN_IGNORED, file1_wd), - Event(IN_DELETE, root_wd, Basename(file1_path))})); -} - -TEST(Inotify, CreatingFileGeneratesEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - // Create a new file in the directory. - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - - // The library function we use to create the new file opens it for writing to - // create it and sets permissions on it, so we expect the three extra events. - ASSERT_THAT(events, Are({Event(IN_CREATE, wd, Basename(file1.path())), - Event(IN_OPEN, wd, Basename(file1.path())), - Event(IN_CLOSE_WRITE, wd, Basename(file1.path())), - Event(IN_ATTRIB, wd, Basename(file1.path()))})); -} - -TEST(Inotify, ReadingFileGeneratesAccessEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode)); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - char buf; - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ACCESS, wd, Basename(file1.path()))})); -} - -TEST(Inotify, WritingFileGeneratesModifyEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - const std::string data = "some content"; - EXPECT_THAT(write(file1_fd.get(), data.c_str(), data.length()), - SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_MODIFY, wd, Basename(file1.path()))})); -} - -TEST(Inotify, WatchSetAfterOpenReportsCloseFdEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - FileDescriptor file1_fd_writable = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - FileDescriptor file1_fd_not_writable = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - file1_fd_writable.reset(); // Close file1_fd_writable. - std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_CLOSE_WRITE, wd, Basename(file1.path()))})); - - file1_fd_not_writable.reset(); // Close file1_fd_not_writable. - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, - Are({Event(IN_CLOSE_NOWRITE, wd, Basename(file1.path()))})); -} - -TEST(Inotify, ChildrenDeletionInWatchedDirGeneratesEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - TempPath dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); - - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - const std::string file1_path = file1.reset(); - const std::string dir1_path = dir1.release(); - EXPECT_THAT(rmdir(dir1_path.c_str()), SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - - ASSERT_THAT(events, - Are({Event(IN_DELETE, wd, Basename(file1_path)), - Event(IN_DELETE | IN_ISDIR, wd, Basename(dir1_path))})); -} - -TEST(Inotify, WatchTargetDeletionGeneratesEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - EXPECT_THAT(rmdir(root.path().c_str()), SyscallSucceeds()); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_DELETE_SELF, wd), Event(IN_IGNORED, wd)})); -} - -TEST(Inotify, MoveGeneratesEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const TempPath dir1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); - const TempPath dir2 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); - - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int dir1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), dir1.path(), IN_ALL_EVENTS)); - const int dir2_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), dir2.path(), IN_ALL_EVENTS)); - // Test move from root -> root. - std::string newpath = NewTempAbsPathInDir(root.path()); - std::string oldpath = file1.release(); - EXPECT_THAT(rename(oldpath.c_str(), newpath.c_str()), SyscallSucceeds()); - file1.reset(newpath); - std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT( - events, - Are({Event(IN_MOVED_FROM, root_wd, Basename(oldpath), events[0].cookie), - Event(IN_MOVED_TO, root_wd, Basename(newpath), events[1].cookie)})); - EXPECT_NE(events[0].cookie, 0); - EXPECT_EQ(events[0].cookie, events[1].cookie); - uint32_t last_cookie = events[0].cookie; - - // Test move from root -> root/dir1. - newpath = NewTempAbsPathInDir(dir1.path()); - oldpath = file1.release(); - EXPECT_THAT(rename(oldpath.c_str(), newpath.c_str()), SyscallSucceeds()); - file1.reset(newpath); - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT( - events, - Are({Event(IN_MOVED_FROM, root_wd, Basename(oldpath), events[0].cookie), - Event(IN_MOVED_TO, dir1_wd, Basename(newpath), events[1].cookie)})); - // Cookies should be distinct between distinct rename events. - EXPECT_NE(events[0].cookie, last_cookie); - EXPECT_EQ(events[0].cookie, events[1].cookie); - last_cookie = events[0].cookie; - - // Test move from root/dir1 -> root/dir2. - newpath = NewTempAbsPathInDir(dir2.path()); - oldpath = file1.release(); - EXPECT_THAT(rename(oldpath.c_str(), newpath.c_str()), SyscallSucceeds()); - file1.reset(newpath); - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT( - events, - Are({Event(IN_MOVED_FROM, dir1_wd, Basename(oldpath), events[0].cookie), - Event(IN_MOVED_TO, dir2_wd, Basename(newpath), events[1].cookie)})); - EXPECT_NE(events[0].cookie, last_cookie); - EXPECT_EQ(events[0].cookie, events[1].cookie); - last_cookie = events[0].cookie; -} - -TEST(Inotify, MoveWatchedTargetGeneratesEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - const std::string newpath = NewTempAbsPathInDir(root.path()); - const std::string oldpath = file1.release(); - EXPECT_THAT(rename(oldpath.c_str(), newpath.c_str()), SyscallSucceeds()); - file1.reset(newpath); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT( - events, - Are({Event(IN_MOVED_FROM, root_wd, Basename(oldpath), events[0].cookie), - Event(IN_MOVED_TO, root_wd, Basename(newpath), events[1].cookie), - // Self move events do not have a cookie. - Event(IN_MOVE_SELF, file1_wd)})); - EXPECT_NE(events[0].cookie, 0); - EXPECT_EQ(events[0].cookie, events[1].cookie); -} - -TEST(Inotify, CoalesceEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode)); - - FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - // Read the file a few times. This will would generate multiple IN_ACCESS - // events but they should get coalesced to a single event. - char buf; - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - - // Use the close event verify that we haven't simply left the additional - // IN_ACCESS events unread. - file1_fd.reset(); // Close file1_fd. - - const std::string file1_name = std::string(Basename(file1.path())); - std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ACCESS, wd, file1_name), - Event(IN_CLOSE_NOWRITE, wd, file1_name)})); - - // Now let's try interleaving other events into a stream of repeated events. - file1_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR)); - - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(write(file1_fd.get(), "x", 1), SyscallSucceeds()); - EXPECT_THAT(write(file1_fd.get(), "x", 1), SyscallSucceeds()); - EXPECT_THAT(write(file1_fd.get(), "x", 1), SyscallSucceeds()); - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - - file1_fd.reset(); // Close the file. - - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT( - events, - Are({Event(IN_OPEN, wd, file1_name), Event(IN_ACCESS, wd, file1_name), - Event(IN_MODIFY, wd, file1_name), Event(IN_ACCESS, wd, file1_name), - Event(IN_CLOSE_WRITE, wd, file1_name)})); - - // Ensure events aren't coalesced if they are from different files. - const TempPath file2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode)); - // Discard events resulting from creation of file2. - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - - file1_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - FileDescriptor file2_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file2.path(), O_RDONLY)); - - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(read(file2_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - - // Close both files. - file1_fd.reset(); - file2_fd.reset(); - - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - const std::string file2_name = std::string(Basename(file2.path())); - ASSERT_THAT( - events, - Are({Event(IN_OPEN, wd, file1_name), Event(IN_OPEN, wd, file2_name), - Event(IN_ACCESS, wd, file1_name), Event(IN_ACCESS, wd, file2_name), - Event(IN_ACCESS, wd, file1_name), - Event(IN_CLOSE_NOWRITE, wd, file1_name), - Event(IN_CLOSE_NOWRITE, wd, file2_name)})); -} - -TEST(Inotify, ClosingInotifyFdWithoutRemovingWatchesWorks) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - // Note: The check on close will happen in FileDescriptor::~FileDescriptor(). -} - -TEST(Inotify, NestedWatches) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode)); - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - // Read from file1. This should generate an event for both watches. - char buf; - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ACCESS, root_wd, Basename(file1.path())), - Event(IN_ACCESS, file1_wd)})); -} - -TEST(Inotify, ConcurrentThreadsGeneratingEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - std::vector<TempPath> files; - files.reserve(10); - for (int i = 0; i < 10; i++) { - files.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode))); - } - - auto test_thread = [&files]() { - uint32_t seed = time(nullptr); - for (int i = 0; i < 20; i++) { - const TempPath& file = files[rand_r(&seed) % files.size()]; - const FileDescriptor file_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); - TEST_PCHECK(write(file_fd.get(), "x", 1) == 1); - } - }; - - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - std::list<ScopedThread> threads; - for (int i = 0; i < 3; i++) { - threads.emplace_back(test_thread); - } - for (auto& t : threads) { - t.Join(); - } - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - // 3 threads doing 20 iterations, 3 events per iteration (open, write, - // close). However, some events may be coalesced, and we can't reliably - // predict how they'll be coalesced since the test threads aren't - // synchronized. We can only check that we aren't getting unexpected events. - for (const Event& ev : events) { - EXPECT_NE(ev.mask & (IN_OPEN | IN_MODIFY | IN_CLOSE_WRITE), 0); - } -} - -TEST(Inotify, ReadWithTooSmallBufferFails) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - // Open the file to queue an event. This event will not have a filename, so - // reading from the inotify fd should return sizeof(struct inotify_event) - // bytes of data. - FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - std::vector<char> buf(kBufSize, 0); - ssize_t readlen; - - // Try a buffer too small to hold any potential event. This is rejected - // outright without the event being dequeued. - EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event) - 1), - SyscallFailsWithErrno(EINVAL)); - // Try a buffer just large enough. This should succeeed. - EXPECT_THAT( - readlen = read(fd.get(), buf.data(), sizeof(struct inotify_event)), - SyscallSucceeds()); - EXPECT_EQ(readlen, sizeof(struct inotify_event)); - // Event queue is now empty, the next read should return EAGAIN. - EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event)), - SyscallFailsWithErrno(EAGAIN)); - - // Now put a watch on the directory, so that generated events contain a name. - EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallSucceeds()); - - // Drain the event generated from the watch removal. - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - file1_fd.reset(); // Close file to generate an event. - - // Try a buffer too small to hold any event and one too small to hold an event - // with a name. These should both fail without consuming the event. - EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event) - 1), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event)), - SyscallFailsWithErrno(EINVAL)); - // Now try with a large enough buffer. This should return the one event. - EXPECT_THAT(readlen = read(fd.get(), buf.data(), buf.size()), - SyscallSucceeds()); - EXPECT_GE(readlen, - sizeof(struct inotify_event) + Basename(file1.path()).size()); - // With the single event read, the queue should once again be empty. - EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST(Inotify, BlockingReadOnInotifyFd) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); - const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode)); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - // Spawn a thread performing a blocking read for new events on the inotify fd. - std::vector<char> buf(kBufSize, 0); - const int shared_fd = fd.get(); // The thread needs it. - ScopedThread t([shared_fd, &buf]() { - ssize_t readlen; - EXPECT_THAT(readlen = read(shared_fd, buf.data(), buf.size()), - SyscallSucceeds()); - }); - - // Perform a read on the watched file, which should generate an IN_ACCESS - // event, unblocking the event_reader thread. - char c; - EXPECT_THAT(read(file1_fd.get(), &c, 1), SyscallSucceeds()); - - // Wait for the thread to read the event and exit. - t.Join(); - - // Make sure the event we got back is sane. - uint32_t event_mask; - memcpy(&event_mask, buf.data() + offsetof(struct inotify_event, mask), - sizeof(event_mask)); - EXPECT_EQ(event_mask, IN_ACCESS); -} - -TEST(Inotify, WatchOnRelativePath) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode)); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - - // Change working directory to root. - const FileDescriptor cwd = ASSERT_NO_ERRNO_AND_VALUE(Open(".", O_PATH)); - EXPECT_THAT(chdir(root.path().c_str()), SyscallSucceeds()); - - // Add a watch on file1 with a relative path. - const int wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( - fd.get(), std::string(Basename(file1.path())), IN_ALL_EVENTS)); - - // Perform a read on file1, this should generate an IN_ACCESS event. - char c; - EXPECT_THAT(read(file1_fd.get(), &c, 1), SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - EXPECT_THAT(events, Are({Event(IN_ACCESS, wd)})); - - // Explicitly reset the working directory so that we don't continue to - // reference "root". Once the test ends, "root" will get unlinked. If we - // continue to hold a reference, random save/restore tests can fail if a save - // is triggered after "root" is unlinked; we can't save deleted fs objects - // with active references. - EXPECT_THAT(fchdir(cwd.get()), SyscallSucceeds()); -} - -TEST(Inotify, ZeroLengthReadWriteDoesNotGenerateEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const char kContent[] = "some content"; - TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), kContent, TempPath::kDefaultFileMode)); - const int kContentSize = sizeof(kContent) - 1; - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR)); - - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - std::vector<char> buf(kContentSize, 0); - // Read all available data. - ssize_t readlen; - EXPECT_THAT(readlen = read(file1_fd.get(), buf.data(), kContentSize), - SyscallSucceeds()); - EXPECT_EQ(readlen, kContentSize); - // Drain all events and make sure we got the IN_ACCESS for the read. - std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - EXPECT_THAT(events, Are({Event(IN_ACCESS, wd, Basename(file1.path()))})); - - // Now try read again. This should be a 0-length read, since we're at EOF. - char c; - EXPECT_THAT(readlen = read(file1_fd.get(), &c, 1), SyscallSucceeds()); - EXPECT_EQ(readlen, 0); - // We should have no new events. - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - EXPECT_TRUE(events.empty()); - - // Try issuing a zero-length read. - EXPECT_THAT(readlen = read(file1_fd.get(), &c, 0), SyscallSucceeds()); - EXPECT_EQ(readlen, 0); - // We should have no new events. - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - EXPECT_TRUE(events.empty()); - - // Try issuing a zero-length write. - ssize_t writelen; - EXPECT_THAT(writelen = write(file1_fd.get(), &c, 0), SyscallSucceeds()); - EXPECT_EQ(writelen, 0); - // We should have no new events. - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - EXPECT_TRUE(events.empty()); -} - -TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - FileDescriptor root_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(root.path(), O_RDONLY)); - FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR)); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - auto verify_chmod_events = [&]() { - std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ATTRIB, root_wd, Basename(file1.path())), - Event(IN_ATTRIB, file1_wd)})); - }; - - // Don't do cooperative S/R tests for any of the {f}chmod* syscalls below, the - // test will always fail because nodes cannot be saved when they have stricter - // permissions than the original host node. - const DisableSave ds; - - // Chmod. - ASSERT_THAT(chmod(file1.path().c_str(), S_IWGRP), SyscallSucceeds()); - verify_chmod_events(); - - // Fchmod. - ASSERT_THAT(fchmod(file1_fd.get(), S_IRGRP | S_IWGRP), SyscallSucceeds()); - verify_chmod_events(); - - // Fchmodat. - const std::string file1_basename = std::string(Basename(file1.path())); - ASSERT_THAT(fchmodat(root_fd.get(), file1_basename.c_str(), S_IWGRP, 0), - SyscallSucceeds()); - verify_chmod_events(); - - // Make sure the chmod'ed file descriptors are destroyed before DisableSave - // is destructed. - root_fd.reset(); - file1_fd.reset(); -} - -TEST(Inotify, TruncateGeneratesModifyEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - auto verify_truncate_events = [&]() { - std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_MODIFY, root_wd, Basename(file1.path())), - Event(IN_MODIFY, file1_wd)})); - }; - - // Truncate. - EXPECT_THAT(truncate(file1.path().c_str(), 4096), SyscallSucceeds()); - verify_truncate_events(); - - // Ftruncate. - EXPECT_THAT(ftruncate(file1_fd.get(), 8192), SyscallSucceeds()); - verify_truncate_events(); - - // No events if truncate fails. - EXPECT_THAT(ftruncate(file1_fd.get(), -1), SyscallFailsWithErrno(EINVAL)); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({})); -} - -TEST(Inotify, GetdentsGeneratesAccessEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - // This internally calls getdents(2). We also expect to see an open/close - // event for the dirfd. - ASSERT_NO_ERRNO_AND_VALUE(ListDir(root.path(), false)); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - - // Linux only seems to generate access events on getdents() on some - // calls. Allow the test to pass even if it isn't generated. gVisor will - // always generate the IN_ACCESS event so the test will at least ensure gVisor - // behaves reasonably. - int i = 0; - EXPECT_EQ(events[i].mask, IN_OPEN | IN_ISDIR); - ++i; - if (IsRunningOnGvisor()) { - EXPECT_EQ(events[i].mask, IN_ACCESS | IN_ISDIR); - ++i; - } else { - if (events[i].mask == (IN_ACCESS | IN_ISDIR)) { - // Skip over the IN_ACCESS event on Linux, it only shows up some of the - // time so we can't assert its existence. - ++i; - } - } - EXPECT_EQ(events[i].mask, IN_CLOSE_NOWRITE | IN_ISDIR); -} - -TEST(Inotify, MknodGeneratesCreateEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - const TempPath file1(root.path() + "/file1"); - const int rc = mknod(file1.path().c_str(), S_IFREG, 0); - // mknod(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0); - ASSERT_THAT(rc, SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_CREATE, wd, Basename(file1.path()))})); -} - -TEST(Inotify, SymlinkGeneratesCreateEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const TempPath link1(NewTempAbsPathInDir(root.path())); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - ASSERT_THAT(symlink(file1.path().c_str(), link1.path().c_str()), - SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - - ASSERT_THAT(events, Are({Event(IN_CREATE, root_wd, Basename(link1.path()))})); -} - -TEST(Inotify, LinkGeneratesAttribAndCreateEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const TempPath link1(root.path() + "/link1"); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - const int rc = link(file1.path().c_str(), link1.path().c_str()); - // link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ATTRIB, file1_wd), - Event(IN_CREATE, root_wd, Basename(link1.path()))})); -} - -TEST(Inotify, UtimesGeneratesAttribEvent) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - struct timeval times[2] = {{1, 0}, {2, 0}}; - EXPECT_THAT(futimes(file1_fd.get(), times), SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ATTRIB, wd, Basename(file1.path()))})); -} - -TEST(Inotify, HardlinksReuseSameWatch) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - TempPath link1(root.path() + "/link1"); - const int rc = link(file1.path().c_str(), link1.path().c_str()); - // link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - const int link1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), link1.path(), IN_ALL_EVENTS)); - - // The watch descriptors for watches on different links to the same file - // should be identical. - EXPECT_NE(root_wd, file1_wd); - EXPECT_EQ(file1_wd, link1_wd); - - FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - - std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, - AreUnordered({Event(IN_OPEN, root_wd, Basename(file1.path())), - Event(IN_OPEN, file1_wd)})); - - // For the next step, we want to ensure all fds to the file are closed. Do - // that now and drain the resulting events. - file1_fd.reset(); - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, - Are({Event(IN_CLOSE_WRITE, root_wd, Basename(file1.path())), - Event(IN_CLOSE_WRITE, file1_wd)})); - - // Try removing the link and let's see what events show up. Note that after - // this, we still have a link to the file so the watch shouldn't be - // automatically removed. - const std::string link1_path = link1.reset(); - - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ATTRIB, link1_wd), - Event(IN_DELETE, root_wd, Basename(link1_path))})); - - // Now remove the other link. Since this is the last link to the file, the - // watch should be automatically removed. - const std::string file1_path = file1.reset(); - - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT( - events, - AreUnordered({Event(IN_ATTRIB, file1_wd), Event(IN_DELETE_SELF, file1_wd), - Event(IN_IGNORED, file1_wd), - Event(IN_DELETE, root_wd, Basename(file1_path))})); -} - -TEST(Inotify, MkdirGeneratesCreateEventWithDirFlag) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - - const TempPath dir1(NewTempAbsPathInDir(root.path())); - ASSERT_THAT(mkdir(dir1.path().c_str(), 0777), SyscallSucceeds()); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT( - events, - Are({Event(IN_CREATE | IN_ISDIR, root_wd, Basename(dir1.path()))})); -} - -TEST(Inotify, MultipleInotifyInstancesAndWatchesAllGetEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - constexpr int kNumFds = 30; - std::vector<FileDescriptor> inotify_fds; - - for (int i = 0; i < kNumFds; ++i) { - const DisableSave ds; // Too expensive. - inotify_fds.emplace_back( - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK))); - const FileDescriptor& fd = inotify_fds[inotify_fds.size() - 1]; // Back. - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - } - - const std::string data = "some content"; - EXPECT_THAT(write(file1_fd.get(), data.c_str(), data.length()), - SyscallSucceeds()); - - for (const FileDescriptor& fd : inotify_fds) { - const DisableSave ds; // Too expensive. - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - if (events.size() >= 2) { - EXPECT_EQ(events[0].mask, IN_MODIFY); - EXPECT_EQ(events[0].wd, 1); - EXPECT_EQ(events[0].name, Basename(file1.path())); - EXPECT_EQ(events[1].mask, IN_MODIFY); - EXPECT_EQ(events[1].wd, 2); - EXPECT_EQ(events[1].name, ""); - } - } -} - -TEST(Inotify, EventsGoUpAtMostOneLevel) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath dir1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int dir1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), dir1.path(), IN_ALL_EVENTS)); - - const std::string file1_path = file1.reset(); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_DELETE, dir1_wd, Basename(file1_path))})); -} - -TEST(Inotify, DuplicateWatchReturnsSameWatchDescriptor) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int wd1 = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - const int wd2 = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - EXPECT_EQ(wd1, wd2); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - // The watch shouldn't be duplicated, we only expect one event. - ASSERT_THAT(events, Are({Event(IN_OPEN, wd1)})); -} - -TEST(Inotify, UnmatchedEventsAreDiscarded) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(fd.get(), file1.path(), IN_ACCESS)); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - // We only asked for access events, the open event should be discarded. - ASSERT_THAT(events, Are({})); -} - -TEST(Inotify, AddWatchWithInvalidEventMaskFails) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - EXPECT_THAT(inotify_add_watch(fd.get(), root.path().c_str(), 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(Inotify, AddWatchOnInvalidPathFails) { - const TempPath nonexistent(NewTempAbsPath()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - // Non-existent path. - EXPECT_THAT( - inotify_add_watch(fd.get(), nonexistent.path().c_str(), IN_CREATE), - SyscallFailsWithErrno(ENOENT)); - - // Garbage path pointer. - EXPECT_THAT(inotify_add_watch(fd.get(), nullptr, IN_CREATE), - SyscallFailsWithErrno(EFAULT)); -} - -TEST(Inotify, InOnlyDirFlagRespected) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - EXPECT_THAT( - inotify_add_watch(fd.get(), root.path().c_str(), IN_ACCESS | IN_ONLYDIR), - SyscallSucceeds()); - - EXPECT_THAT( - inotify_add_watch(fd.get(), file1.path().c_str(), IN_ACCESS | IN_ONLYDIR), - SyscallFailsWithErrno(ENOTDIR)); -} - -TEST(Inotify, MaskAddMergesWithExistingEventMask) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_OPEN | IN_CLOSE_WRITE)); - - const std::string data = "some content"; - EXPECT_THAT(write(file1_fd.get(), data.c_str(), data.length()), - SyscallSucceeds()); - - // We shouldn't get any events, since IN_MODIFY wasn't in the event mask. - std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({})); - - // Add IN_MODIFY to event mask. - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_MODIFY | IN_MASK_ADD)); - - EXPECT_THAT(write(file1_fd.get(), data.c_str(), data.length()), - SyscallSucceeds()); - - // This time we should get the modify event. - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_MODIFY, wd)})); - - // Now close the fd. If the modify event was added to the event mask rather - // than replacing the event mask we won't get the close event. - file1_fd.reset(); - events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_CLOSE_WRITE, wd)})); -} - -// Test that control events bits are not considered when checking event mask. -TEST(Inotify, ControlEvents) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), dir.path(), IN_ACCESS)); - - // Check that events in the mask are dispatched and that control bits are - // part of the event mask. - std::vector<std::string> files = - ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), false)); - ASSERT_EQ(files.size(), 2); - - const std::vector<Event> events1 = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events1, Are({Event(IN_ACCESS | IN_ISDIR, wd)})); - - // Check that events not in the mask are discarded. - const FileDescriptor dir_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); - - const std::vector<Event> events2 = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events2, Are({})); -} - -// Regression test to ensure epoll and directory access doesn't deadlock. -TEST(Inotify, EpollNoDeadlock) { - const DisableSave ds; // Too many syscalls. - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - - // Create lots of directories and watch all of them. - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - std::vector<TempPath> children; - for (size_t i = 0; i < 1000; ++i) { - auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); - ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), child.path(), IN_ACCESS)); - children.emplace_back(std::move(child)); - } - - // Run epoll_wait constantly in a separate thread. - std::atomic<bool> done(false); - ScopedThread th([&fd, &done] { - for (auto start = absl::Now(); absl::Now() - start < absl::Seconds(5);) { - FileDescriptor epoll_fd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); - ASSERT_NO_ERRNO(RegisterEpollFD(epoll_fd.get(), fd.get(), - EPOLLIN | EPOLLOUT | EPOLLET, 0)); - struct epoll_event result[1]; - EXPECT_THAT(RetryEINTR(epoll_wait)(epoll_fd.get(), result, 1, -1), - SyscallSucceedsWithValue(1)); - - sched_yield(); - } - done = true; - }); - - // While epoll thread is running, constantly access all directories to - // generate inotify events. - while (!done) { - std::vector<std::string> files = - ASSERT_NO_ERRNO_AND_VALUE(ListDir(root.path(), false)); - ASSERT_EQ(files.size(), 1002); - for (const auto& child : files) { - if (child == "." || child == "..") { - continue; - } - ASSERT_NO_ERRNO_AND_VALUE(ListDir(JoinPath(root.path(), child), false)); - } - sched_yield(); - } -} - -TEST(Inotify, SpliceEvent) { - int pipes[2]; - ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds()); - - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode)); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - const int watcher = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - char buf; - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - - EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr, - sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK), - SyscallSucceedsWithValue(sizeof(struct inotify_event))); - - const FileDescriptor read_fd(pipes[0]); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get())); - ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)})); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc deleted file mode 100644 index b0a07a064..000000000 --- a/test/syscalls/linux/ioctl.cc +++ /dev/null @@ -1,406 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <errno.h> -#include <fcntl.h> -#include <net/if.h> -#include <netdb.h> -#include <signal.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -bool CheckNonBlocking(int fd) { - int ret = fcntl(fd, F_GETFL, 0); - TEST_CHECK(ret != -1); - return (ret & O_NONBLOCK) == O_NONBLOCK; -} - -bool CheckCloExec(int fd) { - int ret = fcntl(fd, F_GETFD, 0); - TEST_CHECK(ret != -1); - return (ret & FD_CLOEXEC) == FD_CLOEXEC; -} - -class IoctlTest : public ::testing::Test { - protected: - void SetUp() override { - ASSERT_THAT(fd_ = open("/dev/null", O_RDONLY), SyscallSucceeds()); - } - - void TearDown() override { - if (fd_ >= 0) { - ASSERT_THAT(close(fd_), SyscallSucceeds()); - fd_ = -1; - } - } - - int fd() const { return fd_; } - - private: - int fd_ = -1; -}; - -TEST_F(IoctlTest, BadFileDescriptor) { - EXPECT_THAT(ioctl(-1 /* fd */, 0), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(IoctlTest, InvalidControlNumber) { - EXPECT_THAT(ioctl(STDOUT_FILENO, 0), SyscallFailsWithErrno(ENOTTY)); -} - -TEST_F(IoctlTest, FIONBIOSucceeds) { - EXPECT_FALSE(CheckNonBlocking(fd())); - int set = 1; - EXPECT_THAT(ioctl(fd(), FIONBIO, &set), SyscallSucceeds()); - EXPECT_TRUE(CheckNonBlocking(fd())); - set = 0; - EXPECT_THAT(ioctl(fd(), FIONBIO, &set), SyscallSucceeds()); - EXPECT_FALSE(CheckNonBlocking(fd())); -} - -TEST_F(IoctlTest, FIONBIOFails) { - EXPECT_THAT(ioctl(fd(), FIONBIO, nullptr), SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(IoctlTest, FIONCLEXSucceeds) { - EXPECT_THAT(ioctl(fd(), FIONCLEX), SyscallSucceeds()); - EXPECT_FALSE(CheckCloExec(fd())); -} - -TEST_F(IoctlTest, FIOCLEXSucceeds) { - EXPECT_THAT(ioctl(fd(), FIOCLEX), SyscallSucceeds()); - EXPECT_TRUE(CheckCloExec(fd())); -} - -TEST_F(IoctlTest, FIOASYNCFails) { - EXPECT_THAT(ioctl(fd(), FIOASYNC, nullptr), SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(IoctlTest, FIOASYNCSucceeds) { - // Not all FDs support FIOASYNC. - const FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - int before = -1; - ASSERT_THAT(before = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - - int set = 1; - EXPECT_THAT(ioctl(s.get(), FIOASYNC, &set), SyscallSucceeds()); - - int after_set = -1; - ASSERT_THAT(after_set = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - EXPECT_EQ(after_set, before | O_ASYNC) << "before was " << before; - - set = 0; - EXPECT_THAT(ioctl(s.get(), FIOASYNC, &set), SyscallSucceeds()); - - ASSERT_THAT(fcntl(s.get(), F_GETFL), SyscallSucceedsWithValue(before)); -} - -/* Count of the number of SIGIOs handled. */ -static volatile int io_received = 0; - -void inc_io_handler(int sig, siginfo_t* siginfo, void* arg) { io_received++; } - -TEST_F(IoctlTest, FIOASYNCNoTarget) { - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - // Count SIGIOs received. - io_received = 0; - struct sigaction sa; - sa.sa_sigaction = inc_io_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_RESTART; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa)); - - // Actually allow SIGIO delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO)); - - int set = 1; - EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds()); - - constexpr char kData[] = "abc"; - ASSERT_THAT(WriteFd(pair->first_fd(), kData, sizeof(kData)), - SyscallSucceedsWithValue(sizeof(kData))); - - EXPECT_EQ(io_received, 0); -} - -TEST_F(IoctlTest, FIOASYNCSelfTarget) { - // FIXME(b/120624367): gVisor erroneously sends SIGIO on close(2), which would - // kill the test when pair goes out of scope. Temporarily ignore SIGIO so that - // that the close signal is ignored. - struct sigaction sa; - sa.sa_handler = SIG_IGN; - auto early_sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - // Count SIGIOs received. - io_received = 0; - sa.sa_sigaction = inc_io_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_RESTART; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa)); - - // Actually allow SIGIO delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO)); - - int set = 1; - EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds()); - - pid_t pid = getpid(); - EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds()); - - constexpr char kData[] = "abc"; - ASSERT_THAT(WriteFd(pair->first_fd(), kData, sizeof(kData)), - SyscallSucceedsWithValue(sizeof(kData))); - - EXPECT_EQ(io_received, 1); -} - -// Equivalent to FIOASYNCSelfTarget except that FIOSETOWN is called before -// FIOASYNC. -TEST_F(IoctlTest, FIOASYNCSelfTarget2) { - // FIXME(b/120624367): gVisor erroneously sends SIGIO on close(2), which would - // kill the test when pair goes out of scope. Temporarily ignore SIGIO so that - // that the close signal is ignored. - struct sigaction sa; - sa.sa_handler = SIG_IGN; - auto early_sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - // Count SIGIOs received. - io_received = 0; - sa.sa_sigaction = inc_io_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_RESTART; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa)); - - // Actually allow SIGIO delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO)); - - pid_t pid = -1; - EXPECT_THAT(pid = getpid(), SyscallSucceeds()); - EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds()); - - int set = 1; - EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds()); - - constexpr char kData[] = "abc"; - ASSERT_THAT(WriteFd(pair->first_fd(), kData, sizeof(kData)), - SyscallSucceedsWithValue(sizeof(kData))); - - EXPECT_EQ(io_received, 1); -} - -// Check that closing an FD does not result in an event. -TEST_F(IoctlTest, FIOASYNCSelfTargetClose) { - // Count SIGIOs received. - struct sigaction sa; - io_received = 0; - sa.sa_sigaction = inc_io_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_RESTART; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa)); - - // Actually allow SIGIO delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO)); - - for (int i = 0; i < 2; i++) { - auto pair = ASSERT_NO_ERRNO_AND_VALUE( - UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - pid_t pid = getpid(); - EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds()); - - int set = 1; - EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds()); - } - - // FIXME(b/120624367): gVisor erroneously sends SIGIO on close. - SKIP_IF(IsRunningOnGvisor()); - - EXPECT_EQ(io_received, 0); -} - -TEST_F(IoctlTest, FIOASYNCInvalidPID) { - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int set = 1; - ASSERT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds()); - pid_t pid = INT_MAX; - // This succeeds (with behavior equivalent to a pid of 0) in Linux prior to - // f73127356f34 "fs/fcntl: return -ESRCH in f_setown when pid/pgid can't be - // found", and fails with EPERM after that commit. - EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), - AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ESRCH))); -} - -TEST_F(IoctlTest, FIOASYNCUnsetTarget) { - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - // Count SIGIOs received. - io_received = 0; - struct sigaction sa; - sa.sa_sigaction = inc_io_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_RESTART; - auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa)); - - // Actually allow SIGIO delivery. - auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO)); - - int set = 1; - EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds()); - - pid_t pid = getpid(); - EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds()); - - // Passing a PID of 0 unsets the target. - pid = 0; - EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds()); - - constexpr char kData[] = "abc"; - ASSERT_THAT(WriteFd(pair->first_fd(), kData, sizeof(kData)), - SyscallSucceedsWithValue(sizeof(kData))); - - EXPECT_EQ(io_received, 0); -} - -using IoctlTestSIOCGIFCONF = SimpleSocketTest; - -TEST_P(IoctlTestSIOCGIFCONF, ValidateNoArrayGetsLength) { - auto fd = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Validate that no array can be used to get the length required. - struct ifconf ifconf = {}; - ASSERT_THAT(ioctl(fd->get(), SIOCGIFCONF, &ifconf), SyscallSucceeds()); - ASSERT_GT(ifconf.ifc_len, 0); -} - -// This test validates that we will only return a partial array list and not -// partial ifrreq structs. -TEST_P(IoctlTestSIOCGIFCONF, ValidateNoPartialIfrsReturned) { - auto fd = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - struct ifreq ifr = {}; - struct ifconf ifconf = {}; - ifconf.ifc_len = sizeof(ifr) - 1; // One byte too few. - ifconf.ifc_ifcu.ifcu_req = 𝔦 - - ASSERT_THAT(ioctl(fd->get(), SIOCGIFCONF, &ifconf), SyscallSucceeds()); - ASSERT_EQ(ifconf.ifc_len, 0); - ASSERT_EQ(ifr.ifr_name[0], '\0'); // Nothing is returned. - - ifconf.ifc_len = sizeof(ifreq); - ASSERT_THAT(ioctl(fd->get(), SIOCGIFCONF, &ifconf), SyscallSucceeds()); - ASSERT_GT(ifconf.ifc_len, 0); - ASSERT_NE(ifr.ifr_name[0], '\0'); // An interface can now be returned. -} - -TEST_P(IoctlTestSIOCGIFCONF, ValidateLoopbackIsPresent) { - auto fd = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - struct ifconf ifconf = {}; - struct ifreq ifr[10] = {}; // Storage for up to 10 interfaces. - - ifconf.ifc_req = ifr; - ifconf.ifc_len = sizeof(ifr); - - ASSERT_THAT(ioctl(fd->get(), SIOCGIFCONF, &ifconf), SyscallSucceeds()); - size_t num_if = ifconf.ifc_len / sizeof(struct ifreq); - - // We should have at least one interface. - ASSERT_GE(num_if, 1); - - // One of the interfaces should be a loopback. - bool found_loopback = false; - for (size_t i = 0; i < num_if; ++i) { - if (strcmp(ifr[i].ifr_name, "lo") == 0) { - // SIOCGIFCONF returns the ipv4 address of the interface, let's check it. - ASSERT_EQ(ifr[i].ifr_addr.sa_family, AF_INET); - - // Validate the address is correct for loopback. - sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(&ifr[i].ifr_addr); - ASSERT_EQ(htonl(sin->sin_addr.s_addr), INADDR_LOOPBACK); - - found_loopback = true; - break; - } - } - ASSERT_TRUE(found_loopback); -} - -std::vector<SocketKind> IoctlSocketTypes() { - return {SimpleSocket(AF_UNIX, SOCK_STREAM, 0), - SimpleSocket(AF_UNIX, SOCK_DGRAM, 0), - SimpleSocket(AF_INET, SOCK_STREAM, 0), - SimpleSocket(AF_INET6, SOCK_STREAM, 0), - SimpleSocket(AF_INET, SOCK_DGRAM, 0), - SimpleSocket(AF_INET6, SOCK_DGRAM, 0)}; -} - -INSTANTIATE_TEST_SUITE_P(IoctlTest, IoctlTestSIOCGIFCONF, - ::testing::ValuesIn(IoctlSocketTypes())); - -} // namespace - -TEST_F(IoctlTest, FIOGETOWNSucceeds) { - const FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - int get = -1; - ASSERT_THAT(ioctl(s.get(), FIOGETOWN, &get), SyscallSucceeds()); - EXPECT_EQ(get, 0); -} - -TEST_F(IoctlTest, SIOCGPGRPSucceeds) { - const FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - - int get = -1; - ASSERT_THAT(ioctl(s.get(), SIOCGPGRP, &get), SyscallSucceeds()); - EXPECT_EQ(get, 0); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc deleted file mode 100644 index bba022a41..000000000 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/ip_socket_test_util.h" - -#include <net/if.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> - -#include <cstring> - -namespace gvisor { -namespace testing { - -uint32_t IPFromInetSockaddr(const struct sockaddr* addr) { - auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr); - return in_addr->sin_addr.s_addr; -} - -uint16_t PortFromInetSockaddr(const struct sockaddr* addr) { - auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr); - return ntohs(in_addr->sin_port); -} - -PosixErrorOr<int> InterfaceIndex(std::string name) { - // TODO(igudger): Consider using netlink. - ifreq req = {}; - memcpy(req.ifr_name, name.c_str(), name.size()); - ASSIGN_OR_RETURN_ERRNO(auto sock, Socket(AF_INET, SOCK_DGRAM, 0)); - RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(sock.get(), SIOCGIFINDEX, &req)); - return req.ifr_ifindex; -} - -namespace { - -std::string DescribeSocketType(int type) { - return absl::StrCat(((type & SOCK_NONBLOCK) != 0) ? "non-blocking " : "", - ((type & SOCK_CLOEXEC) != 0) ? "close-on-exec " : ""); -} - -} // namespace - -SocketPairKind IPv6TCPAcceptBindSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected IPv6 TCP socket"); - return SocketPairKind{ - description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, - TCPAcceptBindSocketPairCreator(AF_INET6, type | SOCK_STREAM, 0, - /* dual_stack = */ false)}; -} - -SocketPairKind IPv4TCPAcceptBindSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket"); - return SocketPairKind{ - description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP, - TCPAcceptBindSocketPairCreator(AF_INET, type | SOCK_STREAM, 0, - /* dual_stack = */ false)}; -} - -SocketPairKind DualStackTCPAcceptBindSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected dual stack TCP socket"); - return SocketPairKind{ - description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, - TCPAcceptBindSocketPairCreator(AF_INET6, type | SOCK_STREAM, 0, - /* dual_stack = */ true)}; -} - -SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected IPv6 TCP socket"); - return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, - TCPAcceptBindPersistentListenerSocketPairCreator( - AF_INET6, type | SOCK_STREAM, 0, - /* dual_stack = */ false)}; -} - -SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket"); - return SocketPairKind{description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP, - TCPAcceptBindPersistentListenerSocketPairCreator( - AF_INET, type | SOCK_STREAM, 0, - /* dual_stack = */ false)}; -} - -SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected dual stack TCP socket"); - return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, - TCPAcceptBindPersistentListenerSocketPairCreator( - AF_INET6, type | SOCK_STREAM, 0, - /* dual_stack = */ true)}; -} - -SocketPairKind IPv6UDPBidirectionalBindSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected IPv6 UDP socket"); - return SocketPairKind{ - description, AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP, - UDPBidirectionalBindSocketPairCreator(AF_INET6, type | SOCK_DGRAM, 0, - /* dual_stack = */ false)}; -} - -SocketPairKind IPv4UDPBidirectionalBindSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected IPv4 UDP socket"); - return SocketPairKind{ - description, AF_INET, type | SOCK_DGRAM, IPPROTO_UDP, - UDPBidirectionalBindSocketPairCreator(AF_INET, type | SOCK_DGRAM, 0, - /* dual_stack = */ false)}; -} - -SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "connected dual stack UDP socket"); - return SocketPairKind{ - description, AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP, - UDPBidirectionalBindSocketPairCreator(AF_INET6, type | SOCK_DGRAM, 0, - /* dual_stack = */ true)}; -} - -SocketPairKind IPv4UDPUnboundSocketPair(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "IPv4 UDP socket"); - return SocketPairKind{ - description, AF_INET, type | SOCK_DGRAM, IPPROTO_UDP, - UDPUnboundSocketPairCreator(AF_INET, type | SOCK_DGRAM, 0, - /* dual_stack = */ false)}; -} - -SocketKind IPv4UDPUnboundSocket(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "IPv4 UDP socket"); - return SocketKind{ - description, AF_INET, type | SOCK_DGRAM, IPPROTO_UDP, - UnboundSocketCreator(AF_INET, type | SOCK_DGRAM, IPPROTO_UDP)}; -} - -SocketKind IPv6UDPUnboundSocket(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "IPv6 UDP socket"); - return SocketKind{ - description, AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP, - UnboundSocketCreator(AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP)}; -} - -SocketKind IPv4TCPUnboundSocket(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "IPv4 TCP socket"); - return SocketKind{ - description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP, - UnboundSocketCreator(AF_INET, type | SOCK_STREAM, IPPROTO_TCP)}; -} - -SocketKind IPv6TCPUnboundSocket(int type) { - std::string description = - absl::StrCat(DescribeSocketType(type), "IPv6 TCP socket"); - return SocketKind{ - description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, - UnboundSocketCreator(AF_INET6, type | SOCK_STREAM, IPPROTO_TCP)}; -} - -PosixError IfAddrHelper::Load() { - Release(); - RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_)); - return PosixError(0); -} - -void IfAddrHelper::Release() { - if (ifaddr_) { - freeifaddrs(ifaddr_); - } - ifaddr_ = nullptr; -} - -std::vector<std::string> IfAddrHelper::InterfaceList(int family) { - std::vector<std::string> names; - for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { - if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { - continue; - } - names.emplace(names.end(), ifa->ifa_name); - } - return names; -} - -sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { - for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { - if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { - continue; - } - if (name == ifa->ifa_name) { - return ifa->ifa_addr; - } - } - return nullptr; -} - -PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) { - return InterfaceIndex(name); -} - -std::string GetAddr4Str(const in_addr* a) { - char str[INET_ADDRSTRLEN]; - inet_ntop(AF_INET, a, str, sizeof(str)); - return std::string(str); -} - -std::string GetAddr6Str(const in6_addr* a) { - char str[INET6_ADDRSTRLEN]; - inet_ntop(AF_INET6, a, str, sizeof(str)); - return std::string(str); -} - -std::string GetAddrStr(const sockaddr* a) { - if (a->sa_family == AF_INET) { - auto src = &(reinterpret_cast<const sockaddr_in*>(a)->sin_addr); - return GetAddr4Str(src); - } else if (a->sa_family == AF_INET6) { - auto src = &(reinterpret_cast<const sockaddr_in6*>(a)->sin6_addr); - return GetAddr6Str(src); - } - return std::string("<invalid>"); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h deleted file mode 100644 index 39fd6709d..000000000 --- a/test/syscalls/linux/ip_socket_test_util.h +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_IP_SOCKET_TEST_UTIL_H_ -#define GVISOR_TEST_SYSCALLS_IP_SOCKET_TEST_UTIL_H_ - -#include <arpa/inet.h> -#include <ifaddrs.h> -#include <sys/types.h> - -#include <string> - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Extracts the IP address from an inet sockaddr in network byte order. -uint32_t IPFromInetSockaddr(const struct sockaddr* addr); - -// Extracts the port from an inet sockaddr in host byte order. -uint16_t PortFromInetSockaddr(const struct sockaddr* addr); - -// InterfaceIndex returns the index of the named interface. -PosixErrorOr<int> InterfaceIndex(std::string name); - -// IPv6TCPAcceptBindSocketPair returns a SocketPairKind that represents -// SocketPairs created with bind() and accept() syscalls with AF_INET6 and the -// given type bound to the IPv6 loopback. -SocketPairKind IPv6TCPAcceptBindSocketPair(int type); - -// IPv4TCPAcceptBindSocketPair returns a SocketPairKind that represents -// SocketPairs created with bind() and accept() syscalls with AF_INET and the -// given type bound to the IPv4 loopback. -SocketPairKind IPv4TCPAcceptBindSocketPair(int type); - -// DualStackTCPAcceptBindSocketPair returns a SocketPairKind that represents -// SocketPairs created with bind() and accept() syscalls with AF_INET6 and the -// given type bound to the IPv4 loopback. -SocketPairKind DualStackTCPAcceptBindSocketPair(int type); - -// IPv6TCPAcceptBindPersistentListenerSocketPair is like -// IPv6TCPAcceptBindSocketPair except it uses a persistent listening socket to -// create all socket pairs. -SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type); - -// IPv4TCPAcceptBindPersistentListenerSocketPair is like -// IPv4TCPAcceptBindSocketPair except it uses a persistent listening socket to -// create all socket pairs. -SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type); - -// DualStackTCPAcceptBindPersistentListenerSocketPair is like -// DualStackTCPAcceptBindSocketPair except it uses a persistent listening socket -// to create all socket pairs. -SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type); - -// IPv6UDPBidirectionalBindSocketPair returns a SocketPairKind that represents -// SocketPairs created with bind() and connect() syscalls with AF_INET6 and the -// given type bound to the IPv6 loopback. -SocketPairKind IPv6UDPBidirectionalBindSocketPair(int type); - -// IPv4UDPBidirectionalBindSocketPair returns a SocketPairKind that represents -// SocketPairs created with bind() and connect() syscalls with AF_INET and the -// given type bound to the IPv4 loopback. -SocketPairKind IPv4UDPBidirectionalBindSocketPair(int type); - -// DualStackUDPBidirectionalBindSocketPair returns a SocketPairKind that -// represents SocketPairs created with bind() and connect() syscalls with -// AF_INET6 and the given type bound to the IPv4 loopback. -SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type); - -// IPv4UDPUnboundSocketPair returns a SocketPairKind that represents -// SocketPairs created with AF_INET and the given type. -SocketPairKind IPv4UDPUnboundSocketPair(int type); - -// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket -// created with AF_INET, SOCK_DGRAM, and the given type. -SocketKind IPv4UDPUnboundSocket(int type); - -// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket -// created with AF_INET6, SOCK_DGRAM, and the given type. -SocketKind IPv6UDPUnboundSocket(int type); - -// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket -// created with AF_INET, SOCK_STREAM and the given type. -SocketKind IPv4TCPUnboundSocket(int type); - -// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket -// created with AF_INET6, SOCK_STREAM and the given type. -SocketKind IPv6TCPUnboundSocket(int type); - -// IfAddrHelper is a helper class that determines the local interfaces present -// and provides functions to obtain their names, index numbers, and IP address. -class IfAddrHelper { - public: - IfAddrHelper() : ifaddr_(nullptr) {} - ~IfAddrHelper() { Release(); } - - PosixError Load(); - void Release(); - - std::vector<std::string> InterfaceList(int family); - - struct sockaddr* GetAddr(int family, std::string name); - PosixErrorOr<int> GetIndex(std::string name); - - private: - struct ifaddrs* ifaddr_; -}; - -// GetAddr4Str returns the given IPv4 network address structure as a string. -std::string GetAddr4Str(const in_addr* a); - -// GetAddr6Str returns the given IPv6 network address structure as a string. -std::string GetAddr6Str(const in6_addr* a); - -// GetAddrStr returns the given IPv4 or IPv6 network address structure as a -// string. -std::string GetAddrStr(const sockaddr* a); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_IP_SOCKET_TEST_UTIL_H_ diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc deleted file mode 100644 index b8e4ece64..000000000 --- a/test/syscalls/linux/iptables.cc +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/iptables.h" - -#include <arpa/inet.h> -#include <linux/capability.h> -#include <linux/netfilter/x_tables.h> -#include <net/if.h> -#include <netinet/in.h> -#include <netinet/ip.h> -#include <netinet/ip_icmp.h> -#include <stdio.h> -#include <sys/poll.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> - -#include "gtest/gtest.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr char kNatTablename[] = "nat"; -constexpr char kErrorTarget[] = "ERROR"; -constexpr size_t kEmptyStandardEntrySize = - sizeof(struct ipt_entry) + sizeof(struct ipt_standard_target); -constexpr size_t kEmptyErrorEntrySize = - sizeof(struct ipt_entry) + sizeof(struct ipt_error_target); - -TEST(IPTablesBasic, CreateSocket) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - int sock; - ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), - SyscallSucceeds()); - - ASSERT_THAT(close(sock), SyscallSucceeds()); -} - -TEST(IPTablesBasic, FailSockoptNonRaw) { - // Even if the user has CAP_NET_RAW, they shouldn't be able to use the - // iptables sockopts with a non-raw socket. - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - int sock; - ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); - - struct ipt_getinfo info = {}; - snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); - socklen_t info_size = sizeof(info); - EXPECT_THAT(getsockopt(sock, IPPROTO_IP, SO_GET_INFO, &info, &info_size), - SyscallFailsWithErrno(ENOPROTOOPT)); - - ASSERT_THAT(close(sock), SyscallSucceeds()); -} - -// Fixture for iptables tests. -class IPTablesTest : public ::testing::Test { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // The socket via which to manipulate iptables. - int s_; -}; - -void IPTablesTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); -} - -void IPTablesTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - EXPECT_THAT(close(s_), SyscallSucceeds()); -} - -// This tests the initial state of a machine with empty iptables. We don't have -// a guarantee that the iptables are empty when running in native, but we can -// test that gVisor has the same initial state that a newly-booted Linux machine -// would have. -TEST_F(IPTablesTest, InitialState) { - SKIP_IF(!IsRunningOnGvisor()); - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // - // Get info via sockopt. - // - struct ipt_getinfo info = {}; - snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); - socklen_t info_size = sizeof(info); - ASSERT_THAT(getsockopt(s_, IPPROTO_IP, SO_GET_INFO, &info, &info_size), - SyscallSucceeds()); - - // The nat table supports PREROUTING, and OUTPUT. - unsigned int valid_hooks = (1 << NF_IP_PRE_ROUTING) | (1 << NF_IP_LOCAL_OUT) | - (1 << NF_IP_POST_ROUTING) | (1 << NF_IP_LOCAL_IN); - - EXPECT_EQ(info.valid_hooks, valid_hooks); - - // Each chain consists of an empty entry with a standard target.. - EXPECT_EQ(info.hook_entry[NF_IP_PRE_ROUTING], 0); - EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_IN], kEmptyStandardEntrySize); - EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2); - EXPECT_EQ(info.hook_entry[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3); - - // The underflow points are the same as the entry points. - EXPECT_EQ(info.underflow[NF_IP_PRE_ROUTING], 0); - EXPECT_EQ(info.underflow[NF_IP_LOCAL_IN], kEmptyStandardEntrySize); - EXPECT_EQ(info.underflow[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2); - EXPECT_EQ(info.underflow[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3); - - // One entry for each chain, plus an error entry at the end. - EXPECT_EQ(info.num_entries, 5); - - EXPECT_EQ(info.size, 4 * kEmptyStandardEntrySize + kEmptyErrorEntrySize); - EXPECT_EQ(strcmp(info.name, kNatTablename), 0); - - // - // Use info to get entries. - // - socklen_t entries_size = sizeof(struct ipt_get_entries) + info.size; - struct ipt_get_entries* entries = - static_cast<struct ipt_get_entries*>(malloc(entries_size)); - snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); - entries->size = info.size; - ASSERT_THAT( - getsockopt(s_, IPPROTO_IP, SO_GET_ENTRIES, entries, &entries_size), - SyscallSucceeds()); - - // Verify the name and size. - ASSERT_EQ(info.size, entries->size); - ASSERT_EQ(strcmp(entries->name, kNatTablename), 0); - - // Verify that the entrytable is 4 entries with accept targets and no matches - // followed by a single error target. - size_t entry_offset = 0; - while (entry_offset < entries->size) { - struct ipt_entry* entry = reinterpret_cast<struct ipt_entry*>( - reinterpret_cast<char*>(entries->entrytable) + entry_offset); - - // ip should be zeroes. - struct ipt_ip zeroed = {}; - EXPECT_EQ(memcmp(static_cast<void*>(&zeroed), - static_cast<void*>(&entry->ip), sizeof(zeroed)), - 0); - - // target_offset should be zero. - EXPECT_EQ(entry->target_offset, sizeof(ipt_entry)); - - if (entry_offset < kEmptyStandardEntrySize * 4) { - // The first 4 entries are standard targets - struct ipt_standard_target* target = - reinterpret_cast<struct ipt_standard_target*>(entry->elems); - EXPECT_EQ(entry->next_offset, kEmptyStandardEntrySize); - EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); - EXPECT_EQ(strcmp(target->target.u.user.name, ""), 0); - EXPECT_EQ(target->target.u.user.revision, 0); - // This is what's returned for an accept verdict. I don't know why. - EXPECT_EQ(target->verdict, -NF_ACCEPT - 1); - } else { - // The last entry is an error target - struct ipt_error_target* target = - reinterpret_cast<struct ipt_error_target*>(entry->elems); - EXPECT_EQ(entry->next_offset, kEmptyErrorEntrySize); - EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); - EXPECT_EQ(strcmp(target->target.u.user.name, kErrorTarget), 0); - EXPECT_EQ(target->target.u.user.revision, 0); - EXPECT_EQ(strcmp(target->errorname, kErrorTarget), 0); - } - - entry_offset += entry->next_offset; - } - - free(entries); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/iptables.h b/test/syscalls/linux/iptables.h deleted file mode 100644 index 0719c60a4..000000000 --- a/test/syscalls/linux/iptables.h +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// There are a number of structs and values that we can't #include because of a -// difference between C and C++ (C++ won't let you implicitly cast from void* to -// struct something*). We re-define them here. - -#ifndef GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_ -#define GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_ - -// Netfilter headers require some headers to preceed them. -// clang-format off -#include <netinet/in.h> -#include <stddef.h> -// clang-format on - -#include <linux/netfilter/x_tables.h> -#include <linux/netfilter_ipv4.h> -#include <net/if.h> -#include <netinet/ip.h> -#include <stdint.h> - -#define ipt_standard_target xt_standard_target -#define ipt_entry_target xt_entry_target -#define ipt_error_target xt_error_target - -enum SockOpts { - // For setsockopt. - BASE_CTL = 64, - SO_SET_REPLACE = BASE_CTL, - SO_SET_ADD_COUNTERS, - SO_SET_MAX = SO_SET_ADD_COUNTERS, - - // For getsockopt. - SO_GET_INFO = BASE_CTL, - SO_GET_ENTRIES, - SO_GET_REVISION_MATCH, - SO_GET_REVISION_TARGET, - SO_GET_MAX = SO_GET_REVISION_TARGET -}; - -// ipt_ip specifies basic matching criteria that can be applied by examining -// only the IP header of a packet. -struct ipt_ip { - // Source IP address. - struct in_addr src; - - // Destination IP address. - struct in_addr dst; - - // Source IP address mask. - struct in_addr smsk; - - // Destination IP address mask. - struct in_addr dmsk; - - // Input interface. - char iniface[IFNAMSIZ]; - - // Output interface. - char outiface[IFNAMSIZ]; - - // Input interface mask. - unsigned char iniface_mask[IFNAMSIZ]; - - // Output interface mask. - unsigned char outiface_mask[IFNAMSIZ]; - - // Transport protocol. - uint16_t proto; - - // Flags. - uint8_t flags; - - // Inverse flags. - uint8_t invflags; -}; - -// ipt_entry is an iptables rule. It contains information about what packets the -// rule matches and what action (target) to perform for matching packets. -struct ipt_entry { - // Basic matching information used to match a packet's IP header. - struct ipt_ip ip; - - // A caching field that isn't used by userspace. - unsigned int nfcache; - - // The number of bytes between the start of this ipt_entry struct and the - // rule's target. - uint16_t target_offset; - - // The total size of this rule, from the beginning of the entry to the end of - // the target. - uint16_t next_offset; - - // A return pointer not used by userspace. - unsigned int comefrom; - - // Counters for packets and bytes, which we don't yet implement. - struct xt_counters counters; - - // The data for all this rules matches followed by the target. This runs - // beyond the value of sizeof(struct ipt_entry). - unsigned char elems[0]; -}; - -// Passed to getsockopt(SO_GET_INFO). -struct ipt_getinfo { - // The name of the table. The user only fills this in, the rest is filled in - // when returning from getsockopt. Currently "nat" and "mangle" are supported. - char name[XT_TABLE_MAXNAMELEN]; - - // A bitmap of which hooks apply to the table. For example, a table with hooks - // PREROUTING and FORWARD has the value - // (1 << NF_IP_PRE_REOUTING) | (1 << NF_IP_FORWARD). - unsigned int valid_hooks; - - // The offset into the entry table for each valid hook. The entry table is - // returned by getsockopt(SO_GET_ENTRIES). - unsigned int hook_entry[NF_IP_NUMHOOKS]; - - // For each valid hook, the underflow is the offset into the entry table to - // jump to in case traversing the table yields no verdict (although I have no - // clue how that could happen - builtin chains always end with a policy, and - // user-defined chains always end with a RETURN. - // - // The entry referred to must be an "unconditional" entry, meaning it has no - // matches, specifies no IP criteria, and either DROPs or ACCEPTs packets. It - // basically has to be capable of making a definitive decision no matter what - // it's passed. - unsigned int underflow[NF_IP_NUMHOOKS]; - - // The number of entries in the entry table returned by - // getsockopt(SO_GET_ENTRIES). - unsigned int num_entries; - - // The size of the entry table returned by getsockopt(SO_GET_ENTRIES). - unsigned int size; -}; - -// Passed to getsockopt(SO_GET_ENTRIES). -struct ipt_get_entries { - // The name of the table. The user fills this in. Currently "nat" and "mangle" - // are supported. - char name[XT_TABLE_MAXNAMELEN]; - - // The size of the entry table in bytes. The user fills this in with the value - // from struct ipt_getinfo.size. - unsigned int size; - - // The entries for the given table. This will run past the size defined by - // sizeof(struct ipt_get_entries). - struct ipt_entry entrytable[0]; -}; - -// Passed to setsockopt(SO_SET_REPLACE). -struct ipt_replace { - // The name of the table. - char name[XT_TABLE_MAXNAMELEN]; - - // The same as struct ipt_getinfo.valid_hooks. Users don't change this. - unsigned int valid_hooks; - - // The same as struct ipt_getinfo.num_entries. - unsigned int num_entries; - - // The same as struct ipt_getinfo.size. - unsigned int size; - - // The same as struct ipt_getinfo.hook_entry. - unsigned int hook_entry[NF_IP_NUMHOOKS]; - - // The same as struct ipt_getinfo.underflow. - unsigned int underflow[NF_IP_NUMHOOKS]; - - // The number of counters, which should equal the number of entries. - unsigned int num_counters; - - // The unchanged values from each ipt_entry's counters. - struct xt_counters* counters; - - // The entries to write to the table. This will run past the size defined by - // sizeof(srtuct ipt_replace); - struct ipt_entry entries[0]; -}; - -#endif // GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_ diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc deleted file mode 100644 index 8b48f0804..000000000 --- a/test/syscalls/linux/itimer.cc +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <sys/socket.h> -#include <sys/time.h> -#include <sys/types.h> -#include <time.h> - -#include <atomic> -#include <functional> -#include <iostream> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { -namespace { - -constexpr char kSIGALRMToMainThread[] = "--itimer_sigarlm_to_main_thread"; -constexpr char kSIGPROFFairnessActive[] = "--itimer_sigprof_fairness_active"; -constexpr char kSIGPROFFairnessIdle[] = "--itimer_sigprof_fairness_idle"; - -// Time period to be set for the itimers. -constexpr absl::Duration kPeriod = absl::Milliseconds(25); -// Total amount of time to spend per thread. -constexpr absl::Duration kTestDuration = absl::Seconds(20); -// Amount of spin iterations to perform as the minimum work item per thread. -// Chosen to be sub-millisecond range. -constexpr int kIterations = 10000000; -// Allow deviation in the number of samples. -constexpr double kNumSamplesDeviationRatio = 0.2; - -TEST(ItimerTest, ItimervalUpdatedBeforeExpiration) { - constexpr int kSleepSecs = 10; - constexpr int kAlarmSecs = 15; - static_assert( - kSleepSecs < kAlarmSecs, - "kSleepSecs must be less than kAlarmSecs for the test to be meaningful"); - constexpr int kMaxRemainingSecs = kAlarmSecs - kSleepSecs; - - // Install a no-op handler for SIGALRM. - struct sigaction sa = {}; - sigfillset(&sa.sa_mask); - sa.sa_handler = +[](int signo) {}; - auto const cleanup_sa = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); - - // Set an itimer-based alarm for kAlarmSecs from now. - struct itimerval itv = {}; - itv.it_value.tv_sec = kAlarmSecs; - auto const cleanup_itimer = - ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, itv)); - - // After sleeping for kSleepSecs, the itimer value should reflect the elapsed - // time even if it hasn't expired. - absl::SleepFor(absl::Seconds(kSleepSecs)); - ASSERT_THAT(getitimer(ITIMER_REAL, &itv), SyscallSucceeds()); - EXPECT_TRUE( - itv.it_value.tv_sec < kMaxRemainingSecs || - (itv.it_value.tv_sec == kMaxRemainingSecs && itv.it_value.tv_usec == 0)) - << "Remaining time: " << itv.it_value.tv_sec << " seconds + " - << itv.it_value.tv_usec << " microseconds"; -} - -ABSL_CONST_INIT static thread_local std::atomic_int signal_test_num_samples = - ATOMIC_VAR_INIT(0); - -void SignalTestSignalHandler(int /*signum*/) { signal_test_num_samples++; } - -struct SignalTestResult { - int expected_total; - int main_thread_samples; - std::vector<int> worker_samples; -}; - -std::ostream& operator<<(std::ostream& os, const SignalTestResult& r) { - os << "{expected_total: " << r.expected_total - << ", main_thread_samples: " << r.main_thread_samples - << ", worker_samples: ["; - bool first = true; - for (int sample : r.worker_samples) { - if (!first) { - os << ", "; - } - os << sample; - first = false; - } - os << "]}"; - return os; -} - -// Starts two worker threads and itimer id and measures the number of signal -// delivered to each thread. -SignalTestResult ItimerSignalTest(int id, clock_t main_clock, - clock_t worker_clock, int signal, - absl::Duration sleep) { - signal_test_num_samples = 0; - - struct sigaction sa = {}; - sa.sa_handler = &SignalTestSignalHandler; - sa.sa_flags = SA_RESTART; - sigemptyset(&sa.sa_mask); - auto sigaction_cleanup = ScopedSigaction(signal, sa).ValueOrDie(); - - int socketfds[2]; - TEST_PCHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, socketfds) == 0); - - // Do the spinning in the workers. - std::function<void*(int)> work = [&](int socket_fd) { - FileDescriptor fd(socket_fd); - - absl::Time finish = Now(worker_clock) + kTestDuration; - while (Now(worker_clock) < finish) { - // Blocked on read. - char c; - RetryEINTR(read)(fd.get(), &c, 1); - for (int i = 0; i < kIterations; i++) { - // Ensure compiler won't optimize this loop away. - asm(""); - } - - if (sleep != absl::ZeroDuration()) { - // Sleep so that the entire process is idle for a while. - absl::SleepFor(sleep); - } - - // Unblock the other thread. - RetryEINTR(write)(fd.get(), &c, 1); - } - - return reinterpret_cast<void*>(signal_test_num_samples.load()); - }; - - ScopedThread th1( - static_cast<std::function<void*()>>(std::bind(work, socketfds[0]))); - ScopedThread th2( - static_cast<std::function<void*()>>(std::bind(work, socketfds[1]))); - - absl::Time start = Now(main_clock); - // Start the timer. - struct itimerval timer = {}; - timer.it_value = absl::ToTimeval(kPeriod); - timer.it_interval = absl::ToTimeval(kPeriod); - auto cleanup_itimer = ScopedItimer(id, timer).ValueOrDie(); - - // Unblock th1. - // - // N.B. th2 owns socketfds[1] but can't close it until it unblocks. - char c = 0; - TEST_CHECK(write(socketfds[1], &c, 1) == 1); - - SignalTestResult result; - - // Wait for the workers to be done and collect their sample counts. - result.worker_samples.push_back(reinterpret_cast<int64_t>(th1.Join())); - result.worker_samples.push_back(reinterpret_cast<int64_t>(th2.Join())); - cleanup_itimer.Release()(); - result.expected_total = (Now(main_clock) - start) / kPeriod; - result.main_thread_samples = signal_test_num_samples.load(); - - return result; -} - -int TestSIGALRMToMainThread() { - SignalTestResult result = - ItimerSignalTest(ITIMER_REAL, CLOCK_REALTIME, CLOCK_REALTIME, SIGALRM, - absl::ZeroDuration()); - - std::cerr << "result: " << result << std::endl; - - // ITIMER_REAL-generated SIGALRMs prefer to deliver to the thread group leader - // (but don't guarantee it), so we expect to see most samples on the main - // thread. - // - // The number of SIGALRMs delivered to a worker should not exceed 20% - // of the number of total signals expected (this is somewhat arbitrary). - const int worker_threshold = result.expected_total / 5; - - // - // Linux only guarantees timers will never expire before the requested time. - // Thus, we only check the upper bound and also it at least have one sample. - TEST_CHECK(result.main_thread_samples <= result.expected_total); - TEST_CHECK(result.main_thread_samples > 0); - for (int num : result.worker_samples) { - TEST_CHECK_MSG(num <= worker_threshold, "worker received too many samples"); - } - - return 0; -} - -// Random save/restore is disabled as it introduces additional latency and -// unpredictable distribution patterns. -TEST(ItimerTest, DeliversSIGALRMToMainThread_NoRandomSave) { - pid_t child; - int execve_errno; - auto kill = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec("/proc/self/exe", {"/proc/self/exe", kSIGALRMToMainThread}, - {}, &child, &execve_errno)); - EXPECT_EQ(0, execve_errno); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - - // Not required anymore. - kill.Release(); - - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) << status; -} - -// Signals are delivered to threads fairly. -// -// sleep indicates how long to sleep worker threads each iteration to make the -// entire process idle. -int TestSIGPROFFairness(absl::Duration sleep) { - SignalTestResult result = - ItimerSignalTest(ITIMER_PROF, CLOCK_PROCESS_CPUTIME_ID, - CLOCK_THREAD_CPUTIME_ID, SIGPROF, sleep); - - std::cerr << "result: " << result << std::endl; - - // The number of samples on the main thread should be very low as it did - // nothing. - TEST_CHECK(result.main_thread_samples < 60); - - // Both workers should get roughly equal number of samples. - TEST_CHECK(result.worker_samples.size() == 2); - - TEST_CHECK(result.expected_total > 0); - - // In an ideal world each thread would get exactly 50% of the signals, - // but since that's unlikely to happen we allow for them to get no less than - // kNumSamplesDeviationRatio of the total observed samples. - TEST_CHECK_MSG(std::abs(result.worker_samples[0] - result.worker_samples[1]) < - ((result.worker_samples[0] + result.worker_samples[1]) * - kNumSamplesDeviationRatio), - "one worker received disproportionate share of samples"); - - return 0; -} - -// Random save/restore is disabled as it introduces additional latency and -// unpredictable distribution patterns. -TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { - // TODO(b/143247272): CPU time accounting is inaccurate for the KVM platform. - SKIP_IF(GvisorPlatform() == Platform::kKVM); - - pid_t child; - int execve_errno; - auto kill = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec("/proc/self/exe", {"/proc/self/exe", kSIGPROFFairnessActive}, - {}, &child, &execve_errno)); - EXPECT_EQ(0, execve_errno); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - - // Not required anymore. - kill.Release(); - - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -// Random save/restore is disabled as it introduces additional latency and -// unpredictable distribution patterns. -TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) { - // TODO(b/143247272): CPU time accounting is inaccurate for the KVM platform. - SKIP_IF(GvisorPlatform() == Platform::kKVM); - - pid_t child; - int execve_errno; - auto kill = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec("/proc/self/exe", {"/proc/self/exe", kSIGPROFFairnessIdle}, - {}, &child, &execve_errno)); - EXPECT_EQ(0, execve_errno); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - - // Not required anymore. - kill.Release(); - - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "Exited with code: " << status; -} - -} // namespace -} // namespace testing -} // namespace gvisor - -namespace { -void MaskSIGPIPE() { - // Always mask SIGPIPE as it's common and tests aren't expected to handle it. - // We don't take the TestInit() path so we must do this manually. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); -} -} // namespace - -int main(int argc, char** argv) { - // These tests require no background threads, so check for them before - // TestInit. - for (int i = 0; i < argc; i++) { - absl::string_view arg(argv[i]); - - if (arg == gvisor::testing::kSIGALRMToMainThread) { - MaskSIGPIPE(); - return gvisor::testing::TestSIGALRMToMainThread(); - } - if (arg == gvisor::testing::kSIGPROFFairnessActive) { - MaskSIGPIPE(); - return gvisor::testing::TestSIGPROFFairness(absl::ZeroDuration()); - } - if (arg == gvisor::testing::kSIGPROFFairnessIdle) { - MaskSIGPIPE(); - // Sleep time > ClockTick (10ms) exercises sleeping gVisor's - // kernel.cpuClockTicker. - return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(25)); - } - } - - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/kill.cc b/test/syscalls/linux/kill.cc deleted file mode 100644 index db29bd59c..000000000 --- a/test/syscalls/linux/kill.cc +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include <cerrno> -#include <csignal> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID"); -ABSL_FLAG(int32_t, scratch_gid, 65534, "scratch GID"); - -using ::testing::Ge; - -namespace gvisor { -namespace testing { - -namespace { - -TEST(KillTest, CanKillValidPid) { - // If pid is positive, then signal sig is sent to the process with the ID - // specified by pid. - EXPECT_THAT(kill(getpid(), 0), SyscallSucceeds()); - // If pid equals 0, then sig is sent to every process in the process group of - // the calling process. - EXPECT_THAT(kill(0, 0), SyscallSucceeds()); - - ScopedThread([] { EXPECT_THAT(kill(gettid(), 0), SyscallSucceeds()); }); -} - -void SigHandler(int sig, siginfo_t* info, void* context) { _exit(0); } - -// If pid equals -1, then sig is sent to every process for which the calling -// process has permission to send signals, except for process 1 (init). -TEST(KillTest, CanKillAllPIDs) { - int pipe_fds[2]; - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - FileDescriptor read_fd(pipe_fds[0]); - FileDescriptor write_fd(pipe_fds[1]); - - pid_t pid = fork(); - if (pid == 0) { - read_fd.reset(); - - struct sigaction sa; - sa.sa_sigaction = SigHandler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - TEST_PCHECK(sigaction(SIGWINCH, &sa, nullptr) == 0); - MaybeSave(); - - // Indicate to the parent that we're ready. - write_fd.reset(); - - // Wait until we get the signal from the parent. - while (true) { - pause(); - } - } - - ASSERT_THAT(pid, SyscallSucceeds()); - - write_fd.reset(); - - // Wait for the child to indicate that it's unmasked the signal by closing - // the write end. - char buf; - ASSERT_THAT(ReadFd(read_fd.get(), &buf, 1), SyscallSucceedsWithValue(0)); - - // Signal the child and wait for it to die with status 0, indicating that - // it got the expected signal. - EXPECT_THAT(kill(-1, SIGWINCH), SyscallSucceeds()); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), - SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(0, WEXITSTATUS(status)); -} - -TEST(KillTest, CannotKillInvalidPID) { - // We need an unused pid to verify that kill fails when given one. - // - // There is no way to guarantee that a PID is unused, but the PID of a - // recently exited process likely won't be reused soon. - pid_t fake_pid = fork(); - if (fake_pid == 0) { - _exit(0); - } - - ASSERT_THAT(fake_pid, SyscallSucceeds()); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(fake_pid, &status, 0), - SyscallSucceedsWithValue(fake_pid)); - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(0, WEXITSTATUS(status)); - - EXPECT_THAT(kill(fake_pid, 0), SyscallFailsWithErrno(ESRCH)); -} - -TEST(KillTest, CannotUseInvalidSignal) { - EXPECT_THAT(kill(getpid(), 200), SyscallFailsWithErrno(EINVAL)); -} - -TEST(KillTest, CanKillRemoteProcess) { - pid_t pid = fork(); - if (pid == 0) { - while (true) { - pause(); - } - } - - ASSERT_THAT(pid, SyscallSucceeds()); - - EXPECT_THAT(kill(pid, SIGKILL), SyscallSucceeds()); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), - SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status)); - EXPECT_EQ(SIGKILL, WTERMSIG(status)); -} - -TEST(KillTest, CanKillOwnProcess) { - EXPECT_THAT(kill(getpid(), 0), SyscallSucceeds()); -} - -// Verify that you can kill a process even using a tid from a thread other than -// the group leader. -TEST(KillTest, CannotKillTid) { - pid_t tid; - bool tid_available = false; - bool finished = false; - absl::Mutex mu; - ScopedThread t([&] { - mu.Lock(); - tid = gettid(); - tid_available = true; - mu.Await(absl::Condition(&finished)); - mu.Unlock(); - }); - mu.LockWhen(absl::Condition(&tid_available)); - EXPECT_THAT(kill(tid, 0), SyscallSucceeds()); - finished = true; - mu.Unlock(); -} - -TEST(KillTest, SetPgid) { - for (int i = 0; i < 10; i++) { - // The following in the normal pattern for creating a new process group. - // Both the parent and child process will call setpgid in order to avoid any - // race conditions. We do this ten times to catch races. - pid_t pid = fork(); - if (pid == 0) { - setpgid(0, 0); - while (true) { - pause(); - } - } - - ASSERT_THAT(pid, SyscallSucceeds()); - - // Set the child's group and exit. - ASSERT_THAT(setpgid(pid, pid), SyscallSucceeds()); - EXPECT_THAT(kill(pid, SIGKILL), SyscallSucceeds()); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(-pid, &status, 0), - SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status)); - EXPECT_EQ(SIGKILL, WTERMSIG(status)); - } -} - -TEST(KillTest, ProcessGroups) { - // Fork a new child. - // - // other_child is used as a placeholder process. We use this PID as our "does - // not exist" process group to ensure some amount of safety. (It is still - // possible to violate this assumption, but extremely unlikely.) - pid_t child = fork(); - if (child == 0) { - while (true) { - pause(); - } - } - ASSERT_THAT(child, SyscallSucceeds()); - - pid_t other_child = fork(); - if (other_child == 0) { - while (true) { - pause(); - } - } - ASSERT_THAT(other_child, SyscallSucceeds()); - - // Ensure the kill does not succeed without the new group. - EXPECT_THAT(kill(-child, SIGKILL), SyscallFailsWithErrno(ESRCH)); - - // Put the child in its own process group. - ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); - - // This should be not allowed: you can only create a new group with the same - // id or join an existing one. The other_child group should not exist. - ASSERT_THAT(setpgid(child, other_child), SyscallFailsWithErrno(EPERM)); - - // Done with other_child; kill it. - EXPECT_THAT(kill(other_child, SIGKILL), SyscallSucceeds()); - int status; - EXPECT_THAT(RetryEINTR(waitpid)(other_child, &status, 0), SyscallSucceeds()); - - // Linux returns success for the no-op call. - ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); - - // Kill the child's process group. - ASSERT_THAT(kill(-child, SIGKILL), SyscallSucceeds()); - - // Wait on the process group; ensure that the signal was as expected. - EXPECT_THAT(RetryEINTR(waitpid)(-child, &status, 0), - SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFSIGNALED(status)); - EXPECT_EQ(SIGKILL, WTERMSIG(status)); - - // Try to kill the process group again; ensure that the wait fails. - EXPECT_THAT(kill(-child, SIGKILL), SyscallFailsWithErrno(ESRCH)); - EXPECT_THAT(RetryEINTR(waitpid)(-child, &status, 0), - SyscallFailsWithErrno(ECHILD)); -} - -TEST(KillTest, ChildDropsPrivsCannotKill) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - - const int uid = absl::GetFlag(FLAGS_scratch_uid); - const int gid = absl::GetFlag(FLAGS_scratch_gid); - - // Create the child that drops privileges and tries to kill the parent. - pid_t pid = fork(); - if (pid == 0) { - TEST_PCHECK(setresgid(gid, gid, gid) == 0); - MaybeSave(); - - TEST_PCHECK(setresuid(uid, uid, uid) == 0); - MaybeSave(); - - // setresuid should have dropped CAP_KILL. Make sure. - TEST_CHECK(!HaveCapability(CAP_KILL).ValueOrDie()); - - // Try to kill parent with every signal-sending syscall possible. - pid_t parent = getppid(); - - TEST_CHECK(kill(parent, SIGKILL) < 0); - TEST_PCHECK_MSG(errno == EPERM, "kill failed with wrong errno"); - MaybeSave(); - - TEST_CHECK(tgkill(parent, parent, SIGKILL) < 0); - TEST_PCHECK_MSG(errno == EPERM, "tgkill failed with wrong errno"); - MaybeSave(); - - TEST_CHECK(syscall(SYS_tkill, parent, SIGKILL) < 0); - TEST_PCHECK_MSG(errno == EPERM, "tkill failed with wrong errno"); - MaybeSave(); - - siginfo_t uinfo; - uinfo.si_code = -1; // SI_QUEUE (allowed). - - TEST_CHECK(syscall(SYS_rt_sigqueueinfo, parent, SIGKILL, &uinfo) < 0); - TEST_PCHECK_MSG(errno == EPERM, "rt_sigqueueinfo failed with wrong errno"); - MaybeSave(); - - TEST_CHECK(syscall(SYS_rt_tgsigqueueinfo, parent, parent, SIGKILL, &uinfo) < - 0); - TEST_PCHECK_MSG(errno == EPERM, "rt_sigqueueinfo failed with wrong errno"); - MaybeSave(); - - _exit(0); - } - - ASSERT_THAT(pid, SyscallSucceeds()); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(pid, &status, 0), - SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status = " << status; -} - -TEST(KillTest, CanSIGCONTSameSession) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - - pid_t stopped_child = fork(); - if (stopped_child == 0) { - raise(SIGSTOP); - _exit(0); - } - - ASSERT_THAT(stopped_child, SyscallSucceeds()); - - // Put the child in its own process group. The child and parent process - // groups also share a session. - ASSERT_THAT(setpgid(stopped_child, stopped_child), SyscallSucceeds()); - - // Make sure child stopped. - int status; - EXPECT_THAT(RetryEINTR(waitpid)(stopped_child, &status, WUNTRACED), - SyscallSucceedsWithValue(stopped_child)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << "status " << status; - - const int uid = absl::GetFlag(FLAGS_scratch_uid); - const int gid = absl::GetFlag(FLAGS_scratch_gid); - - // Drop privileges only in child process, or else this parent process won't be - // able to open some log files after the test ends. - pid_t other_child = fork(); - if (other_child == 0) { - // Drop privileges. - TEST_PCHECK(setresgid(gid, gid, gid) == 0); - MaybeSave(); - - TEST_PCHECK(setresuid(uid, uid, uid) == 0); - MaybeSave(); - - // setresuid should have dropped CAP_KILL. - TEST_CHECK(!HaveCapability(CAP_KILL).ValueOrDie()); - - // Child 2 and child should now not share a thread group and any UIDs. - // Child 2 should have no privileges. That means any signal other than - // SIGCONT should fail. - TEST_CHECK(kill(stopped_child, SIGKILL) < 0); - TEST_PCHECK_MSG(errno == EPERM, "kill failed with wrong errno"); - MaybeSave(); - - TEST_PCHECK(kill(stopped_child, SIGCONT) == 0); - MaybeSave(); - - _exit(0); - } - - ASSERT_THAT(stopped_child, SyscallSucceeds()); - - // Make sure child exited normally. - EXPECT_THAT(RetryEINTR(waitpid)(stopped_child, &status, 0), - SyscallSucceedsWithValue(stopped_child)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; - - // Make sure other_child exited normally. - EXPECT_THAT(RetryEINTR(waitpid)(other_child, &status, 0), - SyscallSucceedsWithValue(other_child)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc deleted file mode 100644 index e74fa2ed5..000000000 --- a/test/syscalls/linux/link.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <string.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include <string> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/strings/str_cat.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID"); - -namespace gvisor { -namespace testing { - -namespace { - -// IsSameFile returns true if both filenames have the same device and inode. -bool IsSameFile(const std::string& f1, const std::string& f2) { - // Use lstat rather than stat, so that symlinks are not followed. - struct stat stat1 = {}; - EXPECT_THAT(lstat(f1.c_str(), &stat1), SyscallSucceeds()); - struct stat stat2 = {}; - EXPECT_THAT(lstat(f2.c_str(), &stat2), SyscallSucceeds()); - - return stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino; -} - -TEST(LinkTest, CanCreateLinkFile) { - auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const std::string newname = NewTempAbsPath(); - - // Get the initial link count. - uint64_t initial_link_count = - ASSERT_NO_ERRNO_AND_VALUE(Links(oldfile.path())); - - EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), SyscallSucceeds()); - - EXPECT_TRUE(IsSameFile(oldfile.path(), newname)); - - // Link count should be incremented. - EXPECT_THAT(Links(oldfile.path()), - IsPosixErrorOkAndHolds(initial_link_count + 1)); - - // Delete the link. - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); - - // Link count should be back to initial. - EXPECT_THAT(Links(oldfile.path()), - IsPosixErrorOkAndHolds(initial_link_count)); -} - -TEST(LinkTest, PermissionDenied) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER))); - - // Make the file "unsafe" to link by making it only readable, but not - // writable. - const auto oldfile = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400)); - const std::string newname = NewTempAbsPath(); - - // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. After calling - // setuid(non-zero-UID), there is no way to get root privileges back. - ScopedThread([&] { - // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. POSIX threads, however, require that all - // threads have the same UIDs, so using the setuid wrapper sets all threads' - // real UID. - // Also drops capabilities. - EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)), - SyscallSucceeds()); - - EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), - SyscallFailsWithErrno(EPERM)); - }); -} - -TEST(LinkTest, CannotLinkDirectory) { - auto olddir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string newdir = NewTempAbsPath(); - - EXPECT_THAT(link(olddir.path().c_str(), newdir.c_str()), - SyscallFailsWithErrno(EPERM)); - - EXPECT_THAT(rmdir(olddir.path().c_str()), SyscallSucceeds()); -} - -TEST(LinkTest, CannotLinkWithSlash) { - auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - // Put a final "/" on newname. - const std::string newname = absl::StrCat(NewTempAbsPath(), "/"); - - EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), - SyscallFailsWithErrno(ENOENT)); -} - -TEST(LinkTest, OldnameIsEmpty) { - const std::string newname = NewTempAbsPath(); - EXPECT_THAT(link("", newname.c_str()), SyscallFailsWithErrno(ENOENT)); -} - -TEST(LinkTest, OldnameDoesNotExist) { - const std::string oldname = NewTempAbsPath(); - const std::string newname = NewTempAbsPath(); - EXPECT_THAT(link("", newname.c_str()), SyscallFailsWithErrno(ENOENT)); -} - -TEST(LinkTest, NewnameCannotExist) { - const std::string newname = - JoinPath(GetAbsoluteTestTmpdir(), "thisdoesnotexist", "foo"); - EXPECT_THAT(link("/thisdoesnotmatter", newname.c_str()), - SyscallFailsWithErrno(ENOENT)); -} - -TEST(LinkTest, WithOldDirFD) { - const std::string oldname_parent = NewTempAbsPath(); - const std::string oldname_base = "child"; - const std::string oldname = JoinPath(oldname_parent, oldname_base); - const std::string newname = NewTempAbsPath(); - - // Create oldname_parent directory, and get an FD. - ASSERT_THAT(mkdir(oldname_parent.c_str(), 0777), SyscallSucceeds()); - const FileDescriptor oldname_parent_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(oldname_parent, O_DIRECTORY | O_RDONLY)); - - // Create oldname file. - const FileDescriptor oldname_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(oldname, O_CREAT | O_RDWR, 0666)); - - // Link oldname to newname, using oldname_parent_fd. - EXPECT_THAT(linkat(oldname_parent_fd.get(), oldname_base.c_str(), AT_FDCWD, - newname.c_str(), 0), - SyscallSucceeds()); - - EXPECT_TRUE(IsSameFile(oldname, newname)); - - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); - EXPECT_THAT(unlink(oldname.c_str()), SyscallSucceeds()); - EXPECT_THAT(rmdir(oldname_parent.c_str()), SyscallSucceeds()); -} - -TEST(LinkTest, BogusFlags) { - ASSERT_THAT(linkat(1, "foo", 2, "bar", 3), SyscallFailsWithErrno(EINVAL)); -} - -TEST(LinkTest, WithNewDirFD) { - auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const std::string newname_parent = NewTempAbsPath(); - const std::string newname_base = "child"; - const std::string newname = JoinPath(newname_parent, newname_base); - - // Create newname_parent directory, and get an FD. - EXPECT_THAT(mkdir(newname_parent.c_str(), 0777), SyscallSucceeds()); - const FileDescriptor newname_parent_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(newname_parent, O_DIRECTORY | O_RDONLY)); - - // Link newname to oldfile, using newname_parent_fd. - EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), newname_parent_fd.get(), - newname.c_str(), 0), - SyscallSucceeds()); - - EXPECT_TRUE(IsSameFile(oldfile.path(), newname)); - - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); - EXPECT_THAT(rmdir(newname_parent.c_str()), SyscallSucceeds()); -} - -TEST(LinkTest, RelPathsWithNonDirFDs) { - auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Create a file that will be passed as the directory fd for old/new names. - const std::string filename = NewTempAbsPath(); - const FileDescriptor file_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT | O_RDWR, 0666)); - - // Using file_fd as olddirfd will fail. - EXPECT_THAT(linkat(file_fd.get(), "foo", AT_FDCWD, "bar", 0), - SyscallFailsWithErrno(ENOTDIR)); - - // Using file_fd as newdirfd will fail. - EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), file_fd.get(), "bar", 0), - SyscallFailsWithErrno(ENOTDIR)); -} - -TEST(LinkTest, AbsPathsWithNonDirFDs) { - auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const std::string newname = NewTempAbsPath(); - - // Create a file that will be passed as the directory fd for old/new names. - const std::string filename = NewTempAbsPath(); - const FileDescriptor file_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT | O_RDWR, 0666)); - - // Using file_fd as the dirfds is OK as long as paths are absolute. - EXPECT_THAT(linkat(file_fd.get(), oldfile.path().c_str(), file_fd.get(), - newname.c_str(), 0), - SyscallSucceeds()); -} - -TEST(LinkTest, LinkDoesNotFollowSymlinks) { - // Create oldfile, and oldsymlink which points to it. - auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const std::string oldsymlink = NewTempAbsPath(); - EXPECT_THAT(symlink(oldfile.path().c_str(), oldsymlink.c_str()), - SyscallSucceeds()); - - // Now hard link newname to oldsymlink. - const std::string newname = NewTempAbsPath(); - EXPECT_THAT(link(oldsymlink.c_str(), newname.c_str()), SyscallSucceeds()); - - // The link should not have resolved the symlink, so newname and oldsymlink - // are the same. - EXPECT_TRUE(IsSameFile(oldsymlink, newname)); - EXPECT_FALSE(IsSameFile(oldfile.path(), newname)); - - EXPECT_THAT(unlink(oldsymlink.c_str()), SyscallSucceeds()); - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); -} - -TEST(LinkTest, LinkatDoesNotFollowSymlinkByDefault) { - // Create oldfile, and oldsymlink which points to it. - auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const std::string oldsymlink = NewTempAbsPath(); - EXPECT_THAT(symlink(oldfile.path().c_str(), oldsymlink.c_str()), - SyscallSucceeds()); - - // Now hard link newname to oldsymlink. - const std::string newname = NewTempAbsPath(); - EXPECT_THAT( - linkat(AT_FDCWD, oldsymlink.c_str(), AT_FDCWD, newname.c_str(), 0), - SyscallSucceeds()); - - // The link should not have resolved the symlink, so newname and oldsymlink - // are the same. - EXPECT_TRUE(IsSameFile(oldsymlink, newname)); - EXPECT_FALSE(IsSameFile(oldfile.path(), newname)); - - EXPECT_THAT(unlink(oldsymlink.c_str()), SyscallSucceeds()); - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); -} - -TEST(LinkTest, LinkatWithSymlinkFollow) { - // Create oldfile, and oldsymlink which points to it. - auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const std::string oldsymlink = NewTempAbsPath(); - ASSERT_THAT(symlink(oldfile.path().c_str(), oldsymlink.c_str()), - SyscallSucceeds()); - - // Now hard link newname to oldsymlink, and pass AT_SYMLINK_FOLLOW flag. - const std::string newname = NewTempAbsPath(); - ASSERT_THAT(linkat(AT_FDCWD, oldsymlink.c_str(), AT_FDCWD, newname.c_str(), - AT_SYMLINK_FOLLOW), - SyscallSucceeds()); - - // The link should have resolved the symlink, so oldfile and newname are the - // same. - EXPECT_TRUE(IsSameFile(oldfile.path(), newname)); - EXPECT_FALSE(IsSameFile(oldsymlink, newname)); - - EXPECT_THAT(unlink(oldsymlink.c_str()), SyscallSucceeds()); - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc deleted file mode 100644 index a8af8e545..000000000 --- a/test/syscalls/linux/lseek.cc +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <stdlib.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(LseekTest, InvalidWhence) { - const std::string kFileData = "hello world\n"; - const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644)); - - ASSERT_THAT(lseek(fd.get(), 0, -1), SyscallFailsWithErrno(EINVAL)); -} - -TEST(LseekTest, NegativeOffset) { - const std::string kFileData = "hello world\n"; - const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644)); - - EXPECT_THAT(lseek(fd.get(), -(kFileData.length() + 1), SEEK_CUR), - SyscallFailsWithErrno(EINVAL)); -} - -// A 32-bit off_t is not large enough to represent an offset larger than -// maximum file size on standard file systems, so it isn't possible to cause -// overflow. -#ifdef __x86_64__ -TEST(LseekTest, Overflow) { - // HA! Classic Linux. We really should have an EOVERFLOW - // here, since we're seeking to something that cannot be - // represented.. but instead we are given an EINVAL. - const std::string kFileData = "hello world\n"; - const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644)); - EXPECT_THAT(lseek(fd.get(), 0x7fffffffffffffff, SEEK_END), - SyscallFailsWithErrno(EINVAL)); -} -#endif - -TEST(LseekTest, Set) { - const std::string kFileData = "hello world\n"; - const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644)); - - char buf = '\0'; - EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1)); - EXPECT_EQ(buf, kFileData.c_str()[0]); - EXPECT_THAT(lseek(fd.get(), 6, SEEK_SET), SyscallSucceedsWithValue(6)); - ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1)); - EXPECT_EQ(buf, kFileData.c_str()[6]); -} - -TEST(LseekTest, Cur) { - const std::string kFileData = "hello world\n"; - const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644)); - - char buf = '\0'; - EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1)); - EXPECT_EQ(buf, kFileData.c_str()[0]); - EXPECT_THAT(lseek(fd.get(), 3, SEEK_CUR), SyscallSucceedsWithValue(4)); - ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1)); - EXPECT_EQ(buf, kFileData.c_str()[4]); -} - -TEST(LseekTest, End) { - const std::string kFileData = "hello world\n"; - const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644)); - - char buf = '\0'; - EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1)); - EXPECT_EQ(buf, kFileData.c_str()[0]); - EXPECT_THAT(lseek(fd.get(), -2, SEEK_END), SyscallSucceedsWithValue(10)); - ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1)); - EXPECT_EQ(buf, kFileData.c_str()[kFileData.length() - 2]); -} - -TEST(LseekTest, InvalidFD) { - EXPECT_THAT(lseek(-1, 0, SEEK_SET), SyscallFailsWithErrno(EBADF)); -} - -TEST(LseekTest, DirCurEnd) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/tmp", O_RDONLY)); - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); -} - -TEST(LseekTest, ProcDir) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self", O_RDONLY)); - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); - ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds()); -} - -TEST(LseekTest, ProcFile) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/meminfo", O_RDONLY)); - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); - ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallFailsWithErrno(EINVAL)); -} - -TEST(LseekTest, SysDir) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/sys/devices", O_RDONLY)); - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); - ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds()); -} - -TEST(LseekTest, SeekCurrentDir) { - // From include/linux/fs.h. - constexpr loff_t MAX_LFS_FILESIZE = 0x7fffffffffffffff; - - char* dir = get_current_dir_name(); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir, O_RDONLY)); - - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); - ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), - // Some filesystems (like ext4) allow lseek(SEEK_END) on a - // directory and return MAX_LFS_FILESIZE, others return EINVAL. - AnyOf(SyscallSucceedsWithValue(MAX_LFS_FILESIZE), - SyscallFailsWithErrno(EINVAL))); - free(dir); -} - -TEST(LseekTest, ProcStatTwice) { - const FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/stat", O_RDONLY)); - const FileDescriptor fd2 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/stat", O_RDONLY)); - - ASSERT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - ASSERT_THAT(lseek(fd1.get(), 0, SEEK_END), SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT(lseek(fd1.get(), 1000, SEEK_CUR), SyscallSucceeds()); - // Check that just because we moved fd1, fd2 doesn't move. - ASSERT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - - const FileDescriptor fd3 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/stat", O_RDONLY)); - ASSERT_THAT(lseek(fd3.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); -} - -TEST(LseekTest, EtcPasswdDup) { - const FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/etc/passwd", O_RDONLY)); - const FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(fd1.Dup()); - - ASSERT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - ASSERT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - ASSERT_THAT(lseek(fd1.get(), 1000, SEEK_CUR), SyscallSucceeds()); - // Check that just because we moved fd1, fd2 doesn't move. - ASSERT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(1000)); - - const FileDescriptor fd3 = ASSERT_NO_ERRNO_AND_VALUE(fd1.Dup()); - ASSERT_THAT(lseek(fd3.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(1000)); -} - -// TODO(magi): Add tests where we have donated in sockets. - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc deleted file mode 100644 index 5a1973f60..000000000 --- a/test/syscalls/linux/madvise.cc +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdlib.h> -#include <string.h> -#include <sys/mman.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <sys/wait.h> -#include <unistd.h> - -#include <string> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/memory_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -void ExpectAllMappingBytes(Mapping const& m, char c) { - auto const v = m.view(); - for (size_t i = 0; i < kPageSize; i++) { - ASSERT_EQ(v[i], c) << "at offset " << i; - } -} - -// Equivalent to ExpectAllMappingBytes but async-signal-safe and with less -// helpful failure messages. -void CheckAllMappingBytes(Mapping const& m, char c) { - auto const v = m.view(); - for (size_t i = 0; i < kPageSize; i++) { - TEST_CHECK_MSG(v[i] == c, "mapping contains wrong value"); - } -} - -TEST(MadviseDontneedTest, ZerosPrivateAnonPage) { - auto m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - ExpectAllMappingBytes(m, 0); - memset(m.ptr(), 1, m.len()); - ExpectAllMappingBytes(m, 1); - ASSERT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds()); - ExpectAllMappingBytes(m, 0); -} - -TEST(MadviseDontneedTest, ZerosCOWAnonPageInCallerOnly) { - auto m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - ExpectAllMappingBytes(m, 0); - memset(m.ptr(), 2, m.len()); - ExpectAllMappingBytes(m, 2); - - // Do madvise in a child process. - pid_t pid = fork(); - CheckAllMappingBytes(m, 2); - if (pid == 0) { - TEST_PCHECK(madvise(m.ptr(), m.len(), MADV_DONTNEED) == 0); - CheckAllMappingBytes(m, 0); - _exit(0); - } - - ASSERT_THAT(pid, SyscallSucceeds()); - - int status = 0; - ASSERT_THAT(waitpid(-1, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(WEXITSTATUS(status), 0); - // The child's madvise should not have affected the parent. - ExpectAllMappingBytes(m, 2); -} - -TEST(MadviseDontneedTest, DoesNotModifySharedAnonPage) { - auto m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED)); - ExpectAllMappingBytes(m, 0); - memset(m.ptr(), 3, m.len()); - ExpectAllMappingBytes(m, 3); - ASSERT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds()); - ExpectAllMappingBytes(m, 3); -} - -TEST(MadviseDontneedTest, CleansPrivateFilePage) { - TempPath f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - /* parent = */ GetAbsoluteTestTmpdir(), - /* content = */ std::string(kPageSize, 4), TempPath::kDefaultFileMode)); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR)); - - Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd.get(), 0)); - ExpectAllMappingBytes(m, 4); - memset(m.ptr(), 5, m.len()); - ExpectAllMappingBytes(m, 5); - ASSERT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds()); - ExpectAllMappingBytes(m, 4); -} - -TEST(MadviseDontneedTest, DoesNotModifySharedFilePage) { - TempPath f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - /* parent = */ GetAbsoluteTestTmpdir(), - /* content = */ std::string(kPageSize, 6), TempPath::kDefaultFileMode)); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR)); - - Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0)); - ExpectAllMappingBytes(m, 6); - memset(m.ptr(), 7, m.len()); - ExpectAllMappingBytes(m, 7); - ASSERT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds()); - ExpectAllMappingBytes(m, 7); -} - -TEST(MadviseDontneedTest, IgnoresPermissions) { - auto m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE)); - EXPECT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds()); -} - -TEST(MadviseDontforkTest, AddressLength) { - auto m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE)); - char* addr = static_cast<char*>(m.ptr()); - - // Address must be page aligned. - EXPECT_THAT(madvise(addr + 1, kPageSize, MADV_DONTFORK), - SyscallFailsWithErrno(EINVAL)); - - // Zero length madvise always succeeds. - EXPECT_THAT(madvise(addr, 0, MADV_DONTFORK), SyscallSucceeds()); - - // Length must not roll over after rounding up. - size_t badlen = std::numeric_limits<std::size_t>::max() - (kPageSize / 2); - EXPECT_THAT(madvise(0, badlen, MADV_DONTFORK), SyscallFailsWithErrno(EINVAL)); - - // Length need not be page aligned - it is implicitly rounded up. - EXPECT_THAT(madvise(addr, 1, MADV_DONTFORK), SyscallSucceeds()); - EXPECT_THAT(madvise(addr, kPageSize, MADV_DONTFORK), SyscallSucceeds()); -} - -TEST(MadviseDontforkTest, DontforkShared) { - // Mmap two shared file-backed pages and MADV_DONTFORK the second page. - TempPath f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - /* parent = */ GetAbsoluteTestTmpdir(), - /* content = */ std::string(kPageSize * 2, 2), - TempPath::kDefaultFileMode)); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR)); - - Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize * 2, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0)); - - const Mapping ms1 = Mapping(reinterpret_cast<void*>(m.addr()), kPageSize); - const Mapping ms2 = - Mapping(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize); - m.release(); - - ASSERT_THAT(madvise(ms2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds()); - - const auto rest = [&] { - // First page is mapped in child and modifications are visible to parent - // via the shared mapping. - TEST_CHECK(IsMapped(ms1.addr())); - ExpectAllMappingBytes(ms1, 2); - memset(ms1.ptr(), 1, kPageSize); - ExpectAllMappingBytes(ms1, 1); - - // Second page must not be mapped in child. - TEST_CHECK(!IsMapped(ms2.addr())); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); - - ExpectAllMappingBytes(ms1, 1); // page contents modified by child. - ExpectAllMappingBytes(ms2, 2); // page contents unchanged. -} - -TEST(MadviseDontforkTest, DontforkAnonPrivate) { - // Mmap three anonymous pages and MADV_DONTFORK the middle page. - Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize * 3, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - const Mapping mp1 = Mapping(reinterpret_cast<void*>(m.addr()), kPageSize); - const Mapping mp2 = - Mapping(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize); - const Mapping mp3 = - Mapping(reinterpret_cast<void*>(m.addr() + 2 * kPageSize), kPageSize); - m.release(); - - ASSERT_THAT(madvise(mp2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds()); - - // Verify that all pages are zeroed and memset the first, second and third - // pages to 1, 2, and 3 respectively. - ExpectAllMappingBytes(mp1, 0); - memset(mp1.ptr(), 1, kPageSize); - - ExpectAllMappingBytes(mp2, 0); - memset(mp2.ptr(), 2, kPageSize); - - ExpectAllMappingBytes(mp3, 0); - memset(mp3.ptr(), 3, kPageSize); - - const auto rest = [&] { - // Verify first page is mapped, verify its contents and then modify the - // page. The mapping is private so the modifications are not visible to - // the parent. - TEST_CHECK(IsMapped(mp1.addr())); - ExpectAllMappingBytes(mp1, 1); - memset(mp1.ptr(), 11, kPageSize); - ExpectAllMappingBytes(mp1, 11); - - // Verify second page is not mapped. - TEST_CHECK(!IsMapped(mp2.addr())); - - // Verify third page is mapped, verify its contents and then modify the - // page. The mapping is private so the modifications are not visible to - // the parent. - TEST_CHECK(IsMapped(mp3.addr())); - ExpectAllMappingBytes(mp3, 3); - memset(mp3.ptr(), 13, kPageSize); - ExpectAllMappingBytes(mp3, 13); - }; - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); - - // The fork and COW by child should not affect the parent mappings. - ExpectAllMappingBytes(mp1, 1); - ExpectAllMappingBytes(mp2, 2); - ExpectAllMappingBytes(mp3, 3); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc deleted file mode 100644 index e57b49a4a..000000000 --- a/test/syscalls/linux/memfd.cc +++ /dev/null @@ -1,556 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <linux/magic.h> -#include <linux/memfd.h> -#include <string.h> -#include <sys/mman.h> -#include <sys/statfs.h> -#include <sys/syscall.h> - -#include <vector> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/memory_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -// The header sys/memfd.h isn't available on all systems, so redefining some of -// the constants here. -#define F_LINUX_SPECIFIC_BASE 1024 - -#ifndef F_ADD_SEALS -#define F_ADD_SEALS (F_LINUX_SPECIFIC_BASE + 9) -#endif /* F_ADD_SEALS */ - -#ifndef F_GET_SEALS -#define F_GET_SEALS (F_LINUX_SPECIFIC_BASE + 10) -#endif /* F_GET_SEALS */ - -#define F_SEAL_SEAL 0x0001 -#define F_SEAL_SHRINK 0x0002 -#define F_SEAL_GROW 0x0004 -#define F_SEAL_WRITE 0x0008 - -using ::testing::StartsWith; - -const std::string kMemfdName = "some-memfd"; - -int memfd_create(const std::string& name, unsigned int flags) { - return syscall(__NR_memfd_create, name.c_str(), flags); -} - -PosixErrorOr<FileDescriptor> MemfdCreate(const std::string& name, - uint32_t flags) { - int fd = memfd_create(name, flags); - if (fd < 0) { - return PosixError( - errno, absl::StrFormat("memfd_create(\"%s\", %#x)", name, flags)); - } - MaybeSave(); - return FileDescriptor(fd); -} - -// Procfs entries for memfds display the appropriate name. -TEST(MemfdTest, Name) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0)); - const std::string proc_name = ASSERT_NO_ERRNO_AND_VALUE( - ReadLink(absl::StrFormat("/proc/self/fd/%d", memfd.get()))); - EXPECT_THAT(proc_name, StartsWith("/memfd:" + kMemfdName)); -} - -// Memfds support read/write syscalls. -TEST(MemfdTest, WriteRead) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0)); - - // Write a random page of data to the memfd via write(2). - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Read back the same data and verify. - std::vector<char> buf2(kPageSize); - ASSERT_THAT(lseek(memfd.get(), 0, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT(read(memfd.get(), buf2.data(), buf2.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(buf, buf2); -} - -// Memfds can be mapped and used as usual. -TEST(MemfdTest, Mmap) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0)); - const Mapping m1 = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0)); - - // Write a random page of data to the memfd via mmap m1. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds()); - memcpy(m1.ptr(), buf.data(), buf.size()); - - // Read the data back via a read syscall on the memfd. - std::vector<char> buf2(kPageSize); - EXPECT_THAT(read(memfd.get(), buf2.data(), buf2.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(buf, buf2); - - // The same data should be accessible via a new mapping m2. - const Mapping m2 = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0)); - EXPECT_EQ(0, memcmp(m1.ptr(), m2.ptr(), kPageSize)); -} - -TEST(MemfdTest, DuplicateFDsShareContent) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0)); - const Mapping m1 = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0)); - const FileDescriptor memfd2 = ASSERT_NO_ERRNO_AND_VALUE(memfd.Dup()); - - // Write a random page of data to the memfd via mmap m1. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds()); - memcpy(m1.ptr(), buf.data(), buf.size()); - - // Read the data back via a read syscall on a duplicate fd. - std::vector<char> buf2(kPageSize); - EXPECT_THAT(read(memfd2.get(), buf2.data(), buf2.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(buf, buf2); -} - -// File seals are disabled by default on memfds. -TEST(MemfdTest, SealingDisabledByDefault) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0)); - EXPECT_THAT(fcntl(memfd.get(), F_GET_SEALS), - SyscallSucceedsWithValue(F_SEAL_SEAL)); - // Attempting to set any seal should fail. - EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), - SyscallFailsWithErrno(EPERM)); -} - -// Seals can be retrieved and updated for memfds. -TEST(MemfdTest, SealsGetSet) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - int seals; - ASSERT_THAT(seals = fcntl(memfd.get(), F_GET_SEALS), SyscallSucceeds()); - // No seals are set yet. - EXPECT_EQ(0, seals); - - // Set a seal and check that we can get it back. - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds()); - EXPECT_THAT(fcntl(memfd.get(), F_GET_SEALS), - SyscallSucceedsWithValue(F_SEAL_WRITE)); - - // Set some more seals and verify. - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW | F_SEAL_SHRINK), - SyscallSucceeds()); - EXPECT_THAT( - fcntl(memfd.get(), F_GET_SEALS), - SyscallSucceedsWithValue(F_SEAL_WRITE | F_SEAL_GROW | F_SEAL_SHRINK)); - - // Attempting to set a seal that is already set is a no-op. - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds()); - EXPECT_THAT( - fcntl(memfd.get(), F_GET_SEALS), - SyscallSucceedsWithValue(F_SEAL_WRITE | F_SEAL_GROW | F_SEAL_SHRINK)); - - // Add remaining seals and verify. - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_SEAL), SyscallSucceeds()); - EXPECT_THAT(fcntl(memfd.get(), F_GET_SEALS), - SyscallSucceedsWithValue(F_SEAL_WRITE | F_SEAL_GROW | - F_SEAL_SHRINK | F_SEAL_SEAL)); -} - -// F_SEAL_GROW prevents a memfd from being grown using ftruncate. -TEST(MemfdTest, SealGrowWithTruncate) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds()); - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds()); - - // Try grow the memfd by 1 page. - ASSERT_THAT(ftruncate(memfd.get(), kPageSize * 2), - SyscallFailsWithErrno(EPERM)); - - // Ftruncate calls that don't actually grow the memfd are allowed. - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds()); - ASSERT_THAT(ftruncate(memfd.get(), kPageSize / 2), SyscallSucceeds()); - - // After shrinking, growing back is not allowed. - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallFailsWithErrno(EPERM)); -} - -// F_SEAL_GROW prevents a memfd from being grown using the write syscall. -TEST(MemfdTest, SealGrowWithWrite) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - - // Initially, writing to the memfd succeeds. - const std::vector<char> buf(kPageSize); - EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Apply F_SEAL_GROW, subsequent writes which extend the memfd should fail. - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds()); - EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallFailsWithErrno(EPERM)); - - // However, zero-length writes are ok since they don't grow the memfd. - EXPECT_THAT(write(memfd.get(), buf.data(), 0), SyscallSucceeds()); - - // Writing to existing parts of the memfd is also ok. - ASSERT_THAT(lseek(memfd.get(), 0, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Returning the end of the file and writing still not allowed. - EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallFailsWithErrno(EPERM)); -} - -// F_SEAL_GROW causes writes which partially extend off the current EOF to -// partially succeed, up to the page containing the EOF. -TEST(MemfdTest, SealGrowPartialWriteTruncated) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds()); - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds()); - - // FD offset: 1 page, EOF: 1 page. - - ASSERT_THAT(lseek(memfd.get(), kPageSize * 3 / 4, SEEK_SET), - SyscallSucceeds()); - - // FD offset: 3/4 page. Writing a full page now should only write 1/4 page - // worth of data. This partially succeeds because the first page is entirely - // within the file and requires no growth, but attempting to write the final - // 3/4 page would require growing the file. - const std::vector<char> buf(kPageSize); - EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize / 4)); -} - -// F_SEAL_GROW causes writes which partially extend off the current EOF to fail -// in its entirety if the only data written would be to the page containing the -// EOF. -TEST(MemfdTest, SealGrowPartialWriteTruncatedSamePage) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - ASSERT_THAT(ftruncate(memfd.get(), kPageSize * 3 / 4), SyscallSucceeds()); - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds()); - - // EOF: 3/4 page, writing 1/2 page starting at 1/2 page would cause the file - // to grow. Since this would require only the page containing the EOF to be - // modified, the write is rejected entirely. - const std::vector<char> buf(kPageSize / 2); - EXPECT_THAT(pwrite(memfd.get(), buf.data(), buf.size(), kPageSize / 2), - SyscallFailsWithErrno(EPERM)); - - // However, writing up to EOF is fine. - EXPECT_THAT(pwrite(memfd.get(), buf.data(), buf.size() / 2, kPageSize / 2), - SyscallSucceedsWithValue(kPageSize / 4)); -} - -// F_SEAL_SHRINK prevents a memfd from being shrunk using ftruncate. -TEST(MemfdTest, SealShrink) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds()); - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_SHRINK), - SyscallSucceeds()); - - // Shrink by half a page. - ASSERT_THAT(ftruncate(memfd.get(), kPageSize / 2), - SyscallFailsWithErrno(EPERM)); - - // Ftruncate calls that don't actually shrink the file are allowed. - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds()); - ASSERT_THAT(ftruncate(memfd.get(), kPageSize * 2), SyscallSucceeds()); - - // After growing, shrinking is still not allowed. - ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallFailsWithErrno(EPERM)); -} - -// F_SEAL_WRITE prevents a memfd from being written to through a write -// syscall. -TEST(MemfdTest, SealWriteWithWrite) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - const std::vector<char> buf(kPageSize); - ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds()); - - // Attemping to write at the end of the file fails. - EXPECT_THAT(write(memfd.get(), buf.data(), 1), SyscallFailsWithErrno(EPERM)); - - // Attemping to overwrite an existing part of the memfd fails. - EXPECT_THAT(pwrite(memfd.get(), buf.data(), 1, 0), - SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(pwrite(memfd.get(), buf.data(), buf.size() / 2, kPageSize / 2), - SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(pwrite(memfd.get(), buf.data(), buf.size(), kPageSize / 2), - SyscallFailsWithErrno(EPERM)); - - // Zero-length writes however do not fail. - EXPECT_THAT(write(memfd.get(), buf.data(), 0), SyscallSucceeds()); -} - -// F_SEAL_WRITE prevents a memfd from being written to through an mmap. -TEST(MemfdTest, SealWriteWithMmap) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - const std::vector<char> buf(kPageSize); - ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds()); - - // Can't create a shared mapping with writes sealed. - void* ret = mmap(nullptr, kPageSize, PROT_WRITE, MAP_SHARED, memfd.get(), 0); - EXPECT_EQ(ret, MAP_FAILED); - EXPECT_EQ(errno, EPERM); - ret = mmap(nullptr, kPageSize, PROT_READ, MAP_SHARED, memfd.get(), 0); - EXPECT_EQ(ret, MAP_FAILED); - EXPECT_EQ(errno, EPERM); - - // However, private mappings are ok. - EXPECT_NO_ERRNO(Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - memfd.get(), 0)); -} - -// Adding F_SEAL_WRITE fails when there are outstanding writable mappings to a -// memfd. -TEST(MemfdTest, SealWriteWithOutstandingWritbleMapping) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - const std::vector<char> buf(kPageSize); - ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Attempting to add F_SEAL_WRITE with active shared mapping with any set of - // permissions fails. - - // Read-only shared mapping. - { - const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(nullptr, kPageSize, PROT_READ, MAP_SHARED, memfd.get(), 0)); - EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), - SyscallFailsWithErrno(EBUSY)); - } - - // Write-only shared mapping. - { - const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(nullptr, kPageSize, PROT_WRITE, MAP_SHARED, memfd.get(), 0)); - EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), - SyscallFailsWithErrno(EBUSY)); - } - - // Read-write shared mapping. - { - const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - memfd.get(), 0)); - EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), - SyscallFailsWithErrno(EBUSY)); - } - - // F_SEAL_WRITE can be set with private mappings with any permissions. - { - const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - memfd.get(), 0)); - EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), - SyscallSucceeds()); - } -} - -// When applying F_SEAL_WRITE fails due to outstanding writable mappings, any -// additional seals passed to the same add seal call are also rejected. -TEST(MemfdTest, NoPartialSealApplicationWhenWriteSealRejected) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0)); - - // Try add some seals along with F_SEAL_WRITE. The seal application should - // fail since there exists an active shared mapping. - EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE | F_SEAL_GROW), - SyscallFailsWithErrno(EBUSY)); - - // None of the seals should be applied. - EXPECT_THAT(fcntl(memfd.get(), F_GET_SEALS), SyscallSucceedsWithValue(0)); -} - -// Seals are inode level properties, and apply to all file descriptors referring -// to a memfd. -TEST(MemfdTest, SealsAreInodeLevelProperties) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - const FileDescriptor memfd2 = ASSERT_NO_ERRNO_AND_VALUE(memfd.Dup()); - - // Add seal through the original memfd, and verify that it appears on the - // dupped fd. - ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds()); - EXPECT_THAT(fcntl(memfd2.get(), F_GET_SEALS), - SyscallSucceedsWithValue(F_SEAL_WRITE)); - - // Verify the seal actually applies to both fds. - std::vector<char> buf(kPageSize); - EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(write(memfd2.get(), buf.data(), buf.size()), - SyscallFailsWithErrno(EPERM)); - - // Seals are enforced on new FDs that are dupped after the seal is already - // applied. - const FileDescriptor memfd3 = ASSERT_NO_ERRNO_AND_VALUE(memfd2.Dup()); - EXPECT_THAT(write(memfd3.get(), buf.data(), buf.size()), - SyscallFailsWithErrno(EPERM)); - - // Try a new seal applied to one of the dupped fds. - ASSERT_THAT(fcntl(memfd3.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds()); - EXPECT_THAT(ftruncate(memfd.get(), kPageSize), SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(ftruncate(memfd2.get(), kPageSize), SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(ftruncate(memfd3.get(), kPageSize), SyscallFailsWithErrno(EPERM)); -} - -PosixErrorOr<bool> IsTmpfs(const std::string& path) { - struct statfs stat; - if (statfs(path.c_str(), &stat)) { - if (errno == ENOENT) { - // Nothing at path, don't raise this as an error. Instead, just report no - // tmpfs at path. - return false; - } - return PosixError(errno, - absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); - } - return stat.f_type == TMPFS_MAGIC; -} - -// Tmpfs files also support seals, but are created with F_SEAL_SEAL. -TEST(MemfdTest, TmpfsFilesHaveSealSeal) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs("/tmp"))); - const TempPath tmpfs_file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn("/tmp")); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfs_file.path(), O_RDWR, 0644)); - EXPECT_THAT(fcntl(fd.get(), F_GET_SEALS), - SyscallSucceedsWithValue(F_SEAL_SEAL)); -} - -// Can open a memfd from procfs and use as normal. -TEST(MemfdTest, CanOpenFromProcfs) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - - // Write a random page of data to the memfd via write(2). - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Read back the same data from the fd obtained from procfs and verify. - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(absl::StrFormat("/proc/self/fd/%d", memfd.get()), O_RDWR)); - std::vector<char> buf2(kPageSize); - EXPECT_THAT(pread(fd.get(), buf2.data(), buf2.size(), 0), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(buf, buf2); -} - -// Test that memfd permissions are set up correctly to allow another process to -// open it from procfs. -TEST(MemfdTest, OtherProcessCanOpenFromProcfs) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - const auto memfd_path = - absl::StrFormat("/proc/%d/fd/%d", getpid(), memfd.get()); - const auto rest = [&] { - int fd = open(memfd_path.c_str(), O_RDWR); - TEST_PCHECK(fd >= 0); - TEST_PCHECK(close(fd) >= 0); - }; - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -// Test that only files opened as writable can have seals applied to them. -// Normally there's no way to specify file permissions on memfds, but we can -// obtain a read-only memfd by opening the corresponding procfs fd entry as -// read-only. -TEST(MemfdTest, MemfdMustBeWritableToModifySeals) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING)); - - // Initially adding a seal works. - EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds()); - - // Re-open the memfd as read-only from procfs. - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(absl::StrFormat("/proc/self/fd/%d", memfd.get()), O_RDONLY)); - - // Can't add seals through an unwritable fd. - EXPECT_THAT(fcntl(fd.get(), F_ADD_SEALS, F_SEAL_GROW), - SyscallFailsWithErrno(EPERM)); -} - -// Test that the memfd implementation internally tracks potentially writable -// maps correctly. -TEST(MemfdTest, MultipleWritableAndNonWritableRefsToSameFileRegion) { - const FileDescriptor memfd = - ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0)); - - // Populate with a random page of data. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Read-only map to the page. This should cause an initial mapping to be - // created. - Mapping m1 = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(nullptr, kPageSize, PROT_READ, MAP_PRIVATE, memfd.get(), 0)); - - // Create a shared writable map to the page. This should cause the internal - // mapping to become potentially writable. - Mapping m2 = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0)); - - // Drop the read-only mapping first. If writable-ness isn't tracked correctly, - // this can cause some misaccounting, which can trigger asserts internally. - m1.reset(); - m2.reset(); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc deleted file mode 100644 index 94aea4077..000000000 --- a/test/syscalls/linux/memory_accounting.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/mman.h> - -#include <map> - -#include "gtest/gtest.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -using ::absl::StrFormat; - -// AnonUsageFromMeminfo scrapes the current anonymous memory usage from -// /proc/meminfo and returns it in bytes. -PosixErrorOr<uint64_t> AnonUsageFromMeminfo() { - ASSIGN_OR_RETURN_ERRNO(auto meminfo, GetContents("/proc/meminfo")); - std::vector<std::string> lines(absl::StrSplit(meminfo, '\n')); - - // Try to find AnonPages line, the format is AnonPages:\\s+(\\d+) kB\n. - for (const auto& line : lines) { - if (!absl::StartsWith(line, "AnonPages:")) { - continue; - } - - std::vector<std::string> parts( - absl::StrSplit(line, ' ', absl::SkipEmpty())); - if (parts.size() == 3) { - // The size is the second field, let's try to parse it as a number. - ASSIGN_OR_RETURN_ERRNO(auto anon_kb, Atoi<uint64_t>(parts[1])); - return anon_kb * 1024; - } - - return PosixError(EINVAL, "AnonPages field in /proc/meminfo was malformed"); - } - - return PosixError(EINVAL, "AnonPages field not found in /proc/meminfo"); -} - -TEST(MemoryAccounting, AnonAccountingPreservedOnSaveRestore) { - // This test isn't meaningful on Linux. /proc/meminfo reports system-wide - // memory usage, which can change arbitrarily in Linux from other activity on - // the machine. In gvisor, this test is the only thing running on the - // "machine", so values in /proc/meminfo accurately reflect the memory used by - // the test. - SKIP_IF(!IsRunningOnGvisor()); - - uint64_t anon_initial = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo()); - - // Cause some anonymous memory usage. - uint64_t map_bytes = Megabytes(512); - char* mem = - static_cast<char*>(mmap(nullptr, map_bytes, PROT_READ | PROT_WRITE, - MAP_POPULATE | MAP_ANON | MAP_PRIVATE, -1, 0)); - ASSERT_NE(mem, MAP_FAILED) - << "Map failed, errno: " << errno << " (" << strerror(errno) << ")."; - - // Write something to each page to prevent them from being decommited on - // S/R. Zero pages are dropped on save. - for (uint64_t i = 0; i < map_bytes; i += kPageSize) { - mem[i] = 'a'; - } - - uint64_t anon_after_alloc = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo()); - EXPECT_THAT(anon_after_alloc, - EquivalentWithin(anon_initial + map_bytes, 0.03)); - - // We have many implicit S/R cycles from scraping /proc/meminfo throughout the - // test, but throw an explicit S/R in here as well. - MaybeSave(); - - // Usage should remain the same across S/R. - uint64_t anon_after_sr = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo()); - EXPECT_THAT(anon_after_sr, EquivalentWithin(anon_after_alloc, 0.03)); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/mempolicy.cc b/test/syscalls/linux/mempolicy.cc deleted file mode 100644 index 059fad598..000000000 --- a/test/syscalls/linux/mempolicy.cc +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <sys/syscall.h> - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "test/util/cleanup.h" -#include "test/util/memory_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -#define BITS_PER_BYTE 8 - -#define MPOL_F_STATIC_NODES (1 << 15) -#define MPOL_F_RELATIVE_NODES (1 << 14) -#define MPOL_DEFAULT 0 -#define MPOL_PREFERRED 1 -#define MPOL_BIND 2 -#define MPOL_INTERLEAVE 3 -#define MPOL_LOCAL 4 -#define MPOL_F_NODE (1 << 0) -#define MPOL_F_ADDR (1 << 1) -#define MPOL_F_MEMS_ALLOWED (1 << 2) -#define MPOL_MF_STRICT (1 << 0) -#define MPOL_MF_MOVE (1 << 1) -#define MPOL_MF_MOVE_ALL (1 << 2) - -int get_mempolicy(int* policy, uint64_t* nmask, uint64_t maxnode, void* addr, - int flags) { - return syscall(SYS_get_mempolicy, policy, nmask, maxnode, addr, flags); -} - -int set_mempolicy(int mode, uint64_t* nmask, uint64_t maxnode) { - return syscall(SYS_set_mempolicy, mode, nmask, maxnode); -} - -int mbind(void* addr, unsigned long len, int mode, - const unsigned long* nodemask, unsigned long maxnode, - unsigned flags) { - return syscall(SYS_mbind, addr, len, mode, nodemask, maxnode, flags); -} - -// Creates a cleanup object that resets the calling thread's mempolicy to the -// system default when the calling scope ends. -Cleanup ScopedMempolicy() { - return Cleanup([] { - EXPECT_THAT(set_mempolicy(MPOL_DEFAULT, nullptr, 0), SyscallSucceeds()); - }); -} - -// Temporarily change the memory policy for the calling thread within the -// caller's scope. -PosixErrorOr<Cleanup> ScopedSetMempolicy(int mode, uint64_t* nmask, - uint64_t maxnode) { - if (set_mempolicy(mode, nmask, maxnode)) { - return PosixError(errno, "set_mempolicy"); - } - return ScopedMempolicy(); -} - -TEST(MempolicyTest, CheckDefaultPolicy) { - int mode = 0; - uint64_t nodemask = 0; - ASSERT_THAT(get_mempolicy(&mode, &nodemask, sizeof(nodemask) * BITS_PER_BYTE, - nullptr, 0), - SyscallSucceeds()); - - EXPECT_EQ(MPOL_DEFAULT, mode); - EXPECT_EQ(0x0, nodemask); -} - -TEST(MempolicyTest, PolicyPreservedAfterSetMempolicy) { - uint64_t nodemask = 0x1; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetMempolicy( - MPOL_BIND, &nodemask, sizeof(nodemask) * BITS_PER_BYTE)); - - int mode = 0; - uint64_t nodemask_after = 0x0; - ASSERT_THAT(get_mempolicy(&mode, &nodemask_after, - sizeof(nodemask_after) * BITS_PER_BYTE, nullptr, 0), - SyscallSucceeds()); - EXPECT_EQ(MPOL_BIND, mode); - EXPECT_EQ(0x1, nodemask_after); - - // Try throw in some mode flags. - for (auto mode_flag : {MPOL_F_STATIC_NODES, MPOL_F_RELATIVE_NODES}) { - auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE( - ScopedSetMempolicy(MPOL_INTERLEAVE | mode_flag, &nodemask, - sizeof(nodemask) * BITS_PER_BYTE)); - mode = 0; - nodemask_after = 0x0; - ASSERT_THAT( - get_mempolicy(&mode, &nodemask_after, - sizeof(nodemask_after) * BITS_PER_BYTE, nullptr, 0), - SyscallSucceeds()); - EXPECT_EQ(MPOL_INTERLEAVE | mode_flag, mode); - EXPECT_EQ(0x1, nodemask_after); - } -} - -TEST(MempolicyTest, SetMempolicyRejectsInvalidInputs) { - auto cleanup = ScopedMempolicy(); - uint64_t nodemask; - - if (IsRunningOnGvisor()) { - // Invalid nodemask, we only support a single node on gvisor. - nodemask = 0x4; - ASSERT_THAT(set_mempolicy(MPOL_DEFAULT, &nodemask, - sizeof(nodemask) * BITS_PER_BYTE), - SyscallFailsWithErrno(EINVAL)); - } - - nodemask = 0x1; - - // Invalid mode. - ASSERT_THAT(set_mempolicy(7439, &nodemask, sizeof(nodemask) * BITS_PER_BYTE), - SyscallFailsWithErrno(EINVAL)); - - // Invalid nodemask size. - ASSERT_THAT(set_mempolicy(MPOL_DEFAULT, &nodemask, 0), - SyscallFailsWithErrno(EINVAL)); - - // Invalid mode flag. - ASSERT_THAT( - set_mempolicy(MPOL_DEFAULT | MPOL_F_STATIC_NODES | MPOL_F_RELATIVE_NODES, - &nodemask, sizeof(nodemask) * BITS_PER_BYTE), - SyscallFailsWithErrno(EINVAL)); - - // MPOL_INTERLEAVE with empty nodemask. - nodemask = 0x0; - ASSERT_THAT(set_mempolicy(MPOL_INTERLEAVE, &nodemask, - sizeof(nodemask) * BITS_PER_BYTE), - SyscallFailsWithErrno(EINVAL)); -} - -// The manpages specify that the nodemask provided to set_mempolicy are -// considered empty if the nodemask pointer is null, or if the nodemask size is -// 0. We use a policy which accepts both empty and non-empty nodemasks -// (MPOL_PREFERRED), a policy which requires a non-empty nodemask (MPOL_BIND), -// and a policy which completely ignores the nodemask (MPOL_DEFAULT) to verify -// argument checking around nodemasks. -TEST(MempolicyTest, EmptyNodemaskOnSet) { - auto cleanup = ScopedMempolicy(); - - EXPECT_THAT(set_mempolicy(MPOL_DEFAULT, nullptr, 1), SyscallSucceeds()); - EXPECT_THAT(set_mempolicy(MPOL_BIND, nullptr, 1), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(set_mempolicy(MPOL_PREFERRED, nullptr, 1), SyscallSucceeds()); - - uint64_t nodemask = 0x1; - EXPECT_THAT(set_mempolicy(MPOL_DEFAULT, &nodemask, 0), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(set_mempolicy(MPOL_BIND, &nodemask, 0), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(set_mempolicy(MPOL_PREFERRED, &nodemask, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(MempolicyTest, QueryAvailableNodes) { - uint64_t nodemask = 0; - ASSERT_THAT( - get_mempolicy(nullptr, &nodemask, sizeof(nodemask) * BITS_PER_BYTE, - nullptr, MPOL_F_MEMS_ALLOWED), - SyscallSucceeds()); - // We can only be sure there is a single node if running on gvisor. - if (IsRunningOnGvisor()) { - EXPECT_EQ(0x1, nodemask); - } - - // MPOL_F_ADDR and MPOL_F_NODE flags may not be combined with - // MPOL_F_MEMS_ALLLOWED. - for (auto flags : - {MPOL_F_MEMS_ALLOWED | MPOL_F_ADDR, MPOL_F_MEMS_ALLOWED | MPOL_F_NODE, - MPOL_F_MEMS_ALLOWED | MPOL_F_ADDR | MPOL_F_NODE}) { - ASSERT_THAT(get_mempolicy(nullptr, &nodemask, - sizeof(nodemask) * BITS_PER_BYTE, nullptr, flags), - SyscallFailsWithErrno(EINVAL)); - } -} - -TEST(MempolicyTest, GetMempolicyQueryNodeForAddress) { - uint64_t dummy_stack_address; - auto dummy_heap_address = absl::make_unique<uint64_t>(); - int mode; - - for (auto ptr : {&dummy_stack_address, dummy_heap_address.get()}) { - mode = -1; - ASSERT_THAT( - get_mempolicy(&mode, nullptr, 0, ptr, MPOL_F_ADDR | MPOL_F_NODE), - SyscallSucceeds()); - // If we're not running on gvisor, the address may be allocated on a - // different numa node. - if (IsRunningOnGvisor()) { - EXPECT_EQ(0, mode); - } - } - - void* invalid_address = reinterpret_cast<void*>(-1); - - // Invalid address. - ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, invalid_address, - MPOL_F_ADDR | MPOL_F_NODE), - SyscallFailsWithErrno(EFAULT)); - - // Invalid mode pointer. - ASSERT_THAT(get_mempolicy(reinterpret_cast<int*>(invalid_address), nullptr, 0, - &dummy_stack_address, MPOL_F_ADDR | MPOL_F_NODE), - SyscallFailsWithErrno(EFAULT)); -} - -TEST(MempolicyTest, GetMempolicyCanOmitPointers) { - int mode; - uint64_t nodemask; - - // Omit nodemask pointer. - ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, nullptr, 0), SyscallSucceeds()); - // Omit mode pointer. - ASSERT_THAT(get_mempolicy(nullptr, &nodemask, - sizeof(nodemask) * BITS_PER_BYTE, nullptr, 0), - SyscallSucceeds()); - // Omit both pointers. - ASSERT_THAT(get_mempolicy(nullptr, nullptr, 0, nullptr, 0), - SyscallSucceeds()); -} - -TEST(MempolicyTest, GetMempolicyNextInterleaveNode) { - int mode; - // Policy for thread not yet set to MPOL_INTERLEAVE, can't query for - // the next node which will be used for allocation. - ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, nullptr, MPOL_F_NODE), - SyscallFailsWithErrno(EINVAL)); - - // Set default policy for thread to MPOL_INTERLEAVE. - uint64_t nodemask = 0x1; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetMempolicy( - MPOL_INTERLEAVE, &nodemask, sizeof(nodemask) * BITS_PER_BYTE)); - - mode = -1; - ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, nullptr, MPOL_F_NODE), - SyscallSucceeds()); - EXPECT_EQ(0, mode); -} - -TEST(MempolicyTest, Mbind) { - // Temporarily set the thread policy to MPOL_PREFERRED. - const auto cleanup_thread_policy = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSetMempolicy(MPOL_PREFERRED, nullptr, 0)); - - const auto mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS)); - - // vmas default to MPOL_DEFAULT irrespective of the thread policy (currently - // MPOL_PREFERRED). - int mode; - ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, mapping.ptr(), MPOL_F_ADDR), - SyscallSucceeds()); - EXPECT_EQ(mode, MPOL_DEFAULT); - - // Set MPOL_PREFERRED for the vma and read it back. - ASSERT_THAT( - mbind(mapping.ptr(), mapping.len(), MPOL_PREFERRED, nullptr, 0, 0), - SyscallSucceeds()); - ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, mapping.ptr(), MPOL_F_ADDR), - SyscallSucceeds()); - EXPECT_EQ(mode, MPOL_PREFERRED); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/mincore.cc b/test/syscalls/linux/mincore.cc deleted file mode 100644 index 5c1240c89..000000000 --- a/test/syscalls/linux/mincore.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <stdint.h> -#include <string.h> -#include <sys/mman.h> -#include <unistd.h> - -#include <algorithm> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/memory_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -size_t CountSetLSBs(std::vector<unsigned char> const& vec) { - return std::count_if(begin(vec), end(vec), - [](unsigned char c) { return (c & 1) != 0; }); -} - -TEST(MincoreTest, DirtyAnonPagesAreResident) { - constexpr size_t kTestPageCount = 10; - auto const kTestMappingBytes = kTestPageCount * kPageSize; - auto m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kTestMappingBytes, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - memset(m.ptr(), 0, m.len()); - - std::vector<unsigned char> vec(kTestPageCount, 0); - ASSERT_THAT(mincore(m.ptr(), kTestMappingBytes, vec.data()), - SyscallSucceeds()); - EXPECT_EQ(kTestPageCount, CountSetLSBs(vec)); -} - -TEST(MincoreTest, UnalignedAddressFails) { - // Map and touch two pages, then try to mincore the second half of the first - // page + the first half of the second page. Both pages are mapped, but - // mincore should return EINVAL due to the misaligned start address. - constexpr size_t kTestPageCount = 2; - auto const kTestMappingBytes = kTestPageCount * kPageSize; - auto m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kTestMappingBytes, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - memset(m.ptr(), 0, m.len()); - - std::vector<unsigned char> vec(kTestPageCount, 0); - EXPECT_THAT(mincore(reinterpret_cast<void*>(m.addr() + kPageSize / 2), - kPageSize, vec.data()), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(MincoreTest, UnalignedLengthSucceedsAndIsRoundedUp) { - // Map and touch two pages, then try to mincore the first page + the first - // half of the second page. mincore should silently round up the length to - // include both pages. - constexpr size_t kTestPageCount = 2; - auto const kTestMappingBytes = kTestPageCount * kPageSize; - auto m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kTestMappingBytes, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - memset(m.ptr(), 0, m.len()); - - std::vector<unsigned char> vec(kTestPageCount, 0); - ASSERT_THAT(mincore(m.ptr(), kPageSize + kPageSize / 2, vec.data()), - SyscallSucceeds()); - EXPECT_EQ(kTestPageCount, CountSetLSBs(vec)); -} - -TEST(MincoreTest, ZeroLengthSucceedsAndAllowsAnyVecBelowTaskSize) { - EXPECT_THAT(mincore(nullptr, 0, nullptr), SyscallSucceeds()); -} - -TEST(MincoreTest, InvalidLengthFails) { - EXPECT_THAT(mincore(nullptr, -1, nullptr), SyscallFailsWithErrno(ENOMEM)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc deleted file mode 100644 index def4c50a4..000000000 --- a/test/syscalls/linux/mkdir.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/capability_util.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/temp_umask.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class MkdirTest : public ::testing::Test { - protected: - // SetUp creates various configurations of files. - void SetUp() override { dirname_ = NewTempAbsPath(); } - - // TearDown unlinks created files. - void TearDown() override { - // FIXME(edahlgren): We don't currently implement rmdir. - // We do this unconditionally because there's no harm in trying. - rmdir(dirname_.c_str()); - } - - std::string dirname_; -}; - -TEST_F(MkdirTest, DISABLED_CanCreateReadbleDir) { - ASSERT_THAT(mkdir(dirname_.c_str(), 0444), SyscallSucceeds()); - ASSERT_THAT( - open(JoinPath(dirname_, "anything").c_str(), O_RDWR | O_CREAT, 0666), - SyscallFailsWithErrno(EACCES)); -} - -TEST_F(MkdirTest, CanCreateWritableDir) { - ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); - std::string filename = JoinPath(dirname_, "anything"); - int fd; - ASSERT_THAT(fd = open(filename.c_str(), O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds()); -} - -TEST_F(MkdirTest, HonorsUmask) { - constexpr mode_t kMask = 0111; - TempUmask mask(kMask); - ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); - struct stat statbuf; - ASSERT_THAT(stat(dirname_.c_str(), &statbuf), SyscallSucceeds()); - EXPECT_EQ(0777 & ~kMask, statbuf.st_mode & 0777); -} - -TEST_F(MkdirTest, HonorsUmask2) { - constexpr mode_t kMask = 0142; - TempUmask mask(kMask); - ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); - struct stat statbuf; - ASSERT_THAT(stat(dirname_.c_str(), &statbuf), SyscallSucceeds()); - EXPECT_EQ(0777 & ~kMask, statbuf.st_mode & 0777); -} - -TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto parent = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555)); - auto dir = JoinPath(parent.path(), "foo"); - ASSERT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc deleted file mode 100644 index 4c45766c7..000000000 --- a/test/syscalls/linux/mknod.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/un.h> -#include <unistd.h> - -#include <vector> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(MknodTest, RegularFile) { - const std::string node0 = NewTempAbsPath(); - EXPECT_THAT(mknod(node0.c_str(), S_IFREG, 0), SyscallSucceeds()); - - const std::string node1 = NewTempAbsPath(); - EXPECT_THAT(mknod(node1.c_str(), 0, 0), SyscallSucceeds()); -} - -TEST(MknodTest, MknodAtRegularFile) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string fifo_relpath = NewTempRelPath(); - const std::string fifo = JoinPath(dir.path(), fifo_relpath); - - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path().c_str(), O_RDONLY)); - ASSERT_THAT(mknodat(dirfd.get(), fifo_relpath.c_str(), S_IFIFO | S_IRUSR, 0), - SyscallSucceeds()); - - struct stat st; - ASSERT_THAT(stat(fifo.c_str(), &st), SyscallSucceeds()); - EXPECT_TRUE(S_ISFIFO(st.st_mode)); -} - -TEST(MknodTest, MknodOnExistingPathFails) { - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const TempPath slink = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path())); - - EXPECT_THAT(mknod(file.path().c_str(), S_IFREG, 0), - SyscallFailsWithErrno(EEXIST)); - EXPECT_THAT(mknod(file.path().c_str(), S_IFIFO, 0), - SyscallFailsWithErrno(EEXIST)); - EXPECT_THAT(mknod(slink.path().c_str(), S_IFREG, 0), - SyscallFailsWithErrno(EEXIST)); - EXPECT_THAT(mknod(slink.path().c_str(), S_IFIFO, 0), - SyscallFailsWithErrno(EEXIST)); -} - -TEST(MknodTest, UnimplementedTypesReturnError) { - const std::string path = NewTempAbsPath(); - - if (IsRunningOnGvisor()) { - ASSERT_THAT(mknod(path.c_str(), S_IFSOCK, 0), - SyscallFailsWithErrno(EOPNOTSUPP)); - } - // These will fail on linux as well since we don't have CAP_MKNOD. - ASSERT_THAT(mknod(path.c_str(), S_IFCHR, 0), SyscallFailsWithErrno(EPERM)); - ASSERT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM)); -} - -TEST(MknodTest, Fifo) { - const std::string fifo = NewTempAbsPath(); - ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0), - SyscallSucceeds()); - - struct stat st; - ASSERT_THAT(stat(fifo.c_str(), &st), SyscallSucceeds()); - EXPECT_TRUE(S_ISFIFO(st.st_mode)); - - std::string msg = "some std::string"; - std::vector<char> buf(512); - - // Read-end of the pipe. - ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); - EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(msg.length())); - EXPECT_EQ(msg, std::string(buf.data())); - }); - - // Write-end of the pipe. - FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY)); - EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), - SyscallSucceedsWithValue(msg.length())); -} - -TEST(MknodTest, FifoOtrunc) { - const std::string fifo = NewTempAbsPath(); - ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0), - SyscallSucceeds()); - - struct stat st = {}; - ASSERT_THAT(stat(fifo.c_str(), &st), SyscallSucceeds()); - EXPECT_TRUE(S_ISFIFO(st.st_mode)); - - std::string msg = "some std::string"; - std::vector<char> buf(512); - // Read-end of the pipe. - ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); - EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(msg.length())); - EXPECT_EQ(msg, std::string(buf.data())); - }); - - // Write-end of the pipe. - FileDescriptor wfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC)); - EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), - SyscallSucceedsWithValue(msg.length())); -} - -TEST(MknodTest, FifoTruncNoOp) { - const std::string fifo = NewTempAbsPath(); - ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0), - SyscallSucceeds()); - - EXPECT_THAT(truncate(fifo.c_str(), 0), SyscallFailsWithErrno(EINVAL)); - - struct stat st = {}; - ASSERT_THAT(stat(fifo.c_str(), &st), SyscallSucceeds()); - EXPECT_TRUE(S_ISFIFO(st.st_mode)); - - std::string msg = "some std::string"; - std::vector<char> buf(512); - // Read-end of the pipe. - ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); - EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(msg.length())); - EXPECT_EQ(msg, std::string(buf.data())); - }); - - FileDescriptor wfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC)); - EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), - SyscallSucceedsWithValue(msg.length())); - EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc deleted file mode 100644 index 367a90fe1..000000000 --- a/test/syscalls/linux/mlock.cc +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/mman.h> -#include <sys/resource.h> -#include <sys/syscall.h> -#include <unistd.h> - -#include <cerrno> -#include <cstring> - -#include "gmock/gmock.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/memory_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/rlimit_util.h" -#include "test/util/test_util.h" - -using ::testing::_; - -namespace gvisor { -namespace testing { - -namespace { - -PosixErrorOr<bool> CanMlock() { - struct rlimit rlim; - if (getrlimit(RLIMIT_MEMLOCK, &rlim) < 0) { - return PosixError(errno, "getrlimit(RLIMIT_MEMLOCK)"); - } - if (rlim.rlim_cur != 0) { - return true; - } - return HaveCapability(CAP_IPC_LOCK); -} - -// Returns true if the page containing addr is mlocked. -bool IsPageMlocked(uintptr_t addr) { - // This relies on msync(MS_INVALIDATE) interacting correctly with mlocked - // pages, which is tested for by the MsyncInvalidate case below. - int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)), - kPageSize, MS_ASYNC | MS_INVALIDATE); - if (rv == 0) { - return false; - } - // This uses TEST_PCHECK_MSG since it's used in subprocesses. - TEST_PCHECK_MSG(errno == EBUSY, "msync failed with unexpected errno"); - return true; -} - -TEST(MlockTest, Basic) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds()); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); -} - -TEST(MlockTest, ProtNone) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), - SyscallFailsWithErrno(ENOMEM)); - // ENOMEM is returned because mlock can't populate the page, but it's still - // considered locked. - EXPECT_TRUE(IsPageMlocked(mapping.addr())); -} - -TEST(MlockTest, MadviseDontneed) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds()); - EXPECT_THAT(madvise(mapping.ptr(), mapping.len(), MADV_DONTNEED), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(MlockTest, MsyncInvalidate) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds()); - EXPECT_THAT(msync(mapping.ptr(), mapping.len(), MS_ASYNC | MS_INVALIDATE), - SyscallFailsWithErrno(EBUSY)); - EXPECT_THAT(msync(mapping.ptr(), mapping.len(), MS_SYNC | MS_INVALIDATE), - SyscallFailsWithErrno(EBUSY)); -} - -TEST(MlockTest, Fork) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds()); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); - EXPECT_THAT( - InForkedProcess([&] { TEST_CHECK(!IsPageMlocked(mapping.addr())); }), - IsPosixErrorOkAndHolds(0)); -} - -TEST(MlockTest, RlimitMemlockZero) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } - Cleanup reset_rlimit = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), - SyscallFailsWithErrno(EPERM)); -} - -TEST(MlockTest, RlimitMemlockInsufficient) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } - Cleanup reset_rlimit = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, kPageSize)); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), - SyscallFailsWithErrno(ENOMEM)); -} - -TEST(MunlockTest, Basic) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds()); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(munlock(mapping.ptr(), mapping.len()), SyscallSucceeds()); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); -} - -TEST(MunlockTest, NotLocked) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - EXPECT_THAT(munlock(mapping.ptr(), mapping.len()), SyscallSucceeds()); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); -} - -// There is currently no test for mlockall(MCL_CURRENT) because the default -// RLIMIT_MEMLOCK of 64 KB is insufficient to actually invoke -// mlockall(MCL_CURRENT). - -TEST(MlockallTest, Future) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - - // Run this test in a separate (single-threaded) subprocess to ensure that a - // background thread doesn't try to mmap a large amount of memory, fail due - // to hitting RLIMIT_MEMLOCK, and explode the process violently. - auto const do_test = [] { - auto const mapping = - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie(); - TEST_CHECK(!IsPageMlocked(mapping.addr())); - TEST_PCHECK(mlockall(MCL_FUTURE) == 0); - // Ensure that mlockall(MCL_FUTURE) is turned off before the end of the - // test, as otherwise mmaps may fail unexpectedly. - Cleanup do_munlockall([] { TEST_PCHECK(munlockall() == 0); }); - auto const mapping2 = - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie(); - TEST_CHECK(IsPageMlocked(mapping2.addr())); - // Fire munlockall() and check that it disables mlockall(MCL_FUTURE). - do_munlockall.Release()(); - auto const mapping3 = - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie(); - TEST_CHECK(!IsPageMlocked(mapping2.addr())); - }; - EXPECT_THAT(InForkedProcess(do_test), IsPosixErrorOkAndHolds(0)); -} - -TEST(MunlockallTest, Basic) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(munlockall(), SyscallSucceeds()); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); -} - -#ifndef SYS_mlock2 -#ifdef __x86_64__ -#define SYS_mlock2 325 -#endif -#endif - -#ifndef MLOCK_ONFAULT -#define MLOCK_ONFAULT 0x01 // Linux: include/uapi/asm-generic/mman-common.h -#endif - -#ifdef SYS_mlock2 - -int mlock2(void const* addr, size_t len, int flags) { - return syscall(SYS_mlock2, addr, len, flags); -} - -TEST(Mlock2Test, NoFlags) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(mlock2(mapping.ptr(), mapping.len(), 0), SyscallSucceeds()); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); -} - -TEST(Mlock2Test, MlockOnfault) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); - ASSERT_THAT(mlock2(mapping.ptr(), mapping.len(), MLOCK_ONFAULT), - SyscallSucceeds()); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); -} - -TEST(Mlock2Test, UnknownFlags) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - EXPECT_THAT(mlock2(mapping.ptr(), mapping.len(), ~0), - SyscallFailsWithErrno(EINVAL)); -} - -#endif // defined(SYS_mlock2) - -TEST(MapLockedTest, Basic) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); - EXPECT_THAT(munlock(mapping.ptr(), mapping.len()), SyscallSucceeds()); - EXPECT_FALSE(IsPageMlocked(mapping.addr())); -} - -TEST(MapLockedTest, RlimitMemlockZero) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } - Cleanup reset_rlimit = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); - EXPECT_THAT( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED), - PosixErrorIs(EPERM, _)); -} - -TEST(MapLockedTest, RlimitMemlockInsufficient) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } - Cleanup reset_rlimit = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, kPageSize)); - EXPECT_THAT( - MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED), - PosixErrorIs(EAGAIN, _)); -} - -TEST(MremapLockedTest, Basic) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); - - void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(), - MREMAP_MAYMOVE, nullptr); - if (addr == MAP_FAILED) { - FAIL() << "mremap failed: " << errno << " (" << strerror(errno) << ")"; - } - mapping.release(); - mapping.reset(addr, 2 * mapping.len()); - EXPECT_TRUE(IsPageMlocked(reinterpret_cast<uintptr_t>(addr))); -} - -TEST(MremapLockedTest, RlimitMemlockZero) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); - - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } - Cleanup reset_rlimit = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); - void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(), - MREMAP_MAYMOVE, nullptr); - EXPECT_TRUE(addr == MAP_FAILED && errno == EAGAIN) - << "addr = " << addr << ", errno = " << errno; -} - -TEST(MremapLockedTest, RlimitMemlockInsufficient) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); - auto mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); - EXPECT_TRUE(IsPageMlocked(mapping.addr())); - - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } - Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE( - ScopedSetSoftRlimit(RLIMIT_MEMLOCK, mapping.len())); - void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(), - MREMAP_MAYMOVE, nullptr); - EXPECT_TRUE(addr == MAP_FAILED && errno == EAGAIN) - << "addr = " << addr << ", errno = " << errno; -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc deleted file mode 100644 index 11fb1b457..000000000 --- a/test/syscalls/linux/mmap.cc +++ /dev/null @@ -1,1670 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <linux/magic.h> -#include <linux/unistd.h> -#include <signal.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <sys/mman.h> -#include <sys/resource.h> -#include <sys/statfs.h> -#include <sys/syscall.h> -#include <sys/time.h> -#include <sys/types.h> -#include <sys/wait.h> -#include <unistd.h> - -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/escaping.h" -#include "absl/strings/str_split.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/memory_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -using ::testing::Gt; - -namespace gvisor { -namespace testing { - -namespace { - -PosixErrorOr<int64_t> VirtualMemorySize() { - ASSIGN_OR_RETURN_ERRNO(auto contents, GetContents("/proc/self/statm")); - std::vector<std::string> parts = absl::StrSplit(contents, ' '); - if (parts.empty()) { - return PosixError(EINVAL, "Unable to parse /proc/self/statm"); - } - ASSIGN_OR_RETURN_ERRNO(auto pages, Atoi<int64_t>(parts[0])); - return pages * getpagesize(); -} - -class MMapTest : public ::testing::Test { - protected: - // Unmap mapping, if one was made. - void TearDown() override { - if (addr_) { - EXPECT_THAT(Unmap(), SyscallSucceeds()); - } - } - - // Remembers mapping, so it can be automatically unmapped. - uintptr_t Map(uintptr_t addr, size_t length, int prot, int flags, int fd, - off_t offset) { - void* ret = - mmap(reinterpret_cast<void*>(addr), length, prot, flags, fd, offset); - - if (ret != MAP_FAILED) { - addr_ = ret; - length_ = length; - } - - return reinterpret_cast<uintptr_t>(ret); - } - - // Unmap previous mapping - int Unmap() { - if (!addr_) { - return -1; - } - - int ret = munmap(addr_, length_); - - addr_ = nullptr; - length_ = 0; - - return ret; - } - - // Msync the mapping. - int Msync() { return msync(addr_, length_, MS_SYNC); } - - // Mlock the mapping. - int Mlock() { return mlock(addr_, length_); } - - // Munlock the mapping. - int Munlock() { return munlock(addr_, length_); } - - int Protect(uintptr_t addr, size_t length, int prot) { - return mprotect(reinterpret_cast<void*>(addr), length, prot); - } - - void* addr_ = nullptr; - size_t length_ = 0; -}; - -// Matches if arg contains the same contents as string str. -MATCHER_P(EqualsMemory, str, "") { - if (0 == memcmp(arg, str.c_str(), str.size())) { - return true; - } - - *result_listener << "Memory did not match. Got:\n" - << absl::BytesToHexString( - std::string(static_cast<char*>(arg), str.size())) - << "Want:\n" - << absl::BytesToHexString(str); - return false; -} - -// We can't map pipes, but for different reasons. -TEST_F(MMapTest, MapPipe) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fds[0], 0), - SyscallFailsWithErrno(ENODEV)); - EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fds[1], 0), - SyscallFailsWithErrno(EACCES)); - ASSERT_THAT(close(fds[0]), SyscallSucceeds()); - ASSERT_THAT(close(fds[1]), SyscallSucceeds()); -} - -// It's very common to mmap /dev/zero because anonymous mappings aren't part -// of POSIX although they are widely supported. So a zero initialized memory -// region would actually come from a "file backed" /dev/zero mapping. -TEST_F(MMapTest, MapDevZeroShared) { - // This test will verify that we're able to map a page backed by /dev/zero - // as MAP_SHARED. - const FileDescriptor dev_zero = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR)); - - // Test that we can create a RW SHARED mapping of /dev/zero. - ASSERT_THAT( - Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, dev_zero.get(), 0), - SyscallSucceeds()); -} - -TEST_F(MMapTest, MapDevZeroPrivate) { - // This test will verify that we're able to map a page backed by /dev/zero - // as MAP_PRIVATE. - const FileDescriptor dev_zero = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR)); - - // Test that we can create a RW SHARED mapping of /dev/zero. - ASSERT_THAT( - Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, dev_zero.get(), 0), - SyscallSucceeds()); -} - -TEST_F(MMapTest, MapDevZeroNoPersistence) { - // This test will verify that two independent mappings of /dev/zero do not - // appear to reference the same "backed file." - - const FileDescriptor dev_zero1 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR)); - const FileDescriptor dev_zero2 = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR)); - - ASSERT_THAT( - Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, dev_zero1.get(), 0), - SyscallSucceeds()); - - // Create a second mapping via the second /dev/zero fd. - void* psec_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - dev_zero2.get(), 0); - ASSERT_THAT(reinterpret_cast<intptr_t>(psec_map), SyscallSucceeds()); - - // Always unmap. - auto cleanup_psec_map = Cleanup( - [&] { EXPECT_THAT(munmap(psec_map, kPageSize), SyscallSucceeds()); }); - - // Verify that we have independently addressed pages. - ASSERT_NE(psec_map, addr_); - - std::string buf_zero(kPageSize, 0x00); - std::string buf_ones(kPageSize, 0xFF); - - // Verify the first is actually all zeros after mmap. - EXPECT_THAT(addr_, EqualsMemory(buf_zero)); - - // Let's fill in the first mapping with 0xFF. - memcpy(addr_, buf_ones.data(), kPageSize); - - // Verify that the memcpy actually stuck in the page. - EXPECT_THAT(addr_, EqualsMemory(buf_ones)); - - // Verify that it didn't affect the second page which should be all zeros. - EXPECT_THAT(psec_map, EqualsMemory(buf_zero)); -} - -TEST_F(MMapTest, MapDevZeroSharedMultiplePages) { - // This will test that we're able to map /dev/zero over multiple pages. - const FileDescriptor dev_zero = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR)); - - // Test that we can create a RW SHARED mapping of /dev/zero. - ASSERT_THAT(Map(0, kPageSize * 2, PROT_READ | PROT_WRITE, MAP_PRIVATE, - dev_zero.get(), 0), - SyscallSucceeds()); - - std::string buf_zero(kPageSize * 2, 0x00); - std::string buf_ones(kPageSize * 2, 0xFF); - - // Verify the two pages are actually all zeros after mmap. - EXPECT_THAT(addr_, EqualsMemory(buf_zero)); - - // Fill out the pages with all ones. - memcpy(addr_, buf_ones.data(), kPageSize * 2); - - // Verify that the memcpy actually stuck in the pages. - EXPECT_THAT(addr_, EqualsMemory(buf_ones)); -} - -TEST_F(MMapTest, MapDevZeroSharedFdNoPersistence) { - // This test will verify that two independent mappings of /dev/zero do not - // appear to reference the same "backed file" even when mapped from the - // same initial fd. - const FileDescriptor dev_zero = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR)); - - ASSERT_THAT( - Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, dev_zero.get(), 0), - SyscallSucceeds()); - - // Create a second mapping via the same fd. - void* psec_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - dev_zero.get(), 0); - ASSERT_THAT(reinterpret_cast<int64_t>(psec_map), SyscallSucceeds()); - - // Always unmap. - auto cleanup_psec_map = Cleanup( - [&] { ASSERT_THAT(munmap(psec_map, kPageSize), SyscallSucceeds()); }); - - // Verify that we have independently addressed pages. - ASSERT_NE(psec_map, addr_); - - std::string buf_zero(kPageSize, 0x00); - std::string buf_ones(kPageSize, 0xFF); - - // Verify the first is actually all zeros after mmap. - EXPECT_THAT(addr_, EqualsMemory(buf_zero)); - - // Let's fill in the first mapping with 0xFF. - memcpy(addr_, buf_ones.data(), kPageSize); - - // Verify that the memcpy actually stuck in the page. - EXPECT_THAT(addr_, EqualsMemory(buf_ones)); - - // Verify that it didn't affect the second page which should be all zeros. - EXPECT_THAT(psec_map, EqualsMemory(buf_zero)); -} - -TEST_F(MMapTest, MapDevZeroSegfaultAfterUnmap) { - SetupGvisorDeathTest(); - - // This test will verify that we're able to map a page backed by /dev/zero - // as MAP_SHARED and after it's unmapped any access results in a SIGSEGV. - // This test is redundant but given the special nature of /dev/zero mappings - // it doesn't hurt. - const FileDescriptor dev_zero = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR)); - - const auto rest = [&] { - // Test that we can create a RW SHARED mapping of /dev/zero. - TEST_PCHECK(Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - dev_zero.get(), - 0) != reinterpret_cast<uintptr_t>(MAP_FAILED)); - - // Confirm that accesses after the unmap result in a SIGSEGV. - // - // N.B. We depend on this process being single-threaded to ensure there - // can't be another mmap to map addr before the dereference below. - void* addr_saved = addr_; // Unmap resets addr_. - TEST_PCHECK(Unmap() == 0); - *reinterpret_cast<volatile int*>(addr_saved) = 0xFF; - }; - - EXPECT_THAT(InForkedProcess(rest), - IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV))); -} - -TEST_F(MMapTest, MapDevZeroUnaligned) { - const FileDescriptor dev_zero = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR)); - const size_t size = kPageSize + kPageSize / 2; - const std::string buf_zero(size, 0x00); - - ASSERT_THAT( - Map(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, dev_zero.get(), 0), - SyscallSucceeds()); - EXPECT_THAT(addr_, EqualsMemory(buf_zero)); - ASSERT_THAT(Unmap(), SyscallSucceeds()); - - ASSERT_THAT( - Map(0, size, PROT_READ | PROT_WRITE, MAP_PRIVATE, dev_zero.get(), 0), - SyscallSucceeds()); - EXPECT_THAT(addr_, EqualsMemory(buf_zero)); -} - -// We can't map _some_ character devices. -TEST_F(MMapTest, MapCharDevice) { - const FileDescriptor cdevfd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/random", 0, 0)); - EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_PRIVATE, cdevfd.get(), 0), - SyscallFailsWithErrno(ENODEV)); -} - -// We can't map directories. -TEST_F(MMapTest, MapDirectory) { - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), 0, 0)); - EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_PRIVATE, dirfd.get(), 0), - SyscallFailsWithErrno(ENODEV)); -} - -// We can map *something* -TEST_F(MMapTest, MapAnything) { - EXPECT_THAT(Map(0, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceedsWithValue(Gt(0))); -} - -// Map length < PageSize allowed -TEST_F(MMapTest, SmallMap) { - EXPECT_THAT(Map(0, 128, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); -} - -// Hint address doesn't break anything. -// Note: there is no requirement we actually get the hint address -TEST_F(MMapTest, HintAddress) { - EXPECT_THAT( - Map(0x30000000, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); -} - -// MAP_FIXED gives us exactly the requested address -TEST_F(MMapTest, MapFixed) { - EXPECT_THAT(Map(0x30000000, kPageSize, PROT_NONE, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0), - SyscallSucceedsWithValue(0x30000000)); -} - -// 64-bit addresses work too -#ifdef __x86_64__ -TEST_F(MMapTest, MapFixed64) { - EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0), - SyscallSucceedsWithValue(0x300000000000)); -} -#endif - -// MAP_STACK allowed. -// There isn't a good way to verify it did anything. -TEST_F(MMapTest, MapStack) { - EXPECT_THAT(Map(0, kPageSize, PROT_NONE, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0), - SyscallSucceeds()); -} - -// MAP_LOCKED allowed. -// There isn't a good way to verify it did anything. -TEST_F(MMapTest, MapLocked) { - EXPECT_THAT(Map(0, kPageSize, PROT_NONE, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_LOCKED, -1, 0), - SyscallSucceeds()); -} - -// MAP_PRIVATE or MAP_SHARED must be passed -TEST_F(MMapTest, NotPrivateOrShared) { - EXPECT_THAT(Map(0, kPageSize, PROT_NONE, MAP_ANONYMOUS, -1, 0), - SyscallFailsWithErrno(EINVAL)); -} - -// Only one of MAP_PRIVATE or MAP_SHARED may be passed -TEST_F(MMapTest, PrivateAndShared) { - EXPECT_THAT(Map(0, kPageSize, PROT_NONE, - MAP_PRIVATE | MAP_SHARED | MAP_ANONYMOUS, -1, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(MMapTest, FixedAlignment) { - // Addr must be page aligned (MAP_FIXED) - EXPECT_THAT(Map(0x30000001, kPageSize, PROT_NONE, - MAP_PRIVATE | MAP_FIXED | MAP_ANONYMOUS, -1, 0), - SyscallFailsWithErrno(EINVAL)); -} - -// Non-MAP_FIXED address does not need to be page aligned -TEST_F(MMapTest, NonFixedAlignment) { - EXPECT_THAT( - Map(0x30000001, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); -} - -// Length = 0 results in EINVAL. -TEST_F(MMapTest, InvalidLength) { - EXPECT_THAT(Map(0, 0, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallFailsWithErrno(EINVAL)); -} - -// Bad fd not allowed. -TEST_F(MMapTest, BadFd) { - EXPECT_THAT(Map(0, kPageSize, PROT_NONE, MAP_PRIVATE, 999, 0), - SyscallFailsWithErrno(EBADF)); -} - -// Mappings are writable. -TEST_F(MMapTest, ProtWrite) { - uint64_t addr; - constexpr uint8_t kFirstWord[] = {42, 42, 42, 42}; - - EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - // This shouldn't cause a SIGSEGV. - memset(reinterpret_cast<void*>(addr), 42, kPageSize); - - // The written data should actually be there. - EXPECT_EQ( - 0, memcmp(reinterpret_cast<void*>(addr), kFirstWord, sizeof(kFirstWord))); -} - -// "Write-only" mappings are writable *and* readable. -TEST_F(MMapTest, ProtWriteOnly) { - uint64_t addr; - constexpr uint8_t kFirstWord[] = {42, 42, 42, 42}; - - EXPECT_THAT( - addr = Map(0, kPageSize, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - // This shouldn't cause a SIGSEGV. - memset(reinterpret_cast<void*>(addr), 42, kPageSize); - - // The written data should actually be there. - EXPECT_EQ( - 0, memcmp(reinterpret_cast<void*>(addr), kFirstWord, sizeof(kFirstWord))); -} - -// "Write-only" mappings are readable. -// -// This is distinct from above to ensure the page is accessible even if the -// initial fault is a write fault. -TEST_F(MMapTest, ProtWriteOnlyReadable) { - uint64_t addr; - constexpr uint64_t kFirstWord = 0; - - EXPECT_THAT( - addr = Map(0, kPageSize, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), &kFirstWord, - sizeof(kFirstWord))); -} - -// Mappings are writable after mprotect from PROT_NONE to PROT_READ|PROT_WRITE. -TEST_F(MMapTest, ProtectProtWrite) { - uint64_t addr; - constexpr uint8_t kFirstWord[] = {42, 42, 42, 42}; - - EXPECT_THAT( - addr = Map(0, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - ASSERT_THAT(Protect(addr, kPageSize, PROT_READ | PROT_WRITE), - SyscallSucceeds()); - - // This shouldn't cause a SIGSEGV. - memset(reinterpret_cast<void*>(addr), 42, kPageSize); - - // The written data should actually be there. - EXPECT_EQ( - 0, memcmp(reinterpret_cast<void*>(addr), kFirstWord, sizeof(kFirstWord))); -} - -// SIGSEGV raised when reading PROT_NONE memory -TEST_F(MMapTest, ProtNoneDeath) { - SetupGvisorDeathTest(); - - uintptr_t addr; - - ASSERT_THAT( - addr = Map(0, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - EXPECT_EXIT(*reinterpret_cast<volatile int*>(addr), - ::testing::KilledBySignal(SIGSEGV), ""); -} - -// SIGSEGV raised when writing PROT_READ only memory -TEST_F(MMapTest, ReadOnlyDeath) { - SetupGvisorDeathTest(); - - uintptr_t addr; - - ASSERT_THAT( - addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - EXPECT_EXIT(*reinterpret_cast<volatile int*>(addr) = 42, - ::testing::KilledBySignal(SIGSEGV), ""); -} - -// Writable mapping mprotect'd to read-only should not be writable. -TEST_F(MMapTest, MprotectReadOnlyDeath) { - SetupGvisorDeathTest(); - - uintptr_t addr; - - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - volatile int* val = reinterpret_cast<int*>(addr); - - // Copy to ensure page is mapped in. - *val = 42; - - ASSERT_THAT(Protect(addr, kPageSize, PROT_READ), SyscallSucceeds()); - - // Now it shouldn't be writable. - EXPECT_EXIT(*val = 0, ::testing::KilledBySignal(SIGSEGV), ""); -} - -// Verify that calling mprotect an address that's not page aligned fails. -TEST_F(MMapTest, MprotectNotPageAligned) { - uintptr_t addr; - - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - ASSERT_THAT(Protect(addr + 1, kPageSize - 1, PROT_READ), - SyscallFailsWithErrno(EINVAL)); -} - -// Verify that calling mprotect with an absurdly huge length fails. -TEST_F(MMapTest, MprotectHugeLength) { - uintptr_t addr; - - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - ASSERT_THAT(Protect(addr, static_cast<size_t>(-1), PROT_READ), - SyscallFailsWithErrno(ENOMEM)); -} - -#if defined(__x86_64__) || defined(__i386__) -// This code is equivalent in 32 and 64-bit mode -const uint8_t machine_code[] = { - 0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax - 0xc3, // retq -}; - -// PROT_EXEC allows code execution -TEST_F(MMapTest, ProtExec) { - uintptr_t addr; - uint32_t (*func)(void); - - EXPECT_THAT(addr = Map(0, kPageSize, PROT_EXEC | PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - memcpy(reinterpret_cast<void*>(addr), machine_code, sizeof(machine_code)); - - func = reinterpret_cast<uint32_t (*)(void)>(addr); - - EXPECT_EQ(42, func()); -} - -// No PROT_EXEC disallows code execution -TEST_F(MMapTest, NoProtExecDeath) { - SetupGvisorDeathTest(); - - uintptr_t addr; - uint32_t (*func)(void); - - EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - - memcpy(reinterpret_cast<void*>(addr), machine_code, sizeof(machine_code)); - - func = reinterpret_cast<uint32_t (*)(void)>(addr); - - EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), ""); -} -#endif - -TEST_F(MMapTest, NoExceedLimitData) { - void* prevbrk; - void* target_brk; - struct rlimit setlim; - - prevbrk = sbrk(0); - ASSERT_NE(-1, reinterpret_cast<intptr_t>(prevbrk)); - target_brk = reinterpret_cast<char*>(prevbrk) + 1; - - setlim.rlim_cur = RLIM_INFINITY; - setlim.rlim_max = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds()); - EXPECT_THAT(brk(target_brk), SyscallSucceedsWithValue(0)); -} - -TEST_F(MMapTest, ExceedLimitData) { - // To unit test this more precisely, we'd need access to the mm's start_brk - // and end_brk, which we don't have direct access to :/ - void* prevbrk; - void* target_brk; - struct rlimit setlim; - - prevbrk = sbrk(0); - ASSERT_NE(-1, reinterpret_cast<intptr_t>(prevbrk)); - target_brk = reinterpret_cast<char*>(prevbrk) + 8192; - - setlim.rlim_cur = 0; - setlim.rlim_max = RLIM_INFINITY; - // Set RLIMIT_DATA very low so any subsequent brk() calls fail. - // Reset RLIMIT_DATA during teardown step. - ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds()); - EXPECT_THAT(brk(target_brk), SyscallFailsWithErrno(ENOMEM)); - // Teardown step... - setlim.rlim_cur = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds()); -} - -TEST_F(MMapTest, ExceedLimitDataPrlimit) { - // To unit test this more precisely, we'd need access to the mm's start_brk - // and end_brk, which we don't have direct access to :/ - void* prevbrk; - void* target_brk; - struct rlimit setlim; - - prevbrk = sbrk(0); - ASSERT_NE(-1, reinterpret_cast<intptr_t>(prevbrk)); - target_brk = reinterpret_cast<char*>(prevbrk) + 8192; - - setlim.rlim_cur = 0; - setlim.rlim_max = RLIM_INFINITY; - // Set RLIMIT_DATA very low so any subsequent brk() calls fail. - // Reset RLIMIT_DATA during teardown step. - ASSERT_THAT(prlimit(0, RLIMIT_DATA, &setlim, nullptr), SyscallSucceeds()); - EXPECT_THAT(brk(target_brk), SyscallFailsWithErrno(ENOMEM)); - // Teardown step... - setlim.rlim_cur = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds()); -} - -TEST_F(MMapTest, ExceedLimitDataPrlimitPID) { - // To unit test this more precisely, we'd need access to the mm's start_brk - // and end_brk, which we don't have direct access to :/ - void* prevbrk; - void* target_brk; - struct rlimit setlim; - - prevbrk = sbrk(0); - ASSERT_NE(-1, reinterpret_cast<intptr_t>(prevbrk)); - target_brk = reinterpret_cast<char*>(prevbrk) + 8192; - - setlim.rlim_cur = 0; - setlim.rlim_max = RLIM_INFINITY; - // Set RLIMIT_DATA very low so any subsequent brk() calls fail. - // Reset RLIMIT_DATA during teardown step. - ASSERT_THAT(prlimit(syscall(__NR_gettid), RLIMIT_DATA, &setlim, nullptr), - SyscallSucceeds()); - EXPECT_THAT(brk(target_brk), SyscallFailsWithErrno(ENOMEM)); - // Teardown step... - setlim.rlim_cur = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds()); -} - -TEST_F(MMapTest, NoExceedLimitAS) { - constexpr uint64_t kAllocBytes = 200 << 20; - // Add some headroom to the AS limit in case of e.g. unexpected stack - // expansion. - constexpr uint64_t kExtraASBytes = kAllocBytes + (20 << 20); - static_assert(kAllocBytes < kExtraASBytes, - "test depends on allocation not exceeding AS limit"); - - auto vss = ASSERT_NO_ERRNO_AND_VALUE(VirtualMemorySize()); - struct rlimit setlim; - setlim.rlim_cur = vss + kExtraASBytes; - setlim.rlim_max = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_AS, &setlim), SyscallSucceeds()); - EXPECT_THAT( - Map(0, kAllocBytes, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceedsWithValue(Gt(0))); -} - -TEST_F(MMapTest, ExceedLimitAS) { - constexpr uint64_t kAllocBytes = 200 << 20; - // Add some headroom to the AS limit in case of e.g. unexpected stack - // expansion. - constexpr uint64_t kExtraASBytes = 20 << 20; - static_assert(kAllocBytes > kExtraASBytes, - "test depends on allocation exceeding AS limit"); - - auto vss = ASSERT_NO_ERRNO_AND_VALUE(VirtualMemorySize()); - struct rlimit setlim; - setlim.rlim_cur = vss + kExtraASBytes; - setlim.rlim_max = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_AS, &setlim), SyscallSucceeds()); - EXPECT_THAT( - Map(0, kAllocBytes, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallFailsWithErrno(ENOMEM)); -} - -// Tests that setting an anonymous mmap to PROT_NONE doesn't free the memory. -TEST_F(MMapTest, SettingProtNoneDoesntFreeMemory) { - uintptr_t addr; - constexpr uint8_t kFirstWord[] = {42, 42, 42, 42}; - - EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceedsWithValue(Gt(0))); - - memset(reinterpret_cast<void*>(addr), 42, kPageSize); - - ASSERT_THAT(Protect(addr, kPageSize, PROT_NONE), SyscallSucceeds()); - ASSERT_THAT(Protect(addr, kPageSize, PROT_READ | PROT_WRITE), - SyscallSucceeds()); - - // The written data should still be there. - EXPECT_EQ( - 0, memcmp(reinterpret_cast<void*>(addr), kFirstWord, sizeof(kFirstWord))); -} - -constexpr char kFileContents[] = "Hello World!"; - -class MMapFileTest : public MMapTest { - protected: - FileDescriptor fd_; - std::string filename_; - - // Open a file for read/write - void SetUp() override { - MMapTest::SetUp(); - - filename_ = NewTempAbsPath(); - fd_ = ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_CREAT | O_RDWR, 0644)); - - // Extend file so it can be written once mapped. Deliberately make the file - // only half a page in size, so we can test what happens when we access the - // second half. - // Use ftruncate(2) once the sentry supports it. - char zero = 0; - size_t count = 0; - do { - const DisableSave ds; // saving 2048 times is slow and useless. - Write(&zero, 1), SyscallSucceedsWithValue(1); - } while (++count < (kPageSize / 2)); - ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - } - - // Close and delete file - void TearDown() override { - MMapTest::TearDown(); - fd_.reset(); // Make sure the files is closed before we unlink it. - ASSERT_THAT(unlink(filename_.c_str()), SyscallSucceeds()); - } - - ssize_t Read(char* buf, size_t count) { - ssize_t len = 0; - do { - ssize_t ret = read(fd_.get(), buf, count); - if (ret < 0) { - return ret; - } else if (ret == 0) { - return len; - } - - len += ret; - buf += ret; - } while (len < static_cast<ssize_t>(count)); - - return len; - } - - ssize_t Write(const char* buf, size_t count) { - ssize_t len = 0; - do { - ssize_t ret = write(fd_.get(), buf, count); - if (ret < 0) { - return ret; - } else if (ret == 0) { - return len; - } - - len += ret; - buf += ret; - } while (len < static_cast<ssize_t>(count)); - - return len; - } -}; - -class MMapFileParamTest - : public MMapFileTest, - public ::testing::WithParamInterface<std::tuple<int, int>> { - protected: - int prot() const { return std::get<0>(GetParam()); } - - int flags() const { return std::get<1>(GetParam()); } -}; - -// MAP_POPULATE allowed. -// There isn't a good way to verify it actually did anything. -TEST_P(MMapFileParamTest, MapPopulate) { - ASSERT_THAT(Map(0, kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0), - SyscallSucceeds()); -} - -// MAP_POPULATE on a short file. -TEST_P(MMapFileParamTest, MapPopulateShort) { - ASSERT_THAT( - Map(0, 2 * kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0), - SyscallSucceeds()); -} - -// Read contents from mapped file. -TEST_F(MMapFileTest, Read) { - size_t len = strlen(kFileContents); - ASSERT_EQ(len, Write(kFileContents, len)); - - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd_.get(), 0), - SyscallSucceeds()); - - EXPECT_THAT(reinterpret_cast<char*>(addr), - EqualsMemory(std::string(kFileContents))); -} - -// Map at an offset. -TEST_F(MMapFileTest, MapOffset) { - ASSERT_THAT(lseek(fd_.get(), kPageSize, SEEK_SET), SyscallSucceeds()); - - size_t len = strlen(kFileContents); - ASSERT_EQ(len, Write(kFileContents, len)); - - uintptr_t addr; - ASSERT_THAT( - addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd_.get(), kPageSize), - SyscallSucceeds()); - - EXPECT_THAT(reinterpret_cast<char*>(addr), - EqualsMemory(std::string(kFileContents))); -} - -TEST_F(MMapFileTest, MapOffsetBeyondEnd) { - SetupGvisorDeathTest(); - - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 10 * kPageSize), - SyscallSucceeds()); - - // Touching the memory causes SIGBUS. - size_t len = strlen(kFileContents); - EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr)), - ::testing::KilledBySignal(SIGBUS), ""); -} - -// Verify mmap fails when sum of length and offset overflows. -TEST_F(MMapFileTest, MapLengthPlusOffsetOverflows) { - const size_t length = static_cast<size_t>(-kPageSize); - const off_t offset = kPageSize; - ASSERT_THAT(Map(0, length, PROT_READ, MAP_PRIVATE, fd_.get(), offset), - SyscallFailsWithErrno(ENOMEM)); -} - -// MAP_PRIVATE PROT_WRITE is allowed on read-only FDs. -TEST_F(MMapFileTest, WritePrivateOnReadOnlyFd) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_RDONLY)); - - uintptr_t addr; - EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd.get(), 0), - SyscallSucceeds()); - - // Touch the page to ensure the kernel didn't lie about writability. - size_t len = strlen(kFileContents); - std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr)); -} - -// MAP_SHARED PROT_WRITE not allowed on read-only FDs. -TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_RDONLY)); - - uintptr_t addr; - EXPECT_THAT( - addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0), - SyscallFailsWithErrno(EACCES)); -} - -// The FD must be readable. -TEST_P(MMapFileParamTest, WriteOnlyFd) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY)); - - uintptr_t addr; - EXPECT_THAT(addr = Map(0, kPageSize, prot(), flags(), fd.get(), 0), - SyscallFailsWithErrno(EACCES)); -} - -// Overwriting the contents of a file mapped MAP_SHARED PROT_READ -// should cause the new data to be reflected in the mapping. -TEST_F(MMapFileTest, ReadSharedConsistentWithOverwrite) { - // Start from scratch. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - - // Expand the file to two pages and dirty them. - std::string bufA(kPageSize, 'a'); - ASSERT_THAT(Write(bufA.c_str(), bufA.size()), - SyscallSucceedsWithValue(bufA.size())); - std::string bufB(kPageSize, 'b'); - ASSERT_THAT(Write(bufB.c_str(), bufB.size()), - SyscallSucceedsWithValue(bufB.size())); - - // Map the page. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // Check that the mapping contains the right file data. - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), bufA.c_str(), kPageSize)); - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize), bufB.c_str(), - kPageSize)); - - // Start at the beginning of the file. - ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - - // Swap the write pattern. - ASSERT_THAT(Write(bufB.c_str(), bufB.size()), - SyscallSucceedsWithValue(bufB.size())); - ASSERT_THAT(Write(bufA.c_str(), bufA.size()), - SyscallSucceedsWithValue(bufA.size())); - - // Check that the mapping got updated. - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), bufB.c_str(), kPageSize)); - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize), bufA.c_str(), - kPageSize)); -} - -// Partially overwriting a file mapped MAP_SHARED PROT_READ should be reflected -// in the mapping. -TEST_F(MMapFileTest, ReadSharedConsistentWithPartialOverwrite) { - // Start from scratch. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - - // Expand the file to two pages and dirty them. - std::string bufA(kPageSize, 'a'); - ASSERT_THAT(Write(bufA.c_str(), bufA.size()), - SyscallSucceedsWithValue(bufA.size())); - std::string bufB(kPageSize, 'b'); - ASSERT_THAT(Write(bufB.c_str(), bufB.size()), - SyscallSucceedsWithValue(bufB.size())); - - // Map the page. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // Check that the mapping contains the right file data. - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), bufA.c_str(), kPageSize)); - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize), bufB.c_str(), - kPageSize)); - - // Start at the beginning of the file. - ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - - // Do a partial overwrite, spanning both pages. - std::string bufC(kPageSize + (kPageSize / 2), 'c'); - ASSERT_THAT(Write(bufC.c_str(), bufC.size()), - SyscallSucceedsWithValue(bufC.size())); - - // Check that the mapping got updated. - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), bufC.c_str(), - kPageSize + (kPageSize / 2))); - EXPECT_EQ(0, - memcmp(reinterpret_cast<void*>(addr + kPageSize + (kPageSize / 2)), - bufB.c_str(), kPageSize / 2)); -} - -// Overwriting a file mapped MAP_SHARED PROT_READ should be reflected in the -// mapping and the file. -TEST_F(MMapFileTest, ReadSharedConsistentWithWriteAndFile) { - // Start from scratch. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - - // Expand the file to two full pages and dirty it. - std::string bufA(2 * kPageSize, 'a'); - ASSERT_THAT(Write(bufA.c_str(), bufA.size()), - SyscallSucceedsWithValue(bufA.size())); - - // Map only the first page. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // Prepare to overwrite the file contents. - ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - - // Overwrite everything, beyond the mapped portion. - std::string bufB(2 * kPageSize, 'b'); - ASSERT_THAT(Write(bufB.c_str(), bufB.size()), - SyscallSucceedsWithValue(bufB.size())); - - // What the mapped portion should now look like. - std::string bufMapped(kPageSize, 'b'); - - // Expect that the mapped portion is consistent. - EXPECT_EQ( - 0, memcmp(reinterpret_cast<void*>(addr), bufMapped.c_str(), kPageSize)); - - // Prepare to read the entire file contents. - ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - - // Expect that the file was fully updated. - std::vector<char> bufFile(2 * kPageSize); - ASSERT_THAT(Read(bufFile.data(), bufFile.size()), - SyscallSucceedsWithValue(bufFile.size())); - // Cast to void* to avoid EXPECT_THAT assuming bufFile.data() is a - // NUL-terminated C std::string. EXPECT_THAT will try to print a char* as a C - // std::string, possibly overruning the buffer. - EXPECT_THAT(reinterpret_cast<void*>(bufFile.data()), EqualsMemory(bufB)); -} - -// Write data to mapped file. -TEST_F(MMapFileTest, WriteShared) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - size_t len = strlen(kFileContents); - memcpy(reinterpret_cast<void*>(addr), kFileContents, len); - - // The file may not actually be updated until munmap is called. - ASSERT_THAT(Unmap(), SyscallSucceeds()); - - std::vector<char> buf(len); - ASSERT_THAT(Read(buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a - // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C - // string, possibly overruning the buffer. - EXPECT_THAT(reinterpret_cast<void*>(buf.data()), - EqualsMemory(std::string(kFileContents))); -} - -// Write data to portion of mapped page beyond the end of the file. -// These writes are not reflected in the file. -TEST_F(MMapFileTest, WriteSharedBeyondEnd) { - // The file is only half of a page. We map an entire page. Writes to the - // end of the mapping must not be reflected in the file. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - // First half; this is reflected in the file. - std::string first(kPageSize / 2, 'A'); - memcpy(reinterpret_cast<void*>(addr), first.c_str(), first.size()); - - // Second half; this is not reflected in the file. - std::string second(kPageSize / 2, 'B'); - memcpy(reinterpret_cast<void*>(addr + kPageSize / 2), second.c_str(), - second.size()); - - // The file may not actually be updated until munmap is called. - ASSERT_THAT(Unmap(), SyscallSucceeds()); - - // Big enough to fit the entire page, if the writes are mistakenly written to - // the file. - std::vector<char> buf(kPageSize); - - // Only the first half is in the file. - ASSERT_THAT(Read(buf.data(), buf.size()), - SyscallSucceedsWithValue(first.size())); - // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a - // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C - // NUL-terminated C std::string. EXPECT_THAT will try to print a char* as a C - // std::string, possibly overruning the buffer. - EXPECT_THAT(reinterpret_cast<void*>(buf.data()), EqualsMemory(first)); -} - -// The portion of a mapped page that becomes part of the file after a truncate -// is reflected in the file. -TEST_F(MMapFileTest, WriteSharedTruncateUp) { - // The file is only half of a page. We map an entire page. Writes to the - // end of the mapping must not be reflected in the file. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - // First half; this is reflected in the file. - std::string first(kPageSize / 2, 'A'); - memcpy(reinterpret_cast<void*>(addr), first.c_str(), first.size()); - - // Second half; this is not reflected in the file now (see - // WriteSharedBeyondEnd), but will be after the truncate. - std::string second(kPageSize / 2, 'B'); - memcpy(reinterpret_cast<void*>(addr + kPageSize / 2), second.c_str(), - second.size()); - - // Extend the file to a full page. The second half of the page will be - // reflected in the file. - EXPECT_THAT(ftruncate(fd_.get(), kPageSize), SyscallSucceeds()); - - // The file may not actually be updated until munmap is called. - ASSERT_THAT(Unmap(), SyscallSucceeds()); - - // The whole page is in the file. - std::vector<char> buf(kPageSize); - ASSERT_THAT(Read(buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a - // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C - // string, possibly overruning the buffer. - EXPECT_THAT(reinterpret_cast<void*>(buf.data()), EqualsMemory(first)); - EXPECT_THAT(reinterpret_cast<void*>(buf.data() + kPageSize / 2), - EqualsMemory(second)); -} - -TEST_F(MMapFileTest, ReadSharedTruncateDownThenUp) { - // Start from scratch. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - - // Expand the file to a full page and dirty it. - std::string buf(kPageSize, 'a'); - ASSERT_THAT(Write(buf.c_str(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - // Map the page. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // Check that the memory contains the file data. - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), buf.c_str(), kPageSize)); - - // Truncate down, then up. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - EXPECT_THAT(ftruncate(fd_.get(), kPageSize), SyscallSucceeds()); - - // Check that the memory was zeroed. - std::string zeroed(kPageSize, '\0'); - EXPECT_EQ(0, - memcmp(reinterpret_cast<void*>(addr), zeroed.c_str(), kPageSize)); - - // The file may not actually be updated until msync is called. - ASSERT_THAT(Msync(), SyscallSucceeds()); - - // Prepare to read the entire file contents. - ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - - // Expect that the file is fully updated. - std::vector<char> bufFile(kPageSize); - ASSERT_THAT(Read(bufFile.data(), bufFile.size()), - SyscallSucceedsWithValue(bufFile.size())); - EXPECT_EQ(0, memcmp(bufFile.data(), zeroed.c_str(), kPageSize)); -} - -TEST_F(MMapFileTest, WriteSharedTruncateDownThenUp) { - // The file is only half of a page. We map an entire page. Writes to the - // end of the mapping must not be reflected in the file. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - // First half; this will be deleted by truncate(0). - std::string first(kPageSize / 2, 'A'); - memcpy(reinterpret_cast<void*>(addr), first.c_str(), first.size()); - - // Truncate down, then up. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - EXPECT_THAT(ftruncate(fd_.get(), kPageSize), SyscallSucceeds()); - - // The whole page is zeroed in memory. - std::string zeroed(kPageSize, '\0'); - EXPECT_EQ(0, - memcmp(reinterpret_cast<void*>(addr), zeroed.c_str(), kPageSize)); - - // The file may not actually be updated until munmap is called. - ASSERT_THAT(Unmap(), SyscallSucceeds()); - - // The whole file is also zeroed. - std::vector<char> buf(kPageSize); - ASSERT_THAT(Read(buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a - // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C - // string, possibly overruning the buffer. - EXPECT_THAT(reinterpret_cast<void*>(buf.data()), EqualsMemory(zeroed)); -} - -TEST_F(MMapFileTest, ReadSharedTruncateSIGBUS) { - SetupGvisorDeathTest(); - - // Start from scratch. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - - // Expand the file to a full page and dirty it. - std::string buf(kPageSize, 'a'); - ASSERT_THAT(Write(buf.c_str(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - // Map the page. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // Check that the mapping contains the file data. - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), buf.c_str(), kPageSize)); - - // Truncate down. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - - // Accessing the truncated region should cause a SIGBUS. - std::vector<char> in(kPageSize); - EXPECT_EXIT( - std::copy(reinterpret_cast<volatile char*>(addr), - reinterpret_cast<volatile char*>(addr) + kPageSize, in.data()), - ::testing::KilledBySignal(SIGBUS), ""); -} - -TEST_F(MMapFileTest, WriteSharedTruncateSIGBUS) { - SetupGvisorDeathTest(); - - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - // Touch the memory to be sure it really is mapped. - size_t len = strlen(kFileContents); - memcpy(reinterpret_cast<void*>(addr), kFileContents, len); - - // Truncate down. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - - // Accessing the truncated file should cause a SIGBUS. - EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr)), - ::testing::KilledBySignal(SIGBUS), ""); -} - -TEST_F(MMapFileTest, ReadSharedTruncatePartialPage) { - // Start from scratch. - EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds()); - - // Dirty the file. - std::string buf(kPageSize, 'a'); - ASSERT_THAT(Write(buf.c_str(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - // Map a page. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // Truncate to half of the page. - EXPECT_THAT(ftruncate(fd_.get(), kPageSize / 2), SyscallSucceeds()); - - // First half of the page untouched. - EXPECT_EQ(0, - memcmp(reinterpret_cast<void*>(addr), buf.data(), kPageSize / 2)); - - // Second half is zeroed. - std::string zeroed(kPageSize / 2, '\0'); - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize / 2), - zeroed.c_str(), kPageSize / 2)); -} - -// Page can still be accessed and contents are intact after truncating a partial -// page. -TEST_F(MMapFileTest, WriteSharedTruncatePartialPage) { - // Expand the file to a full page. - EXPECT_THAT(ftruncate(fd_.get(), kPageSize), SyscallSucceeds()); - - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - // Fill the entire page. - std::string contents(kPageSize, 'A'); - memcpy(reinterpret_cast<void*>(addr), contents.c_str(), contents.size()); - - // Truncate half of the page. - EXPECT_THAT(ftruncate(fd_.get(), kPageSize / 2), SyscallSucceeds()); - - // First half of the page untouched. - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), contents.c_str(), - kPageSize / 2)); - - // Second half zeroed. - std::string zeroed(kPageSize / 2, '\0'); - EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize / 2), - zeroed.c_str(), kPageSize / 2)); -} - -// MAP_PRIVATE writes are not carried through to the underlying file. -TEST_F(MMapFileTest, WritePrivate) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 0), - SyscallSucceeds()); - - size_t len = strlen(kFileContents); - memcpy(reinterpret_cast<void*>(addr), kFileContents, len); - - // The file should not be updated, but if it mistakenly is, it may not be - // until after munmap is called. - ASSERT_THAT(Unmap(), SyscallSucceeds()); - - std::vector<char> buf(len); - ASSERT_THAT(Read(buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a - // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C - // string, possibly overruning the buffer. - EXPECT_THAT(reinterpret_cast<void*>(buf.data()), - EqualsMemory(std::string(len, '\0'))); -} - -// SIGBUS raised when reading or writing past end of a mapped file. -TEST_P(MMapFileParamTest, SigBusDeath) { - SetupGvisorDeathTest(); - - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), - SyscallSucceeds()); - - auto* start = reinterpret_cast<volatile char*>(addr + kPageSize); - - // MMapFileTest makes a file kPageSize/2 long. The entire first page should be - // accessible, but anything beyond it should not. - if (prot() & PROT_WRITE) { - // Write beyond first page. - size_t len = strlen(kFileContents); - EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, start), - ::testing::KilledBySignal(SIGBUS), ""); - } else { - // Read beyond first page. - std::vector<char> in(kPageSize); - EXPECT_EXIT(std::copy(start, start + kPageSize, in.data()), - ::testing::KilledBySignal(SIGBUS), ""); - } -} - -// Tests that SIGBUS is not raised when reading or writing to a file-mapped -// page before EOF, even if part of the mapping extends beyond EOF. -// -// See b/27877699. -TEST_P(MMapFileParamTest, NoSigBusOnPagesBeforeEOF) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), - SyscallSucceeds()); - - // The test passes if this survives. - auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1); - size_t len = strlen(kFileContents); - if (prot() & PROT_WRITE) { - std::copy(kFileContents, kFileContents + len, start); - } else { - std::vector<char> in(len); - std::copy(start, start + len, in.data()); - } -} - -// Tests that SIGBUS is not raised when reading or writing from a file-mapped -// page containing EOF, *after* the EOF. -TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), - SyscallSucceeds()); - - // The test passes if this survives. (Technically addr+kPageSize/2 is already - // beyond EOF, but +1 to check for fencepost errors.) - auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1); - size_t len = strlen(kFileContents); - if (prot() & PROT_WRITE) { - std::copy(kFileContents, kFileContents + len, start); - } else { - std::vector<char> in(len); - std::copy(start, start + len, in.data()); - } -} - -// Tests that reading from writable shared file-mapped pages succeeds. -// -// On most platforms this is trivial, but when the file is mapped via the sentry -// page cache (which does not yet support writing to shared mappings), a bug -// caused reads to fail unnecessarily on such mappings. See b/28913513. -TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) { - uintptr_t addr; - size_t len = strlen(kFileContents); - - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - std::vector<char> buf(kPageSize); - // The test passes if this survives. - std::copy(reinterpret_cast<volatile char*>(addr), - reinterpret_cast<volatile char*>(addr) + len, buf.data()); -} - -// Tests that EFAULT is returned when invoking a syscall that requires the OS to -// read past end of file (resulting in a fault in sentry context in the gVisor -// case). See b/28913513. -TEST_F(MMapFileTest, InternalSigBus) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 0), - SyscallSucceeds()); - - // This depends on the fact that gVisor implements pipes internally. - int pipefd[2]; - ASSERT_THAT(pipe(pipefd), SyscallSucceeds()); - EXPECT_THAT( - write(pipefd[1], reinterpret_cast<void*>(addr + kPageSize), kPageSize), - SyscallFailsWithErrno(EFAULT)); - - EXPECT_THAT(close(pipefd[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipefd[1]), SyscallSucceeds()); -} - -// Like InternalSigBus, but test the WriteZerosAt path by reading from -// /dev/zero to a shared mapping (so that the SIGBUS isn't caught during -// copy-on-write breaking). -TEST_F(MMapFileTest, InternalSigBusZeroing) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - const FileDescriptor dev_zero = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY)); - EXPECT_THAT(read(dev_zero.get(), reinterpret_cast<void*>(addr + kPageSize), - kPageSize), - SyscallFailsWithErrno(EFAULT)); -} - -// Checks that mmaps with a length of uint64_t(-PAGE_SIZE + 1) or greater do not -// induce a sentry panic (due to "rounding up" to 0). -TEST_F(MMapTest, HugeLength) { - EXPECT_THAT(Map(0, static_cast<uint64_t>(-kPageSize + 1), PROT_NONE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallFailsWithErrno(ENOMEM)); -} - -// Tests for a specific gVisor MM caching bug. -TEST_F(MMapTest, AccessCOWInvalidatesCachedSegments) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR)); - auto zero_fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY)); - - // Get a two-page private mapping and fill it with 1s. - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0), - SyscallSucceeds()); - memset(addr_, 1, 2 * kPageSize); - MaybeSave(); - - // Fork to make the mapping copy-on-write. - pid_t const pid = fork(); - if (pid == 0) { - // The child process waits for the parent to SIGKILL it. - while (true) { - pause(); - } - } - ASSERT_THAT(pid, SyscallSucceeds()); - auto cleanup_child = Cleanup([&] { - EXPECT_THAT(kill(pid, SIGKILL), SyscallSucceeds()); - int status; - EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - }); - - // Induce a read-only Access of the first page of the mapping, which will not - // cause a copy. The usermem.Segment should be cached. - ASSERT_THAT(PwriteFd(fd.get(), addr_, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - - // Induce a writable Access of both pages of the mapping. This should - // invalidate the cached Segment. - ASSERT_THAT(PreadFd(zero_fd.get(), addr_, 2 * kPageSize, 0), - SyscallSucceedsWithValue(2 * kPageSize)); - - // Induce a read-only Access of the first page of the mapping again. It should - // read the 0s that were stored in the mapping by the read from /dev/zero. If - // the read failed to invalidate the cached Segment, it will instead read the - // 1s in the stale page. - ASSERT_THAT(PwriteFd(fd.get(), addr_, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - std::vector<char> buf(kPageSize); - ASSERT_THAT(PreadFd(fd.get(), buf.data(), kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - for (size_t i = 0; i < kPageSize; i++) { - ASSERT_EQ(0, buf[i]) << "at offset " << i; - } -} - -TEST_F(MMapTest, NoReserve) { - const size_t kSize = 10 * 1 << 20; // 10M - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE, -1, 0), - SyscallSucceeds()); - EXPECT_GT(addr, 0); - - // Check that every page can be read/written. Technically, writing to memory - // could SIGSEGV in case there is no more memory available. In gVisor it - // would never happen though because NORESERVE is ignored. In Linux, it's - // possible to fail, but allocation is small enough that it's highly likely - // to succeed. - for (size_t j = 0; j < kSize; j += kPageSize) { - EXPECT_EQ(0, reinterpret_cast<char*>(addr)[j]); - reinterpret_cast<char*>(addr)[j] = j; - } -} - -// Map more than the gVisor page-cache map unit (64k) and ensure that -// it is consistent with reading from the file. -TEST_F(MMapFileTest, Bug38498194) { - // Choose a sufficiently large map unit. - constexpr int kSize = 4 * 1024 * 1024; - EXPECT_THAT(ftruncate(fd_.get(), kSize), SyscallSucceeds()); - - // Map a large enough region so that multiple internal segments - // are created to back the mapping. - uintptr_t addr; - ASSERT_THAT( - addr = Map(0, kSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - std::vector<char> expect(kSize, 'a'); - std::copy(expect.data(), expect.data() + expect.size(), - reinterpret_cast<volatile char*>(addr)); - - // Trigger writeback for gVisor. In Linux pages stay cached until - // it can't hold onto them anymore. - ASSERT_THAT(Unmap(), SyscallSucceeds()); - - std::vector<char> buf(kSize); - ASSERT_THAT(Read(buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - EXPECT_EQ(buf, expect) << std::string(buf.data(), buf.size()); -} - -// Tests that reading from a file to a memory mapping of the same file does not -// deadlock. See b/34813270. -TEST_F(MMapFileTest, SelfRead) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - EXPECT_THAT(Read(reinterpret_cast<char*>(addr), kPageSize / 2), - SyscallSucceedsWithValue(kPageSize / 2)); - // The resulting file contents are poorly-specified and irrelevant. -} - -// Tests that writing to a file from a memory mapping of the same file does not -// deadlock. Regression test for b/34813270. -TEST_F(MMapFileTest, SelfWrite) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - EXPECT_THAT(Write(reinterpret_cast<char*>(addr), kPageSize / 2), - SyscallSucceedsWithValue(kPageSize / 2)); - // The resulting file contents are poorly-specified and irrelevant. -} - -TEST(MMapDeathTest, TruncateAfterCOWBreak) { - SetupGvisorDeathTest(); - - // Create and map a single-page file. - auto const temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDWR)); - ASSERT_THAT(ftruncate(fd.get(), kPageSize), SyscallSucceeds()); - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd.get(), 0)); - - // Write to this mapping, causing the page to be copied for write. - memset(mapping.ptr(), 'a', mapping.len()); - MaybeSave(); // Trigger a co-operative save cycle. - - // Truncate the file and expect it to invalidate the copied page. - ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds()); - EXPECT_EXIT(*reinterpret_cast<volatile char*>(mapping.ptr()), - ::testing::KilledBySignal(SIGBUS), ""); -} - -// Regression test for #147. -TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) { - std::string filename = NewTempAbsPath(); - - // We have to create the file O_RDONLY to reproduce the bug because - // fsgofer.localFile.Create() silently upgrades O_WRONLY to O_RDWR, causing - // the cached "write-only" FD to be read/write and therefore usable by mmap(). - auto const ro_fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(filename, O_RDONLY | O_CREAT | O_EXCL, 0666)); - - // Get a write-only FD for the same file, which should be ignored by mmap() - // (but isn't in #147). - auto const wo_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_WRONLY)); - ASSERT_THAT(ftruncate(wo_fd.get(), kPageSize), SyscallSucceeds()); - - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(nullptr, kPageSize, PROT_READ, MAP_SHARED, ro_fd.get(), 0)); - std::vector<char> buf(kPageSize); - // The test passes if this survives. - std::copy(static_cast<char*>(mapping.ptr()), - static_cast<char*>(mapping.endptr()), buf.data()); -} - -// Conditional on MAP_32BIT. -#ifdef __x86_64__ - -TEST(MMapNoFixtureTest, Map32Bit) { - auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE | MAP_32BIT)); - EXPECT_LT(mapping.addr(), static_cast<uintptr_t>(1) << 32); - EXPECT_LE(mapping.endaddr(), static_cast<uintptr_t>(1) << 32); -} - -#endif // defined(__x86_64__) - -INSTANTIATE_TEST_SUITE_P( - ReadWriteSharedPrivate, MMapFileParamTest, - ::testing::Combine(::testing::ValuesIn({ - PROT_READ, - PROT_WRITE, - PROT_READ | PROT_WRITE, - }), - ::testing::ValuesIn({MAP_SHARED, MAP_PRIVATE}))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc deleted file mode 100644 index a3e9745cf..000000000 --- a/test/syscalls/linux/mount.cc +++ /dev/null @@ -1,327 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <stdio.h> -#include <sys/mount.h> -#include <sys/stat.h> -#include <unistd.h> - -#include <functional> -#include <memory> -#include <string> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "absl/time/time.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/mount_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(MountTest, MountBadFilesystem) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - // Linux expects a valid target before it checks the file system name. - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(mount("", dir.path().c_str(), "foobar", 0, ""), - SyscallFailsWithErrno(ENODEV)); -} - -TEST(MountTest, MountInvalidTarget) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = NewTempAbsPath(); - EXPECT_THAT(mount("", dir.c_str(), "tmpfs", 0, ""), - SyscallFailsWithErrno(ENOENT)); -} - -TEST(MountTest, MountPermDenied) { - // Clear CAP_SYS_ADMIN. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); - } - - // Linux expects a valid target before checking capability. - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(mount("", dir.path().c_str(), "", 0, ""), - SyscallFailsWithErrno(EPERM)); -} - -TEST(MountTest, UmountPermDenied) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const mount = - ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "", 0)); - - // Drop privileges in another thread, so we can still unmount the mounted - // directory. - ScopedThread([&]() { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); - EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EPERM)); - }); -} - -TEST(MountTest, MountOverBusy) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777)); - - // Should be able to mount over a busy directory. - ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "", 0)); -} - -TEST(MountTest, OpenFileBusy) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const mount = ASSERT_NO_ERRNO_AND_VALUE( - Mount("", dir.path(), "tmpfs", 0, "mode=0700", 0)); - auto const fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777)); - - // An open file should prevent unmounting. - EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); -} - -TEST(MountTest, UmountDetach) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - // structure: - // - // dir (mount point) - // subdir - // file - // - // We show that we can walk around in the mount after detach-unmount dir. - // - // We show that even though dir is unreachable from outside the mount, we can - // still reach dir's (former) parent! - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - const struct stat before = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - auto mount = - ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "mode=0700", - /* umountflags= */ MNT_DETACH)); - const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_NE(before.st_ino, after.st_ino); - - // Create files in the new mount. - constexpr char kContents[] = "no no no"; - auto const subdir = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); - auto const file = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateFileWith(dir.path(), kContents, 0777)); - - auto const dir_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(subdir.path(), O_RDONLY | O_DIRECTORY)); - auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - // Unmount the tmpfs. - mount.Release()(); - - const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(before.st_ino, after2.st_ino); - - // Can still read file after unmounting. - std::vector<char> buf(sizeof(kContents)); - EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), SyscallSucceeds()); - - // Walk to dir. - auto const mounted_dir = ASSERT_NO_ERRNO_AND_VALUE( - OpenAt(dir_fd.get(), "..", O_DIRECTORY | O_RDONLY)); - // Walk to dir/file. - auto const fd_again = ASSERT_NO_ERRNO_AND_VALUE( - OpenAt(mounted_dir.get(), std::string(Basename(file.path())), O_RDONLY)); - - std::vector<char> buf2(sizeof(kContents)); - EXPECT_THAT(ReadFd(fd_again.get(), buf2.data(), buf2.size()), - SyscallSucceeds()); - EXPECT_EQ(buf, buf2); - - // Walking outside the unmounted realm should still work, too! - auto const dir_parent = ASSERT_NO_ERRNO_AND_VALUE( - OpenAt(mounted_dir.get(), "..", O_DIRECTORY | O_RDONLY)); -} - -TEST(MountTest, ActiveSubmountBusy) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const mount1 = ASSERT_NO_ERRNO_AND_VALUE( - Mount("", dir.path(), "tmpfs", 0, "mode=0700", 0)); - - auto const dir2 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); - auto const mount2 = - ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir2.path(), "tmpfs", 0, "", 0)); - - // Since dir now has an active submount, should not be able to unmount. - EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); -} - -TEST(MountTest, MountTmpfs) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // NOTE(b/129868551): Inode IDs are only stable across S/R if we have an open - // FD for that inode. Since we are going to compare inode IDs below, get a - // FileDescriptor for this directory here, which will be closed automatically - // at the end of the test. - auto const fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY, O_RDONLY)); - - const struct stat before = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - - { - auto const mount = ASSERT_NO_ERRNO_AND_VALUE( - Mount("", dir.path(), "tmpfs", 0, "mode=0700", 0)); - - const struct stat s = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(s.st_mode, S_IFDIR | 0700); - EXPECT_NE(s.st_ino, before.st_ino); - - EXPECT_NO_ERRNO(Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777)); - } - - // Now that dir is unmounted again, we should have the old inode back. - const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(before.st_ino, after.st_ino); -} - -TEST(MountTest, MountTmpfsMagicValIgnored) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - auto const mount = ASSERT_NO_ERRNO_AND_VALUE( - Mount("", dir.path(), "tmpfs", MS_MGC_VAL, "mode=0700", 0)); -} - -// Passing nullptr to data is equivalent to "". -TEST(MountTest, NullData) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - EXPECT_THAT(mount("", dir.path().c_str(), "tmpfs", 0, nullptr), - SyscallSucceeds()); - EXPECT_THAT(umount2(dir.path().c_str(), 0), SyscallSucceeds()); -} - -TEST(MountTest, MountReadonly) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const mount = ASSERT_NO_ERRNO_AND_VALUE( - Mount("", dir.path(), "tmpfs", MS_RDONLY, "mode=0777", 0)); - - const struct stat s = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(s.st_mode, S_IFDIR | 0777); - - std::string const filename = JoinPath(dir.path(), "foo"); - EXPECT_THAT(open(filename.c_str(), O_RDWR | O_CREAT, 0777), - SyscallFailsWithErrno(EROFS)); -} - -PosixErrorOr<absl::Time> ATime(absl::string_view file) { - struct stat s = {}; - if (stat(std::string(file).c_str(), &s) == -1) { - return PosixError(errno, "stat failed"); - } - return absl::TimeFromTimespec(s.st_atim); -} - -TEST(MountTest, MountNoAtime) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const mount = ASSERT_NO_ERRNO_AND_VALUE( - Mount("", dir.path(), "tmpfs", MS_NOATIME, "mode=0777", 0)); - - std::string const contents = "No no no, don't follow the instructions!"; - auto const file = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateFileWith(dir.path(), contents, 0777)); - - absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(ATime(file.path())); - - // Reading from the file should change the atime, but the MS_NOATIME flag - // should prevent that. - auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - char buf[100]; - int read_n; - ASSERT_THAT(read_n = read(fd.get(), buf, sizeof(buf)), SyscallSucceeds()); - EXPECT_EQ(std::string(buf, read_n), contents); - - absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(ATime(file.path())); - - // Expect that atime hasn't changed. - EXPECT_EQ(before, after); -} - -TEST(MountTest, MountNoExec) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const mount = ASSERT_NO_ERRNO_AND_VALUE( - Mount("", dir.path(), "tmpfs", MS_NOEXEC, "mode=0777", 0)); - - std::string const contents = "No no no, don't follow the instructions!"; - auto const file = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateFileWith(dir.path(), contents, 0777)); - - int execve_errno; - ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(file.path(), {}, {}, nullptr, &execve_errno)); - EXPECT_EQ(execve_errno, EACCES); -} - -TEST(MountTest, RenameRemoveMountPoint) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - auto const dir_parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const dir = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir_parent.path())); - auto const new_dir = NewTempAbsPath(); - - auto const mount = - ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "", 0)); - - ASSERT_THAT(rename(dir.path().c_str(), new_dir.c_str()), - SyscallFailsWithErrno(EBUSY)); - - ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/mremap.cc b/test/syscalls/linux/mremap.cc deleted file mode 100644 index f0e5f7d82..000000000 --- a/test/syscalls/linux/mremap.cc +++ /dev/null @@ -1,492 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <string.h> -#include <sys/mman.h> - -#include <string> - -#include "gmock/gmock.h" -#include "absl/strings/string_view.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/memory_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -using ::testing::_; - -namespace gvisor { -namespace testing { - -namespace { - -// Fixture for mremap tests parameterized by mmap flags. -using MremapParamTest = ::testing::TestWithParam<int>; - -TEST_P(MremapParamTest, Noop) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam())); - - ASSERT_THAT(Mremap(m.ptr(), kPageSize, kPageSize, 0, nullptr), - IsPosixErrorOkAndHolds(m.ptr())); - EXPECT_TRUE(IsMapped(m.addr())); -} - -TEST_P(MremapParamTest, InPlace_ShrinkingWholeVMA) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // N.B. we must be in a single-threaded subprocess to ensure a - // background thread doesn't concurrently map the second page. - void* addr = mremap(m.ptr(), 2 * kPageSize, kPageSize, 0, nullptr); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == m.ptr()); - MaybeSave(); - - TEST_CHECK(IsMapped(m.addr())); - TEST_CHECK(!IsMapped(m.addr() + kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, InPlace_ShrinkingPartialVMA) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - void* addr = mremap(m.ptr(), 2 * kPageSize, kPageSize, 0, nullptr); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == m.ptr()); - MaybeSave(); - - TEST_CHECK(IsMapped(m.addr())); - TEST_CHECK(!IsMapped(m.addr() + kPageSize)); - TEST_CHECK(IsMapped(m.addr() + 2 * kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, InPlace_ShrinkingAcrossVMAs) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_READ, GetParam())); - // Changing permissions on the first page forces it to become a separate vma. - ASSERT_THAT(mprotect(m.ptr(), kPageSize, PROT_NONE), SyscallSucceeds()); - - const auto rest = [&] { - // Both old_size and new_size now span two vmas; mremap - // shouldn't care. - void* addr = mremap(m.ptr(), 3 * kPageSize, 2 * kPageSize, 0, nullptr); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == m.ptr()); - MaybeSave(); - - TEST_CHECK(IsMapped(m.addr())); - TEST_CHECK(IsMapped(m.addr() + kPageSize)); - TEST_CHECK(!IsMapped(m.addr() + 2 * kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, InPlace_ExpansionSuccess) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // Unmap the second page so that the first can be expanded back into it. - // - // N.B. we must be in a single-threaded subprocess to ensure a - // background thread doesn't concurrently map this page. - TEST_PCHECK( - munmap(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize) == 0); - MaybeSave(); - - void* addr = mremap(m.ptr(), kPageSize, 2 * kPageSize, 0, nullptr); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == m.ptr()); - MaybeSave(); - - TEST_CHECK(IsMapped(m.addr())); - TEST_CHECK(IsMapped(m.addr() + kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, InPlace_ExpansionFailure) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // Unmap the second page, leaving a one-page hole. Trying to expand the - // first page to three pages should fail since the original third page - // is still mapped. - TEST_PCHECK( - munmap(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize) == 0); - MaybeSave(); - - void* addr = mremap(m.ptr(), kPageSize, 3 * kPageSize, 0, nullptr); - TEST_CHECK_MSG(addr == MAP_FAILED, "mremap unexpectedly succeeded"); - TEST_PCHECK_MSG(errno == ENOMEM, "mremap failed with wrong errno"); - MaybeSave(); - - TEST_CHECK(IsMapped(m.addr())); - TEST_CHECK(!IsMapped(m.addr() + kPageSize)); - TEST_CHECK(IsMapped(m.addr() + 2 * kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, MayMove_Expansion) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // Unmap the second page, leaving a one-page hole. Trying to expand the - // first page to three pages with MREMAP_MAYMOVE should force the - // mapping to be relocated since the original third page is still - // mapped. - TEST_PCHECK( - munmap(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize) == 0); - MaybeSave(); - - void* addr2 = - mremap(m.ptr(), kPageSize, 3 * kPageSize, MREMAP_MAYMOVE, nullptr); - TEST_PCHECK_MSG(addr2 != MAP_FAILED, "mremap failed"); - MaybeSave(); - - const Mapping m2 = Mapping(addr2, 3 * kPageSize); - TEST_CHECK(m.addr() != m2.addr()); - - TEST_CHECK(!IsMapped(m.addr())); - TEST_CHECK(!IsMapped(m.addr() + kPageSize)); - TEST_CHECK(IsMapped(m.addr() + 2 * kPageSize)); - TEST_CHECK(IsMapped(m2.addr())); - TEST_CHECK(IsMapped(m2.addr() + kPageSize)); - TEST_CHECK(IsMapped(m2.addr() + 2 * kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, Fixed_SourceAndDestinationCannotOverlap) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam())); - - ASSERT_THAT(Mremap(m.ptr(), kPageSize, kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, m.ptr()), - PosixErrorIs(EINVAL, _)); - EXPECT_TRUE(IsMapped(m.addr())); -} - -TEST_P(MremapParamTest, Fixed_SameSize) { - Mapping const src = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam())); - Mapping const dst = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // Unmap dst to create a hole. - TEST_PCHECK(munmap(dst.ptr(), kPageSize) == 0); - MaybeSave(); - - void* addr = mremap(src.ptr(), kPageSize, kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr()); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == dst.ptr()); - MaybeSave(); - - TEST_CHECK(!IsMapped(src.addr())); - TEST_CHECK(IsMapped(dst.addr())); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, Fixed_SameSize_Unmapping) { - // Like the Fixed_SameSize case, but expect mremap to unmap the destination - // automatically. - Mapping const src = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam())); - Mapping const dst = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - void* addr = mremap(src.ptr(), kPageSize, kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr()); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == dst.ptr()); - MaybeSave(); - - TEST_CHECK(!IsMapped(src.addr())); - TEST_CHECK(IsMapped(dst.addr())); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, Fixed_ShrinkingWholeVMA) { - Mapping const src = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam())); - Mapping const dst = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // Unmap dst so we can check that mremap does not keep the - // second page. - TEST_PCHECK(munmap(dst.ptr(), 2 * kPageSize) == 0); - MaybeSave(); - - void* addr = mremap(src.ptr(), 2 * kPageSize, kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr()); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == dst.ptr()); - MaybeSave(); - - TEST_CHECK(!IsMapped(src.addr())); - TEST_CHECK(!IsMapped(src.addr() + kPageSize)); - TEST_CHECK(IsMapped(dst.addr())); - TEST_CHECK(!IsMapped(dst.addr() + kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, Fixed_ShrinkingPartialVMA) { - Mapping const src = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_NONE, GetParam())); - Mapping const dst = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // Unmap dst so we can check that mremap does not keep the - // second page. - TEST_PCHECK(munmap(dst.ptr(), 2 * kPageSize) == 0); - MaybeSave(); - - void* addr = mremap(src.ptr(), 2 * kPageSize, kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr()); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == dst.ptr()); - MaybeSave(); - - TEST_CHECK(!IsMapped(src.addr())); - TEST_CHECK(!IsMapped(src.addr() + kPageSize)); - TEST_CHECK(IsMapped(src.addr() + 2 * kPageSize)); - TEST_CHECK(IsMapped(dst.addr())); - TEST_CHECK(!IsMapped(dst.addr() + kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, Fixed_ShrinkingAcrossVMAs) { - Mapping const src = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_READ, GetParam())); - // Changing permissions on the first page forces it to become a separate vma. - ASSERT_THAT(mprotect(src.ptr(), kPageSize, PROT_NONE), SyscallSucceeds()); - Mapping const dst = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // Unlike flags=0, MREMAP_FIXED requires that [old_address, - // old_address+new_size) only spans a single vma. - void* addr = mremap(src.ptr(), 3 * kPageSize, 2 * kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr()); - TEST_CHECK_MSG(addr == MAP_FAILED, "mremap unexpectedly succeeded"); - TEST_PCHECK_MSG(errno == EFAULT, "mremap failed with wrong errno"); - MaybeSave(); - - TEST_CHECK(IsMapped(src.addr())); - TEST_CHECK(IsMapped(src.addr() + kPageSize)); - // Despite failing, mremap should have unmapped [old_address+new_size, - // old_address+old_size) (i.e. the third page). - TEST_CHECK(!IsMapped(src.addr() + 2 * kPageSize)); - // Despite failing, mremap should have unmapped the destination pages. - TEST_CHECK(!IsMapped(dst.addr())); - TEST_CHECK(!IsMapped(dst.addr() + kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST_P(MremapParamTest, Fixed_Expansion) { - Mapping const src = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam())); - Mapping const dst = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam())); - - const auto rest = [&] { - // Unmap dst so we can check that mremap actually maps all pages - // at the destination. - TEST_PCHECK(munmap(dst.ptr(), 2 * kPageSize) == 0); - MaybeSave(); - - void* addr = mremap(src.ptr(), kPageSize, 2 * kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr()); - TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed"); - TEST_CHECK(addr == dst.ptr()); - MaybeSave(); - - TEST_CHECK(!IsMapped(src.addr())); - TEST_CHECK(IsMapped(dst.addr())); - TEST_CHECK(IsMapped(dst.addr() + kPageSize)); - }; - - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -INSTANTIATE_TEST_SUITE_P(PrivateShared, MremapParamTest, - ::testing::Values(MAP_PRIVATE, MAP_SHARED)); - -// mremap with old_size == 0 only works with MAP_SHARED after Linux 4.14 -// (dba58d3b8c50 "mm/mremap: fail map duplication attempts for private -// mappings"). - -TEST(MremapTest, InPlace_Copy) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_SHARED)); - EXPECT_THAT(Mremap(m.ptr(), 0, kPageSize, 0, nullptr), - PosixErrorIs(ENOMEM, _)); -} - -TEST(MremapTest, MayMove_Copy) { - Mapping const m = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_SHARED)); - - // Remainder of this test executes in a subprocess to ensure that if mremap - // incorrectly removes m, it is not remapped by another thread. - const auto rest = [&] { - void* ptr = mremap(m.ptr(), 0, kPageSize, MREMAP_MAYMOVE, nullptr); - MaybeSave(); - TEST_PCHECK_MSG(ptr != MAP_FAILED, "mremap failed"); - TEST_CHECK(ptr != m.ptr()); - TEST_CHECK(IsMapped(m.addr())); - TEST_CHECK(IsMapped(reinterpret_cast<uintptr_t>(ptr))); - }; - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -TEST(MremapTest, MustMove_Copy) { - Mapping const src = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_SHARED)); - Mapping const dst = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE)); - - // Remainder of this test executes in a subprocess to ensure that if mremap - // incorrectly removes src, it is not remapped by another thread. - const auto rest = [&] { - void* ptr = mremap(src.ptr(), 0, kPageSize, MREMAP_MAYMOVE | MREMAP_FIXED, - dst.ptr()); - MaybeSave(); - TEST_PCHECK_MSG(ptr != MAP_FAILED, "mremap failed"); - TEST_CHECK(ptr == dst.ptr()); - TEST_CHECK(IsMapped(src.addr())); - TEST_CHECK(IsMapped(dst.addr())); - }; - EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); -} - -void ExpectAllBytesAre(absl::string_view v, char c) { - for (size_t i = 0; i < v.size(); i++) { - ASSERT_EQ(v[i], c) << "at offset " << i; - } -} - -TEST(MremapTest, ExpansionPreservesCOWPagesAndExposesNewFilePages) { - // Create a file with 3 pages. The first is filled with 'a', the second is - // filled with 'b', and the third is filled with 'c'. - TempPath const file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - ASSERT_THAT(WriteFd(fd.get(), std::string(kPageSize, 'a').c_str(), kPageSize), - SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(WriteFd(fd.get(), std::string(kPageSize, 'b').c_str(), kPageSize), - SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(WriteFd(fd.get(), std::string(kPageSize, 'c').c_str(), kPageSize), - SyscallSucceedsWithValue(kPageSize)); - - // Create a private mapping of the first 2 pages, and fill the second page - // with 'd'. - Mapping const src = ASSERT_NO_ERRNO_AND_VALUE(Mmap(nullptr, 2 * kPageSize, - PROT_READ | PROT_WRITE, - MAP_PRIVATE, fd.get(), 0)); - memset(reinterpret_cast<void*>(src.addr() + kPageSize), 'd', kPageSize); - MaybeSave(); - - // Move the mapping while expanding it to 3 pages. The resulting mapping - // should contain the original first page of the file (filled with 'a'), - // followed by the private copy of the second page (filled with 'd'), followed - // by the newly-mapped third page of the file (filled with 'c'). - Mapping const dst = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(3 * kPageSize, PROT_NONE, MAP_PRIVATE)); - ASSERT_THAT(Mremap(src.ptr(), 2 * kPageSize, 3 * kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr()), - IsPosixErrorOkAndHolds(dst.ptr())); - auto const v = dst.view(); - ExpectAllBytesAre(v.substr(0, kPageSize), 'a'); - ExpectAllBytesAre(v.substr(kPageSize, kPageSize), 'd'); - ExpectAllBytesAre(v.substr(2 * kPageSize, kPageSize), 'c'); -} - -TEST(MremapDeathTest, SharedAnon) { - SetupGvisorDeathTest(); - - // Reserve 4 pages of address space. - Mapping const reserved = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(4 * kPageSize, PROT_NONE, MAP_PRIVATE)); - - // Create a 2-page shared anonymous mapping at the beginning of the - // reservation. Fill the first page with 'a' and the second with 'b'. - Mapping const m = ASSERT_NO_ERRNO_AND_VALUE( - Mmap(reserved.ptr(), 2 * kPageSize, PROT_READ | PROT_WRITE, - MAP_SHARED | MAP_ANONYMOUS | MAP_FIXED, -1, 0)); - memset(m.ptr(), 'a', kPageSize); - memset(reinterpret_cast<void*>(m.addr() + kPageSize), 'b', kPageSize); - MaybeSave(); - - // Shrink the mapping to 1 page in-place. - ASSERT_THAT(Mremap(m.ptr(), 2 * kPageSize, kPageSize, 0, m.ptr()), - IsPosixErrorOkAndHolds(m.ptr())); - - // Expand the mapping to 3 pages, moving it forward by 1 page in the process - // since the old and new mappings can't overlap. - void* const new_m = reinterpret_cast<void*>(m.addr() + kPageSize); - ASSERT_THAT(Mremap(m.ptr(), kPageSize, 3 * kPageSize, - MREMAP_MAYMOVE | MREMAP_FIXED, new_m), - IsPosixErrorOkAndHolds(new_m)); - - // The first 2 pages of the mapping should still contain the data we wrote - // (i.e. shrinking should not have discarded the second page's data), while - // touching the third page should raise SIGBUS. - auto const v = - absl::string_view(static_cast<char const*>(new_m), 3 * kPageSize); - ExpectAllBytesAre(v.substr(0, kPageSize), 'a'); - ExpectAllBytesAre(v.substr(kPageSize, kPageSize), 'b'); - EXPECT_EXIT(ExpectAllBytesAre(v.substr(2 * kPageSize, kPageSize), '\0'), - ::testing::KilledBySignal(SIGBUS), ""); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/msync.cc b/test/syscalls/linux/msync.cc deleted file mode 100644 index 2b2b6aef9..000000000 --- a/test/syscalls/linux/msync.cc +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/mman.h> -#include <unistd.h> - -#include <functional> -#include <string> -#include <utility> -#include <vector> - -#include "test/util/file_descriptor.h" -#include "test/util/memory_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Parameters for msync tests. Use a std::tuple so we can use -// ::testing::Combine. -using MsyncTestParam = - std::tuple<int, // msync flags - std::function<PosixErrorOr<Mapping>()> // returns mapping to - // msync - >; - -class MsyncParameterizedTest : public ::testing::TestWithParam<MsyncTestParam> { - protected: - int msync_flags() const { return std::get<0>(GetParam()); } - - PosixErrorOr<Mapping> GetMapping() const { return std::get<1>(GetParam())(); } -}; - -// All valid msync(2) flag combinations, not including MS_INVALIDATE. ("Linux -// permits a call to msync() that specifies neither [MS_SYNC or MS_ASYNC], with -// semantics that are (currently) equivalent to specifying MS_ASYNC." - -// msync(2)) -constexpr std::initializer_list<int> kMsyncFlags = {MS_SYNC, MS_ASYNC, 0}; - -// Returns functions that return mappings that should be successfully -// msync()able. -std::vector<std::function<PosixErrorOr<Mapping>()>> SyncableMappings() { - std::vector<std::function<PosixErrorOr<Mapping>()>> funcs; - for (bool const writable : {false, true}) { - for (int const mflags : {MAP_PRIVATE, MAP_SHARED}) { - int const prot = PROT_READ | (writable ? PROT_WRITE : 0); - int const oflags = O_CREAT | (writable ? O_RDWR : O_RDONLY); - funcs.push_back([=] { return MmapAnon(kPageSize, prot, mflags); }); - funcs.push_back([=]() -> PosixErrorOr<Mapping> { - std::string const path = NewTempAbsPath(); - ASSIGN_OR_RETURN_ERRNO(auto fd, Open(path, oflags, 0644)); - // Don't unlink the file since that breaks save/restore. Just let the - // test infrastructure clean up all of our temporary files when we're - // done. - return Mmap(nullptr, kPageSize, prot, mflags, fd.get(), 0); - }); - } - } - return funcs; -} - -PosixErrorOr<Mapping> NoMappings() { - return PosixError(EINVAL, "unexpected attempt to create a mapping"); -} - -// "Fixture" for msync tests that hold for all valid flags, but do not create -// mappings. -using MsyncNoMappingTest = MsyncParameterizedTest; - -TEST_P(MsyncNoMappingTest, UnmappedAddressWithZeroLengthSucceeds) { - EXPECT_THAT(msync(nullptr, 0, msync_flags()), SyscallSucceeds()); -} - -TEST_P(MsyncNoMappingTest, UnmappedAddressWithNonzeroLengthFails) { - EXPECT_THAT(msync(nullptr, kPageSize, msync_flags()), - SyscallFailsWithErrno(ENOMEM)); -} - -INSTANTIATE_TEST_SUITE_P(All, MsyncNoMappingTest, - ::testing::Combine(::testing::ValuesIn(kMsyncFlags), - ::testing::Values(NoMappings))); - -// "Fixture" for msync tests that are not parameterized by msync flags, but do -// create mappings. -using MsyncNoFlagsTest = MsyncParameterizedTest; - -TEST_P(MsyncNoFlagsTest, BothSyncAndAsyncFails) { - auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping()); - EXPECT_THAT(msync(m.ptr(), m.len(), MS_SYNC | MS_ASYNC), - SyscallFailsWithErrno(EINVAL)); -} - -INSTANTIATE_TEST_SUITE_P( - All, MsyncNoFlagsTest, - ::testing::Combine(::testing::Values(0), // ignored - ::testing::ValuesIn(SyncableMappings()))); - -// "Fixture" for msync tests parameterized by both msync flags and sources of -// mappings. -using MsyncFullParamTest = MsyncParameterizedTest; - -TEST_P(MsyncFullParamTest, NormallySucceeds) { - auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping()); - EXPECT_THAT(msync(m.ptr(), m.len(), msync_flags()), SyscallSucceeds()); -} - -TEST_P(MsyncFullParamTest, UnalignedLengthSucceeds) { - auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping()); - EXPECT_THAT(msync(m.ptr(), m.len() - 1, msync_flags()), SyscallSucceeds()); -} - -TEST_P(MsyncFullParamTest, UnalignedAddressFails) { - auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping()); - EXPECT_THAT( - msync(reinterpret_cast<void*>(m.addr() + 1), m.len() - 1, msync_flags()), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(MsyncFullParamTest, InvalidateUnlockedSucceeds) { - auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping()); - EXPECT_THAT(msync(m.ptr(), m.len(), msync_flags() | MS_INVALIDATE), - SyscallSucceeds()); -} - -// The test for MS_INVALIDATE on mlocked pages is in mlock.cc since it requires -// probing for mlock support. - -INSTANTIATE_TEST_SUITE_P( - All, MsyncFullParamTest, - ::testing::Combine(::testing::ValuesIn(kMsyncFlags), - ::testing::ValuesIn(SyncableMappings()))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/munmap.cc b/test/syscalls/linux/munmap.cc deleted file mode 100644 index 067241f4d..000000000 --- a/test/syscalls/linux/munmap.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/mman.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class MunmapTest : public ::testing::Test { - protected: - void SetUp() override { - m_ = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - ASSERT_NE(MAP_FAILED, m_); - } - - void* m_ = nullptr; -}; - -TEST_F(MunmapTest, HappyCase) { - EXPECT_THAT(munmap(m_, kPageSize), SyscallSucceeds()); -} - -TEST_F(MunmapTest, ZeroLength) { - EXPECT_THAT(munmap(m_, 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(MunmapTest, LastPageRoundUp) { - // Attempt to unmap up to and including the last page. - EXPECT_THAT(munmap(m_, static_cast<size_t>(-kPageSize + 1)), - SyscallFailsWithErrno(EINVAL)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc deleted file mode 100644 index 133fdecf0..000000000 --- a/test/syscalls/linux/network_namespace.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <net/if.h> -#include <sched.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { -namespace { - -TEST(NetworkNamespaceTest, LoopbackExists) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - - ScopedThread t([&] { - ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0)); - - // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists. - // Check loopback device exists. - int sock = socket(AF_INET, SOCK_DGRAM, 0); - ASSERT_THAT(sock, SyscallSucceeds()); - struct ifreq ifr; - strncpy(ifr.ifr_name, "lo", IFNAMSIZ); - EXPECT_THAT(ioctl(sock, SIOCGIFINDEX, &ifr), SyscallSucceeds()) - << "lo cannot be found"; - }); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc deleted file mode 100644 index 267ae19f6..000000000 --- a/test/syscalls/linux/open.cc +++ /dev/null @@ -1,400 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <linux/capability.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// This test is currently very rudimentary. -// -// There are plenty of extra cases to cover once the sentry supports them. -// -// Different types of opens: -// * O_CREAT -// * O_DIRECTORY -// * O_NOFOLLOW -// * O_PATH <- Will we ever support this? -// -// Special operations on open: -// * O_EXCL -// -// Special files: -// * Blocking behavior for a named pipe. -// -// Different errors: -// * EACCES -// * EEXIST -// * ENAMETOOLONG -// * ELOOP -// * ENOTDIR -// * EPERM -class OpenTest : public FileTest { - void SetUp() override { - FileTest::SetUp(); - - ASSERT_THAT( - write(test_file_fd_.get(), test_data_.c_str(), test_data_.length()), - SyscallSucceedsWithValue(test_data_.length())); - EXPECT_THAT(lseek(test_file_fd_.get(), 0, SEEK_SET), SyscallSucceeds()); - } - - public: - const std::string test_data_ = "hello world\n"; -}; - -TEST_F(OpenTest, OTrunc) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_TRUNC, 0666), - SyscallFailsWithErrno(EISDIR)); -} - -TEST_F(OpenTest, OTruncAndReadOnlyDir) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666), - SyscallFailsWithErrno(EISDIR)); -} - -TEST_F(OpenTest, OTruncAndReadOnlyFile) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); - const FileDescriptor existing = - ASSERT_NO_ERRNO_AND_VALUE(Open(dirpath.c_str(), O_RDWR | O_CREAT, 0666)); - const FileDescriptor otrunc = ASSERT_NO_ERRNO_AND_VALUE( - Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666)); -} - -TEST_F(OpenTest, ReadOnly) { - char buf; - const FileDescriptor ro_file = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); - - EXPECT_THAT(read(ro_file.get(), &buf, 1), SyscallSucceedsWithValue(1)); - EXPECT_THAT(lseek(ro_file.get(), 0, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT(write(ro_file.get(), &buf, 1), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(OpenTest, WriteOnly) { - char buf; - const FileDescriptor wo_file = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_WRONLY)); - - EXPECT_THAT(read(wo_file.get(), &buf, 1), SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(lseek(wo_file.get(), 0, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT(write(wo_file.get(), &buf, 1), SyscallSucceedsWithValue(1)); -} - -TEST_F(OpenTest, ReadWrite) { - char buf; - const FileDescriptor rw_file = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - EXPECT_THAT(read(rw_file.get(), &buf, 1), SyscallSucceedsWithValue(1)); - EXPECT_THAT(lseek(rw_file.get(), 0, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT(write(rw_file.get(), &buf, 1), SyscallSucceedsWithValue(1)); -} - -TEST_F(OpenTest, RelPath) { - auto name = std::string(Basename(test_file_name_)); - - ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name, O_RDONLY)); -} - -TEST_F(OpenTest, AbsPath) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); -} - -TEST_F(OpenTest, AtRelPath) { - auto name = std::string(Basename(test_file_name_)); - const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE( - Open(GetAbsoluteTestTmpdir(), O_RDONLY | O_DIRECTORY)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(dirfd.get(), name, O_RDONLY)); -} - -TEST_F(OpenTest, AtAbsPath) { - const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE( - Open(GetAbsoluteTestTmpdir(), O_RDONLY | O_DIRECTORY)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(dirfd.get(), test_file_name_, O_RDONLY)); -} - -TEST_F(OpenTest, OpenNoFollowSymlink) { - const std::string link_path = JoinPath(GetAbsoluteTestTmpdir(), "link"); - ASSERT_THAT(symlink(test_file_name_.c_str(), link_path.c_str()), - SyscallSucceeds()); - auto cleanup = Cleanup([link_path]() { - EXPECT_THAT(unlink(link_path.c_str()), SyscallSucceeds()); - }); - - // Open will succeed without O_NOFOLLOW and fails with O_NOFOLLOW. - const FileDescriptor fd2 = - ASSERT_NO_ERRNO_AND_VALUE(Open(link_path, O_RDONLY)); - ASSERT_THAT(open(link_path.c_str(), O_RDONLY | O_NOFOLLOW), - SyscallFailsWithErrno(ELOOP)); -} - -TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) { - // We will create the following structure: - // tmp_folder/real_folder/file - // tmp_folder/sym_folder -> tmp_folder/real_folder - // - // We will then open tmp_folder/sym_folder/file with O_NOFOLLOW and it - // should succeed as O_NOFOLLOW only applies to the final path component. - auto tmp_path = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir())); - auto sym_path = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), tmp_path.path())); - auto file_path = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(tmp_path.path())); - - auto path_via_symlink = JoinPath(sym_path.path(), Basename(file_path.path())); - const FileDescriptor fd2 = - ASSERT_NO_ERRNO_AND_VALUE(Open(path_via_symlink, O_RDONLY | O_NOFOLLOW)); -} - -TEST_F(OpenTest, Fault) { - char* totally_not_null = nullptr; - ASSERT_THAT(open(totally_not_null, O_RDONLY), SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(OpenTest, AppendOnly) { - // First write some data to the fresh file. - const int64_t kBufSize = 1024; - std::vector<char> buf(kBufSize, 'a'); - - FileDescriptor fd0 = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - - std::fill(buf.begin(), buf.end(), 'a'); - EXPECT_THAT(WriteFd(fd0.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - fd0.reset(); // Close the file early. - - // Next get two handles to the same file. We open two files because we want - // to make sure that appending is respected between them. - const FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_APPEND)); - EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - - const FileDescriptor fd2 = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_APPEND)); - EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - - // Then try to write to the first file and make sure the bytes are appended. - EXPECT_THAT(WriteFd(fd1.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - // Check that the size of the file is correct and that the offset has been - // incremented to that size. - struct stat s0; - EXPECT_THAT(fstat(fd1.get(), &s0), SyscallSucceeds()); - EXPECT_EQ(s0.st_size, kBufSize * 2); - EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(kBufSize * 2)); - - // Then try to write to the second file and make sure the bytes are appended. - EXPECT_THAT(WriteFd(fd2.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - // Check that the size of the file is correct and that the offset has been - // incremented to that size. - struct stat s1; - EXPECT_THAT(fstat(fd2.get(), &s1), SyscallSucceeds()); - EXPECT_EQ(s1.st_size, kBufSize * 3); - EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(kBufSize * 3)); -} - -TEST_F(OpenTest, AppendConcurrentWrite) { - constexpr int kThreadCount = 5; - constexpr int kBytesPerThread = 10000; - std::unique_ptr<ScopedThread> threads[kThreadCount]; - - // In case of the uncached policy, we expect that a file system can be changed - // externally, so we create a new inode each time when we open a file and we - // can't guarantee that writes to files with O_APPEND will work correctly. - SKIP_IF(getenv("GVISOR_GOFER_UNCACHED")); - - EXPECT_THAT(truncate(test_file_name_.c_str(), 0), SyscallSucceeds()); - - std::string filename = test_file_name_; - DisableSave ds; // Too many syscalls. - // Start kThreadCount threads which will write concurrently into the same - // file. - for (int i = 0; i < kThreadCount; i++) { - threads[i] = absl::make_unique<ScopedThread>([filename]() { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_RDWR | O_APPEND)); - - for (int j = 0; j < kBytesPerThread; j++) { - EXPECT_THAT(WriteFd(fd.get(), &j, 1), SyscallSucceedsWithValue(1)); - } - }); - } - for (int i = 0; i < kThreadCount; i++) { - threads[i]->Join(); - } - - // Check that the size of the file is correct. - struct stat st; - EXPECT_THAT(stat(test_file_name_.c_str(), &st), SyscallSucceeds()); - EXPECT_EQ(st.st_size, kThreadCount * kBytesPerThread); -} - -TEST_F(OpenTest, Truncate) { - { - // First write some data to the new file and close it. - FileDescriptor fd0 = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_WRONLY)); - std::vector<char> orig(10, 'a'); - EXPECT_THAT(WriteFd(fd0.get(), orig.data(), orig.size()), - SyscallSucceedsWithValue(orig.size())); - } - - // Then open with truncate and verify that offset is set to 0. - const FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_TRUNC)); - EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - - // Then write less data to the file and ensure the old content is gone. - std::vector<char> want(5, 'b'); - EXPECT_THAT(WriteFd(fd1.get(), want.data(), want.size()), - SyscallSucceedsWithValue(want.size())); - - struct stat stat; - EXPECT_THAT(fstat(fd1.get(), &stat), SyscallSucceeds()); - EXPECT_EQ(stat.st_size, want.size()); - EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(want.size())); - - // Read the data and ensure only the latest write is in the file. - std::vector<char> got(want.size() + 1, 'c'); - ASSERT_THAT(pread(fd1.get(), got.data(), got.size(), 0), - SyscallSucceedsWithValue(want.size())); - EXPECT_EQ(memcmp(want.data(), got.data(), want.size()), 0) - << "rbuf=" << got.data(); - EXPECT_EQ(got.back(), 'c'); // Last byte should not have been modified. -} - -TEST_F(OpenTest, NameTooLong) { - char buf[4097] = {}; - memset(buf, 'a', 4097); - EXPECT_THAT(open(buf, O_RDONLY), SyscallFailsWithErrno(ENAMETOOLONG)); -} - -TEST_F(OpenTest, DotsFromRoot) { - const FileDescriptor rootfd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/", O_RDONLY | O_DIRECTORY)); - const FileDescriptor other_rootfd = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(rootfd.get(), "..", O_RDONLY)); -} - -TEST_F(OpenTest, DirectoryWritableFails) { - ASSERT_THAT(open(GetAbsoluteTestTmpdir().c_str(), O_RDWR), - SyscallFailsWithErrno(EISDIR)); -} - -TEST_F(OpenTest, FileNotDirectory) { - // Create a file and try to open it with O_DIRECTORY. - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_THAT(open(file.path().c_str(), O_RDONLY | O_DIRECTORY), - SyscallFailsWithErrno(ENOTDIR)); -} - -TEST_F(OpenTest, Null) { - char c = '\0'; - ASSERT_THAT(open(&c, O_RDONLY), SyscallFailsWithErrno(ENOENT)); -} - -// NOTE(b/119785738): While the man pages specify that this behavior should be -// undefined, Linux truncates the file on opening read only if we have write -// permission, so we will too. -TEST_F(OpenTest, CanTruncateReadOnly) { - const FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY | O_TRUNC)); - - struct stat stat; - EXPECT_THAT(fstat(fd1.get(), &stat), SyscallSucceeds()); - EXPECT_EQ(stat.st_size, 0); -} - -// If we don't have read permission on the file, opening with -// O_TRUNC should fail. -TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission_NoRandomSave) { - // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - - const DisableSave ds; // Permissions are dropped. - ASSERT_THAT(chmod(test_file_name_.c_str(), S_IRUSR | S_IRGRP), - SyscallSucceeds()); - - ASSERT_THAT(open(test_file_name_.c_str(), O_RDONLY | O_TRUNC), - SyscallFailsWithErrno(EACCES)); - - const FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); - - struct stat stat; - EXPECT_THAT(fstat(fd1.get(), &stat), SyscallSucceeds()); - EXPECT_EQ(stat.st_size, test_data_.size()); -} - -// If we don't have read permission but have write permission, opening O_WRONLY -// and O_TRUNC should succeed. -TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission_NoRandomSave) { - const DisableSave ds; // Permissions are dropped. - - EXPECT_THAT(fchmod(test_file_fd_.get(), S_IWUSR | S_IWGRP), - SyscallSucceeds()); - - const FileDescriptor fd1 = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_WRONLY | O_TRUNC)); - - EXPECT_THAT(fchmod(test_file_fd_.get(), S_IRUSR | S_IRGRP), - SyscallSucceeds()); - - const FileDescriptor fd2 = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); - - struct stat stat; - EXPECT_THAT(fstat(fd2.get(), &stat), SyscallSucceeds()); - EXPECT_EQ(stat.st_size, 0); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc deleted file mode 100644 index 51eacf3f2..000000000 --- a/test/syscalls/linux/open_create.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/temp_umask.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { -TEST(CreateTest, TmpFile) { - int fd; - EXPECT_THAT(fd = open(JoinPath(GetAbsoluteTestTmpdir(), "a").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(CreateTest, ExistingFile) { - int fd; - EXPECT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(CreateTest, CreateAtFile) { - int dirfd; - EXPECT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_DIRECTORY, 0666), - SyscallSucceeds()); - EXPECT_THAT(openat(dirfd, "CreateAtFile", O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(dirfd), SyscallSucceeds()); -} - -TEST(CreateTest, HonorsUmask_NoRandomSave) { - const DisableSave ds; // file cannot be re-opened as writable. - TempUmask mask(0222); - int fd; - ASSERT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "UmaskedFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - struct stat statbuf; - ASSERT_THAT(fstat(fd, &statbuf), SyscallSucceeds()); - EXPECT_EQ(0444, statbuf.st_mode & 0777); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(CreateTest, CreateExclusively) { - std::string filename = NewTempAbsPath(); - - int fd; - ASSERT_THAT(fd = open(filename.c_str(), O_CREAT | O_RDWR, 0644), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(open(filename.c_str(), O_CREAT | O_EXCL | O_RDWR, 0644), - SyscallFailsWithErrno(EEXIST)); -} - -TEST(CreateTeast, CreatWithOTrunc) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), - SyscallFailsWithErrno(EISDIR)); -} - -TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), - SyscallFailsWithErrno(EISDIR)); -} - -TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); - int dirfd; - ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), - SyscallSucceeds()); - ASSERT_THAT(close(dirfd), SyscallSucceeds()); -} - -TEST(CreateTest, CreateFailsOnUnpermittedDir) { - // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to - // always override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_THAT(open("/foo", O_CREAT | O_RDWR, 0644), - SyscallFailsWithErrno(EACCES)); -} - -TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) { - // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to - // always override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - auto parent = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555)); - auto file = JoinPath(parent.path(), "foo"); - ASSERT_THAT(open(file.c_str(), O_CREAT | O_RDWR, 0644), - SyscallFailsWithErrno(EACCES)); -} - -// A file originally created RW, but opened RO can later be opened RW. -// Regression test for b/65385065. -TEST(CreateTest, OpenCreateROThenRW) { - TempPath file(NewTempAbsPath()); - - // Create a RW file, but only open it RO. - FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE( - Open(file.path(), O_CREAT | O_EXCL | O_RDONLY, 0644)); - - // Now get a RW FD. - FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - // fd1 is not writable, but fd2 is. - char c = 'a'; - EXPECT_THAT(WriteFd(fd1.get(), &c, 1), SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc deleted file mode 100644 index 5ac68feb4..000000000 --- a/test/syscalls/linux/packet_socket.cc +++ /dev/null @@ -1,440 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <ifaddrs.h> -#include <linux/capability.h> -#include <linux/if_arp.h> -#include <linux/if_packet.h> -#include <net/ethernet.h> -#include <netinet/in.h> -#include <netinet/ip.h> -#include <netinet/udp.h> -#include <poll.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/base/internal/endian.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -// Some of these tests involve sending packets via AF_PACKET sockets and the -// loopback interface. Because AF_PACKET circumvents so much of the networking -// stack, Linux sees these packets as "martian", i.e. they claim to be to/from -// localhost but don't have the usual associated data. Thus Linux drops them by -// default. You can see where this happens by following the code at: -// -// - net/ipv4/ip_input.c:ip_rcv_finish, which calls -// - net/ipv4/route.c:ip_route_input_noref, which calls -// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian -// packets. -// -// To tell Linux not to drop these packets, you need to tell it to accept our -// funny packets (which are completely valid and correct, but lack associated -// in-kernel data because we use AF_PACKET): -// -// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local -// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet -// -// These tests require CAP_NET_RAW to run. - -// TODO(gvisor.dev/issue/173): gVisor support. - -namespace gvisor { -namespace testing { - -namespace { - -using ::testing::AnyOf; -using ::testing::Eq; - -constexpr char kMessage[] = "soweoneul malhaebwa"; -constexpr in_port_t kPort = 0x409c; // htons(40000) - -// -// "Cooked" tests. Cooked AF_PACKET sockets do not contain link layer -// headers, and provide link layer destination/source information via a -// returned struct sockaddr_ll. -// - -// Send kMessage via sock to loopback -void SendUDPMessage(int sock) { - struct sockaddr_in dest = {}; - dest.sin_port = kPort; - dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - dest.sin_family = AF_INET; - EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0, - reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), - SyscallSucceedsWithValue(sizeof(kMessage))); -} - -// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up. -TEST(BasicCookedPacketTest, WrongType) { - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP), - SyscallFailsWithErrno(EPERM)); - GTEST_SKIP(); - } - - FileDescriptor sock = ASSERT_NO_ERRNO_AND_VALUE( - Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_PUP))); - - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - - // Wait and make sure the socket never becomes readable. - struct pollfd pfd = {}; - pfd.fd = sock.get(); - pfd.events = POLLIN; - EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); -} - -// Tests for "cooked" (SOCK_DGRAM) packet(7) sockets. -class CookedPacketTest : public ::testing::TestWithParam<int> { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // Gets the device index of the loopback device. - int GetLoopbackIndex(); - - // The socket used for both reading and writing. - int socket_; -}; - -void CookedPacketTest::SetUp() { - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), - SyscallFailsWithErrno(EPERM)); - GTEST_SKIP(); - } - - if (!IsRunningOnGvisor()) { - FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); - FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); - char enabled; - ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), - SyscallSucceedsWithValue(1)); - ASSERT_EQ(enabled, '1'); - ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), - SyscallSucceedsWithValue(1)); - ASSERT_EQ(enabled, '1'); - } - - ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), - SyscallSucceeds()); -} - -void CookedPacketTest::TearDown() { - // TearDown will be run even if we skip the test. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - EXPECT_THAT(close(socket_), SyscallSucceeds()); - } -} - -int CookedPacketTest::GetLoopbackIndex() { - struct ifreq ifr; - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); - EXPECT_NE(ifr.ifr_ifindex, 0); - return ifr.ifr_ifindex; -} - -// Receive and verify the message via packet socket on interface. -void ReceiveMessage(int sock, int ifindex) { - // Wait for the socket to become readable. - struct pollfd pfd = {}; - pfd.fd = sock; - pfd.events = POLLIN; - EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); - - // Read and verify the data. - constexpr size_t packet_size = - sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); - char buf[64]; - struct sockaddr_ll src = {}; - socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0, - reinterpret_cast<struct sockaddr*>(&src), &src_len), - SyscallSucceedsWithValue(packet_size)); - - // sockaddr_ll ends with an 8 byte physical address field, but ethernet - // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 - // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns - // sizeof(sockaddr_ll). - ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - - // TODO(b/129292371): Verify protocol once we return it. - // Verify the source address. - EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_ifindex, ifindex); - EXPECT_EQ(src.sll_halen, ETH_ALEN); - // This came from the loopback device, so the address is all 0s. - for (int i = 0; i < src.sll_halen; i++) { - EXPECT_EQ(src.sll_addr[i], 0); - } - - // Verify the IP header. We memcpy to deal with pointer aligment. - struct iphdr ip = {}; - memcpy(&ip, buf, sizeof(ip)); - EXPECT_EQ(ip.ihl, 5); - EXPECT_EQ(ip.version, 4); - EXPECT_EQ(ip.tot_len, htons(packet_size)); - EXPECT_EQ(ip.protocol, IPPROTO_UDP); - EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK)); - EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK)); - - // Verify the UDP header. We memcpy to deal with pointer aligment. - struct udphdr udp = {}; - memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); - EXPECT_EQ(udp.dest, kPort); - EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); - - // Verify the payload. - char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr)); - EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); -} - -// Receive via a packet socket. -TEST_P(CookedPacketTest, Receive) { - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - - // Receive and verify the data. - int loopback_index = GetLoopbackIndex(); - ReceiveMessage(socket_, loopback_index); -} - -// Send via a packet socket. -TEST_P(CookedPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. - SKIP_IF(IsRunningOnGvisor()); - - // Let's send a UDP packet and receive it using a regular UDP socket. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - struct sockaddr_in bind_addr = {}; - bind_addr.sin_family = AF_INET; - bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - bind_addr.sin_port = kPort; - ASSERT_THAT( - bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallSucceeds()); - - // Set up the destination physical address. - struct sockaddr_ll dest = {}; - dest.sll_family = AF_PACKET; - dest.sll_halen = ETH_ALEN; - dest.sll_ifindex = GetLoopbackIndex(); - dest.sll_protocol = htons(ETH_P_IP); - // We're sending to the loopback device, so the address is all 0s. - memset(dest.sll_addr, 0x00, ETH_ALEN); - - // Set up the IP header. - struct iphdr iphdr = {0}; - iphdr.ihl = 5; - iphdr.version = 4; - iphdr.tos = 0; - iphdr.tot_len = - htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage)); - // Get a pseudo-random ID. If we clash with an in-use ID the test will fail, - // but we have no way of getting an ID we know to be good. - srand(*reinterpret_cast<unsigned int*>(&iphdr)); - iphdr.id = rand(); - // Linux sets this bit ("do not fragment") for small packets. - iphdr.frag_off = 1 << 6; - iphdr.ttl = 64; - iphdr.protocol = IPPROTO_UDP; - iphdr.daddr = htonl(INADDR_LOOPBACK); - iphdr.saddr = htonl(INADDR_LOOPBACK); - iphdr.check = IPChecksum(iphdr); - - // Set up the UDP header. - struct udphdr udphdr = {}; - udphdr.source = kPort; - udphdr.dest = kPort; - udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage)); - udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage)); - - // Copy both headers and the payload into our packet buffer. - char send_buf[sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)]; - memcpy(send_buf, &iphdr, sizeof(iphdr)); - memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr)); - memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage)); - - // Send it. - ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, - reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Wait for the packet to become available on both sockets. - struct pollfd pfd = {}; - pfd.fd = udp_sock.get(); - pfd.events = POLLIN; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); - pfd.fd = socket_; - pfd.events = POLLIN; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); - - // Receive on the packet socket. - char recv_buf[sizeof(send_buf)]; - ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); - - // Receive on the UDP socket. - struct sockaddr_in src; - socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT, - reinterpret_cast<struct sockaddr*>(&src), &src_len), - SyscallSucceedsWithValue(sizeof(kMessage))); - // Check src and payload. - EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0); - EXPECT_EQ(src.sin_family, AF_INET); - EXPECT_EQ(src.sin_port, kPort); - EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); -} - -// Bind and receive via packet socket. -TEST_P(CookedPacketTest, BindReceive) { - struct sockaddr_ll bind_addr = {}; - bind_addr.sll_family = AF_PACKET; - bind_addr.sll_protocol = htons(GetParam()); - bind_addr.sll_ifindex = GetLoopbackIndex(); - - ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallSucceeds()); - - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - - // Receive and verify the data. - ReceiveMessage(socket_, bind_addr.sll_ifindex); -} - -// Double Bind socket. -TEST_P(CookedPacketTest, DoubleBind) { - struct sockaddr_ll bind_addr = {}; - bind_addr.sll_family = AF_PACKET; - bind_addr.sll_protocol = htons(GetParam()); - bind_addr.sll_ifindex = GetLoopbackIndex(); - - ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallSucceeds()); - - // Binding socket again should fail. - ASSERT_THAT( - bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched - // to EADDRINUSE. - AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL))); -} - -// Bind and verify we do not receive data on interface which is not bound -TEST_P(CookedPacketTest, BindDrop) { - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - - struct ifaddrs* if_addr_list = nullptr; - auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); - - ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); - - // Get interface other than loopback. - struct ifreq ifr = {}; - for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { - if (strcmp(i->ifa_name, "lo") != 0) { - strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); - break; - } - } - - // Skip if no interface is available other than loopback. - if (strlen(ifr.ifr_name) == 0) { - GTEST_SKIP(); - } - - // Get interface index. - EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); - EXPECT_NE(ifr.ifr_ifindex, 0); - - // Bind to packet socket requires only family, protocol and ifindex. - struct sockaddr_ll bind_addr = {}; - bind_addr.sll_family = AF_PACKET; - bind_addr.sll_protocol = htons(GetParam()); - bind_addr.sll_ifindex = ifr.ifr_ifindex; - - ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallSucceeds()); - - // Send to loopback interface. - struct sockaddr_in dest = {}; - dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - dest.sin_family = AF_INET; - dest.sin_port = kPort; - EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, - reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), - SyscallSucceedsWithValue(sizeof(kMessage))); - - // Wait and make sure the socket never receives any data. - struct pollfd pfd = {}; - pfd.fd = socket_; - pfd.events = POLLIN; - EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); -} - -// Bind with invalid address. -TEST_P(CookedPacketTest, BindFail) { - // Null address. - ASSERT_THAT( - bind(socket_, nullptr, sizeof(struct sockaddr)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL))); - - // Address of size 1. - uint8_t addr = 0; - ASSERT_THAT( - bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(EINVAL)); -} - -INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, - ::testing::Values(ETH_P_IP, ETH_P_ALL)); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc deleted file mode 100644 index d258d353c..000000000 --- a/test/syscalls/linux/packet_socket_raw.cc +++ /dev/null @@ -1,318 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <linux/capability.h> -#include <linux/if_arp.h> -#include <linux/if_packet.h> -#include <net/ethernet.h> -#include <netinet/in.h> -#include <netinet/ip.h> -#include <netinet/udp.h> -#include <poll.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/base/internal/endian.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -// Some of these tests involve sending packets via AF_PACKET sockets and the -// loopback interface. Because AF_PACKET circumvents so much of the networking -// stack, Linux sees these packets as "martian", i.e. they claim to be to/from -// localhost but don't have the usual associated data. Thus Linux drops them by -// default. You can see where this happens by following the code at: -// -// - net/ipv4/ip_input.c:ip_rcv_finish, which calls -// - net/ipv4/route.c:ip_route_input_noref, which calls -// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian -// packets. -// -// To tell Linux not to drop these packets, you need to tell it to accept our -// funny packets (which are completely valid and correct, but lack associated -// in-kernel data because we use AF_PACKET): -// -// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local -// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet -// -// These tests require CAP_NET_RAW to run. - -// TODO(gvisor.dev/issue/173): gVisor support. - -namespace gvisor { -namespace testing { - -namespace { - -using ::testing::AnyOf; -using ::testing::Eq; - -constexpr char kMessage[] = "soweoneul malhaebwa"; -constexpr in_port_t kPort = 0x409c; // htons(40000) - -// Send kMessage via sock to loopback -void SendUDPMessage(int sock) { - struct sockaddr_in dest = {}; - dest.sin_port = kPort; - dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - dest.sin_family = AF_INET; - EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0, - reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), - SyscallSucceedsWithValue(sizeof(kMessage))); -} - -// -// Raw tests. Packets sent with raw AF_PACKET sockets always include link layer -// headers. -// - -// Tests for "raw" (SOCK_RAW) packet(7) sockets. -class RawPacketTest : public ::testing::TestWithParam<int> { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // Gets the device index of the loopback device. - int GetLoopbackIndex(); - - // The socket used for both reading and writing. - int socket_; -}; - -void RawPacketTest::SetUp() { - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())), - SyscallFailsWithErrno(EPERM)); - GTEST_SKIP(); - } - - if (!IsRunningOnGvisor()) { - FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); - FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); - char enabled; - ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), - SyscallSucceedsWithValue(1)); - ASSERT_EQ(enabled, '1'); - ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), - SyscallSucceedsWithValue(1)); - ASSERT_EQ(enabled, '1'); - } - - ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())), - SyscallSucceeds()); -} - -void RawPacketTest::TearDown() { - // TearDown will be run even if we skip the test. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - EXPECT_THAT(close(socket_), SyscallSucceeds()); - } -} - -int RawPacketTest::GetLoopbackIndex() { - struct ifreq ifr; - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); - EXPECT_NE(ifr.ifr_ifindex, 0); - return ifr.ifr_ifindex; -} - -// Receive via a packet socket. -TEST_P(RawPacketTest, Receive) { - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - - // Wait for the socket to become readable. - struct pollfd pfd = {}; - pfd.fd = socket_; - pfd.events = POLLIN; - EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); - - // Read and verify the data. - constexpr size_t packet_size = sizeof(struct ethhdr) + sizeof(struct iphdr) + - sizeof(struct udphdr) + sizeof(kMessage); - char buf[64]; - struct sockaddr_ll src = {}; - socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, - reinterpret_cast<struct sockaddr*>(&src), &src_len), - SyscallSucceedsWithValue(packet_size)); - // sockaddr_ll ends with an 8 byte physical address field, but ethernet - // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 - // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns - // sizeof(sockaddr_ll). - ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - - // TODO(b/129292371): Verify protocol once we return it. - // Verify the source address. - EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); - EXPECT_EQ(src.sll_halen, ETH_ALEN); - // This came from the loopback device, so the address is all 0s. - for (int i = 0; i < src.sll_halen; i++) { - EXPECT_EQ(src.sll_addr[i], 0); - } - - // Verify the ethernet header. We memcpy to deal with pointer alignment. - struct ethhdr eth = {}; - memcpy(ð, buf, sizeof(eth)); - // The destination and source address should be 0, for loopback. - for (int i = 0; i < ETH_ALEN; i++) { - EXPECT_EQ(eth.h_dest[i], 0); - EXPECT_EQ(eth.h_source[i], 0); - } - EXPECT_EQ(eth.h_proto, htons(ETH_P_IP)); - - // Verify the IP header. We memcpy to deal with pointer aligment. - struct iphdr ip = {}; - memcpy(&ip, buf + sizeof(ethhdr), sizeof(ip)); - EXPECT_EQ(ip.ihl, 5); - EXPECT_EQ(ip.version, 4); - EXPECT_EQ(ip.tot_len, htons(packet_size - sizeof(eth))); - EXPECT_EQ(ip.protocol, IPPROTO_UDP); - EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK)); - EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK)); - - // Verify the UDP header. We memcpy to deal with pointer aligment. - struct udphdr udp = {}; - memcpy(&udp, buf + sizeof(eth) + sizeof(iphdr), sizeof(udp)); - EXPECT_EQ(udp.dest, kPort); - EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); - - // Verify the payload. - char* payload = reinterpret_cast<char*>(buf + sizeof(eth) + sizeof(iphdr) + - sizeof(udphdr)); - EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); -} - -// Send via a packet socket. -TEST_P(RawPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. - SKIP_IF(IsRunningOnGvisor()); - - // Let's send a UDP packet and receive it using a regular UDP socket. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - struct sockaddr_in bind_addr = {}; - bind_addr.sin_family = AF_INET; - bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - bind_addr.sin_port = kPort; - ASSERT_THAT( - bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallSucceeds()); - - // Set up the destination physical address. - struct sockaddr_ll dest = {}; - dest.sll_family = AF_PACKET; - dest.sll_halen = ETH_ALEN; - dest.sll_ifindex = GetLoopbackIndex(); - dest.sll_protocol = htons(ETH_P_IP); - // We're sending to the loopback device, so the address is all 0s. - memset(dest.sll_addr, 0x00, ETH_ALEN); - - // Set up the ethernet header. The kernel takes care of the footer. - // We're sending to and from hardware address 0 (loopback). - struct ethhdr eth = {}; - eth.h_proto = htons(ETH_P_IP); - - // Set up the IP header. - struct iphdr iphdr = {}; - iphdr.ihl = 5; - iphdr.version = 4; - iphdr.tos = 0; - iphdr.tot_len = - htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage)); - // Get a pseudo-random ID. If we clash with an in-use ID the test will fail, - // but we have no way of getting an ID we know to be good. - srand(*reinterpret_cast<unsigned int*>(&iphdr)); - iphdr.id = rand(); - // Linux sets this bit ("do not fragment") for small packets. - iphdr.frag_off = 1 << 6; - iphdr.ttl = 64; - iphdr.protocol = IPPROTO_UDP; - iphdr.daddr = htonl(INADDR_LOOPBACK); - iphdr.saddr = htonl(INADDR_LOOPBACK); - iphdr.check = IPChecksum(iphdr); - - // Set up the UDP header. - struct udphdr udphdr = {}; - udphdr.source = kPort; - udphdr.dest = kPort; - udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage)); - udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage)); - - // Copy both headers and the payload into our packet buffer. - char - send_buf[sizeof(eth) + sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)]; - memcpy(send_buf, ð, sizeof(eth)); - memcpy(send_buf + sizeof(ethhdr), &iphdr, sizeof(iphdr)); - memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr), &udphdr, sizeof(udphdr)); - memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr) + sizeof(udphdr), kMessage, - sizeof(kMessage)); - - // Send it. - ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, - reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Wait for the packet to become available on both sockets. - struct pollfd pfd = {}; - pfd.fd = udp_sock.get(); - pfd.events = POLLIN; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); - pfd.fd = socket_; - pfd.events = POLLIN; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); - - // Receive on the packet socket. - char recv_buf[sizeof(send_buf)]; - ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); - - // Receive on the UDP socket. - struct sockaddr_in src; - socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT, - reinterpret_cast<struct sockaddr*>(&src), &src_len), - SyscallSucceedsWithValue(sizeof(kMessage))); - // Check src and payload. - EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0); - EXPECT_EQ(src.sin_family, AF_INET); - EXPECT_EQ(src.sin_port, kPort); - EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); -} - -INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, - ::testing::Values(ETH_P_IP, ETH_P_ALL)); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc deleted file mode 100644 index df7129acc..000000000 --- a/test/syscalls/linux/partial_bad_buffer.cc +++ /dev/null @@ -1,405 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <sys/mman.h> -#include <sys/socket.h> -#include <sys/stat.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <sys/uio.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -using ::testing::Gt; - -namespace gvisor { -namespace testing { - -namespace { - -constexpr char kMessage[] = "hello world"; - -// PartialBadBufferTest checks the result of various IO syscalls when passed a -// buffer that does not have the space specified in the syscall (most of it is -// PROT_NONE). Linux is annoyingly inconsistent among different syscalls, so we -// test all of them. -class PartialBadBufferTest : public ::testing::Test { - protected: - void SetUp() override { - // Create and open a directory for getdents cases. - directory_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT( - directory_fd_ = open(directory_.path().c_str(), O_RDONLY | O_DIRECTORY), - SyscallSucceeds()); - - // Create and open a normal file, placing it in the directory - // so the getdents cases have some dirents. - name_ = JoinPath(directory_.path(), "a"); - ASSERT_THAT(fd_ = open(name_.c_str(), O_RDWR | O_CREAT, 0644), - SyscallSucceeds()); - - // Write some initial data. - size_t size = sizeof(kMessage) - 1; - EXPECT_THAT(WriteFd(fd_, &kMessage, size), SyscallSucceedsWithValue(size)); - ASSERT_THAT(lseek(fd_, 0, SEEK_SET), SyscallSucceeds()); - - // Map a useable buffer. - addr_ = mmap(0, 2 * kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - ASSERT_NE(addr_, MAP_FAILED); - char* buf = reinterpret_cast<char*>(addr_); - - // Guard page for our read to run into. - ASSERT_THAT(mprotect(reinterpret_cast<void*>(buf + kPageSize), kPageSize, - PROT_NONE), - SyscallSucceeds()); - - // Leave only one free byte in the buffer. - bad_buffer_ = buf + kPageSize - 1; - } - - off_t Size() { - struct stat st; - int rc = fstat(fd_, &st); - if (rc < 0) { - return static_cast<off_t>(rc); - } - return st.st_size; - } - - void TearDown() override { - EXPECT_THAT(munmap(addr_, 2 * kPageSize), SyscallSucceeds()) << addr_; - EXPECT_THAT(close(fd_), SyscallSucceeds()); - EXPECT_THAT(unlink(name_.c_str()), SyscallSucceeds()); - EXPECT_THAT(close(directory_fd_), SyscallSucceeds()); - } - - // Return buffer with n bytes of free space. - // N.B. this is the same buffer used to back bad_buffer_. - char* FreeBytes(size_t n) { - TEST_CHECK(n <= static_cast<size_t>(4096)); - return reinterpret_cast<char*>(addr_) + kPageSize - n; - } - - std::string name_; - int fd_; - TempPath directory_; - int directory_fd_; - void* addr_; - char* bad_buffer_; -}; - -// We do both "big" and "small" tests to try to hit the "zero copy" and -// non-"zero copy" paths, which have different code paths for handling faults. - -TEST_F(PartialBadBufferTest, ReadBig) { - EXPECT_THAT(RetryEINTR(read)(fd_, bad_buffer_, kPageSize), - SyscallSucceedsWithValue(1)); - EXPECT_EQ('h', bad_buffer_[0]); -} - -TEST_F(PartialBadBufferTest, ReadSmall) { - EXPECT_THAT(RetryEINTR(read)(fd_, bad_buffer_, 10), - SyscallSucceedsWithValue(1)); - EXPECT_EQ('h', bad_buffer_[0]); -} - -TEST_F(PartialBadBufferTest, PreadBig) { - EXPECT_THAT(RetryEINTR(pread)(fd_, bad_buffer_, kPageSize, 0), - SyscallSucceedsWithValue(1)); - EXPECT_EQ('h', bad_buffer_[0]); -} - -TEST_F(PartialBadBufferTest, PreadSmall) { - EXPECT_THAT(RetryEINTR(pread)(fd_, bad_buffer_, 10, 0), - SyscallSucceedsWithValue(1)); - EXPECT_EQ('h', bad_buffer_[0]); -} - -TEST_F(PartialBadBufferTest, ReadvBig) { - struct iovec vec; - vec.iov_base = bad_buffer_; - vec.iov_len = kPageSize; - - EXPECT_THAT(RetryEINTR(readv)(fd_, &vec, 1), SyscallSucceedsWithValue(1)); - EXPECT_EQ('h', bad_buffer_[0]); -} - -TEST_F(PartialBadBufferTest, ReadvSmall) { - struct iovec vec; - vec.iov_base = bad_buffer_; - vec.iov_len = 10; - - EXPECT_THAT(RetryEINTR(readv)(fd_, &vec, 1), SyscallSucceedsWithValue(1)); - EXPECT_EQ('h', bad_buffer_[0]); -} - -TEST_F(PartialBadBufferTest, PreadvBig) { - struct iovec vec; - vec.iov_base = bad_buffer_; - vec.iov_len = kPageSize; - - EXPECT_THAT(RetryEINTR(preadv)(fd_, &vec, 1, 0), SyscallSucceedsWithValue(1)); - EXPECT_EQ('h', bad_buffer_[0]); -} - -TEST_F(PartialBadBufferTest, PreadvSmall) { - struct iovec vec; - vec.iov_base = bad_buffer_; - vec.iov_len = 10; - - EXPECT_THAT(RetryEINTR(preadv)(fd_, &vec, 1, 0), SyscallSucceedsWithValue(1)); - EXPECT_EQ('h', bad_buffer_[0]); -} - -TEST_F(PartialBadBufferTest, WriteBig) { - off_t orig_size = Size(); - int n; - - ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT( - (n = RetryEINTR(write)(fd_, bad_buffer_, kPageSize)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); - EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); -} - -TEST_F(PartialBadBufferTest, WriteSmall) { - off_t orig_size = Size(); - int n; - - ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT( - (n = RetryEINTR(write)(fd_, bad_buffer_, 10)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); - EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); -} - -TEST_F(PartialBadBufferTest, PwriteBig) { - off_t orig_size = Size(); - int n; - - EXPECT_THAT( - (n = RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, orig_size)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); - EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); -} - -TEST_F(PartialBadBufferTest, PwriteSmall) { - off_t orig_size = Size(); - int n; - - EXPECT_THAT( - (n = RetryEINTR(pwrite)(fd_, bad_buffer_, 10, orig_size)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); - EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); -} - -TEST_F(PartialBadBufferTest, WritevBig) { - struct iovec vec; - vec.iov_base = bad_buffer_; - vec.iov_len = kPageSize; - off_t orig_size = Size(); - int n; - - ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT( - (n = RetryEINTR(writev)(fd_, &vec, 1)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); - EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); -} - -TEST_F(PartialBadBufferTest, WritevSmall) { - struct iovec vec; - vec.iov_base = bad_buffer_; - vec.iov_len = 10; - off_t orig_size = Size(); - int n; - - ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds()); - EXPECT_THAT( - (n = RetryEINTR(writev)(fd_, &vec, 1)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); - EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); -} - -TEST_F(PartialBadBufferTest, PwritevBig) { - struct iovec vec; - vec.iov_base = bad_buffer_; - vec.iov_len = kPageSize; - off_t orig_size = Size(); - int n; - - EXPECT_THAT( - (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); - EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); -} - -TEST_F(PartialBadBufferTest, PwritevSmall) { - struct iovec vec; - vec.iov_base = bad_buffer_; - vec.iov_len = 10; - off_t orig_size = Size(); - int n; - - EXPECT_THAT( - (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)), - AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); - EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); -} - -// getdents returns EFAULT when the you claim the buffer is large enough, but -// it actually isn't. -TEST_F(PartialBadBufferTest, GetdentsBig) { - EXPECT_THAT(RetryEINTR(syscall)(SYS_getdents64, directory_fd_, bad_buffer_, - kPageSize), - SyscallFailsWithErrno(EFAULT)); -} - -// getdents returns EINVAL when the you claim the buffer is too small. -TEST_F(PartialBadBufferTest, GetdentsSmall) { - EXPECT_THAT( - RetryEINTR(syscall)(SYS_getdents64, directory_fd_, bad_buffer_, 10), - SyscallFailsWithErrno(EINVAL)); -} - -// getdents will write entries into a buffer if there is space before it faults. -TEST_F(PartialBadBufferTest, GetdentsOneEntry) { - // 30 bytes is enough for one (small) entry. - char* buf = FreeBytes(30); - - EXPECT_THAT( - RetryEINTR(syscall)(SYS_getdents64, directory_fd_, buf, kPageSize), - SyscallSucceedsWithValue(Gt(0))); -} - -PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - addr.ss_family = family; - switch (family) { - case AF_INET: - reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = - htonl(INADDR_LOOPBACK); - break; - case AF_INET6: - reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = - in6addr_loopback; - break; - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } - return addr; -} - -// SendMsgTCP verifies that calling sendmsg with a bad address returns an -// EFAULT. It also verifies that passing a buffer which is made up of 2 -// pages one valid and one guard page succeeds as long as the write is -// for exactly the size of 1 page. -TEST_F(PartialBadBufferTest, SendMsgTCP) { - auto listen_socket = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); - - // Initialize address to the loopback one. - sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); - socklen_t addrlen = sizeof(addr); - - // Bind to some port then start listening. - ASSERT_THAT(bind(listen_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the address we're listening on, then connect to it. We need to do this - // because we're allowing the stack to pick a port for us. - ASSERT_THAT(getsockname(listen_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - auto send_socket = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); - - ASSERT_THAT( - RetryEINTR(connect)(send_socket.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - // Accept the connection. - auto recv_socket = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); - - // TODO(gvisor.dev/issue/674): Update this once Netstack matches linux - // behaviour on a setsockopt of SO_RCVBUF/SO_SNDBUF. - // - // Set SO_SNDBUF for socket to exactly kPageSize+1. - // - // gVisor does not double the value passed in SO_SNDBUF like linux does so we - // just increase it by 1 byte here for gVisor so that we can test writing 1 - // byte past the valid page and check that it triggers an EFAULT - // correctly. Otherwise in gVisor the sendmsg call will just return with no - // error with kPageSize bytes written successfully. - const uint32_t buf_size = kPageSize + 1; - ASSERT_THAT(setsockopt(send_socket.get(), SOL_SOCKET, SO_SNDBUF, &buf_size, - sizeof(buf_size)), - SyscallSucceedsWithValue(0)); - - struct msghdr hdr = {}; - struct iovec iov = {}; - iov.iov_base = bad_buffer_; - iov.iov_len = kPageSize; - hdr.msg_iov = &iov; - hdr.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), - SyscallFailsWithErrno(EFAULT)); - - // Now assert that writing kPageSize from addr_ succeeds. - iov.iov_base = addr_; - ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), - SyscallSucceedsWithValue(kPageSize)); - // Read all the data out so that we drain the socket SND_BUF on the sender. - std::vector<char> buffer(kPageSize); - ASSERT_THAT(RetryEINTR(read)(recv_socket.get(), buffer.data(), kPageSize), - SyscallSucceedsWithValue(kPageSize)); - - // Sleep for a shortwhile to ensure that we have time to process the - // ACKs. This is not strictly required unless running under gotsan which is a - // lot slower and can result in the next write to write only 1 byte instead of - // our intended kPageSize + 1. - absl::SleepFor(absl::Milliseconds(50)); - - // Now assert that writing > kPageSize results in EFAULT. - iov.iov_len = kPageSize + 1; - ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), - SyscallFailsWithErrno(EFAULT)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/pause.cc b/test/syscalls/linux/pause.cc deleted file mode 100644 index 8c05efd6f..000000000 --- a/test/syscalls/linux/pause.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <signal.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include <atomic> - -#include "gtest/gtest.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -void NoopSignalHandler(int sig, siginfo_t* info, void* context) {} - -} // namespace - -TEST(PauseTest, OnlyReturnsWhenSignalHandled) { - struct sigaction sa; - sigfillset(&sa.sa_mask); - - // Ensure that SIGUSR1 is ignored. - sa.sa_handler = SIG_IGN; - ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); - - // Register a handler for SIGUSR2. - sa.sa_sigaction = NoopSignalHandler; - sa.sa_flags = SA_SIGINFO; - ASSERT_THAT(sigaction(SIGUSR2, &sa, nullptr), SyscallSucceeds()); - - // The child sets their own tid. - absl::Mutex mu; - pid_t child_tid = 0; - bool child_tid_available = false; - std::atomic<int> sent_signal{0}; - std::atomic<int> waking_signal{0}; - ScopedThread t([&] { - mu.Lock(); - child_tid = gettid(); - child_tid_available = true; - mu.Unlock(); - EXPECT_THAT(pause(), SyscallFailsWithErrno(EINTR)); - waking_signal.store(sent_signal.load()); - }); - mu.Lock(); - mu.Await(absl::Condition(&child_tid_available)); - mu.Unlock(); - - // Wait a bit to let the child enter pause(). - absl::SleepFor(absl::Seconds(3)); - - // The child should not be woken by SIGUSR1. - sent_signal.store(SIGUSR1); - ASSERT_THAT(tgkill(getpid(), child_tid, SIGUSR1), SyscallSucceeds()); - absl::SleepFor(absl::Seconds(3)); - - // The child should be woken by SIGUSR2. - sent_signal.store(SIGUSR2); - ASSERT_THAT(tgkill(getpid(), child_tid, SIGUSR2), SyscallSucceeds()); - absl::SleepFor(absl::Seconds(3)); - - EXPECT_EQ(SIGUSR2, waking_signal.load()); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc deleted file mode 100644 index d8e19e910..000000000 --- a/test/syscalls/linux/pipe.cc +++ /dev/null @@ -1,663 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> /* Obtain O_* constant definitions */ -#include <sys/ioctl.h> -#include <sys/uio.h> -#include <unistd.h> - -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/notification.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Used as a non-zero sentinel value, below. -constexpr int kTestValue = 0x12345678; - -// Used for synchronization in race tests. -const absl::Duration syncDelay = absl::Seconds(2); - -struct PipeCreator { - std::string name_; - - // void (fds, is_blocking, is_namedpipe). - std::function<void(int[2], bool*, bool*)> create_; -}; - -class PipeTest : public ::testing::TestWithParam<PipeCreator> { - public: - static void SetUpTestSuite() { - // Tests intentionally generate SIGPIPE. - TEST_PCHECK(signal(SIGPIPE, SIG_IGN) != SIG_ERR); - } - - // Initializes rfd_ and wfd_ as a blocking pipe. - // - // The return value indicates success: the test should be skipped otherwise. - bool CreateBlocking() { return create(true); } - - // Initializes rfd_ and wfd_ as a non-blocking pipe. - // - // The return value is per CreateBlocking. - bool CreateNonBlocking() { return create(false); } - - // Returns true iff the pipe represents a named pipe. - bool IsNamedPipe() const { return named_pipe_; } - - int Size() const { - int s1 = fcntl(rfd_.get(), F_GETPIPE_SZ); - int s2 = fcntl(wfd_.get(), F_GETPIPE_SZ); - EXPECT_GT(s1, 0); - EXPECT_GT(s2, 0); - EXPECT_EQ(s1, s2); - return s1; - } - - static void TearDownTestSuite() { - TEST_PCHECK(signal(SIGPIPE, SIG_DFL) != SIG_ERR); - } - - private: - bool create(bool wants_blocking) { - // Generate the pipe. - int fds[2] = {-1, -1}; - bool is_blocking = false; - GetParam().create_(fds, &is_blocking, &named_pipe_); - if (fds[0] < 0 || fds[1] < 0) { - return false; - } - - // Save descriptors. - rfd_.reset(fds[0]); - wfd_.reset(fds[1]); - - // Adjust blocking, if needed. - if (!is_blocking && wants_blocking) { - // Clear the blocking flag. - EXPECT_THAT(fcntl(fds[0], F_SETFL, 0), SyscallSucceeds()); - EXPECT_THAT(fcntl(fds[1], F_SETFL, 0), SyscallSucceeds()); - } else if (is_blocking && !wants_blocking) { - // Set the descriptors to blocking. - EXPECT_THAT(fcntl(fds[0], F_SETFL, O_NONBLOCK), SyscallSucceeds()); - EXPECT_THAT(fcntl(fds[1], F_SETFL, O_NONBLOCK), SyscallSucceeds()); - } - - return true; - } - - protected: - FileDescriptor rfd_; - FileDescriptor wfd_; - - private: - bool named_pipe_ = false; -}; - -TEST_P(PipeTest, Inode) { - SKIP_IF(!CreateBlocking()); - - // Ensure that the inode number is the same for each end. - struct stat rst; - ASSERT_THAT(fstat(rfd_.get(), &rst), SyscallSucceeds()); - struct stat wst; - ASSERT_THAT(fstat(wfd_.get(), &wst), SyscallSucceeds()); - EXPECT_EQ(rst.st_ino, wst.st_ino); -} - -TEST_P(PipeTest, Permissions) { - SKIP_IF(!CreateBlocking()); - - // Attempt bad operations. - int buf = kTestValue; - ASSERT_THAT(write(rfd_.get(), &buf, sizeof(buf)), - SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(read(wfd_.get(), &buf, sizeof(buf)), - SyscallFailsWithErrno(EBADF)); -} - -TEST_P(PipeTest, Flags) { - SKIP_IF(!CreateBlocking()); - - if (IsNamedPipe()) { - // May be stubbed to zero; define locally. - EXPECT_THAT(fcntl(rfd_.get(), F_GETFL), - SyscallSucceedsWithValue(kOLargeFile | O_RDONLY)); - EXPECT_THAT(fcntl(wfd_.get(), F_GETFL), - SyscallSucceedsWithValue(kOLargeFile | O_WRONLY)); - } else { - EXPECT_THAT(fcntl(rfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_RDONLY)); - EXPECT_THAT(fcntl(wfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_WRONLY)); - } -} - -TEST_P(PipeTest, Write) { - SKIP_IF(!CreateBlocking()); - - int wbuf = kTestValue; - int rbuf = ~kTestValue; - ASSERT_THAT(write(wfd_.get(), &wbuf, sizeof(wbuf)), - SyscallSucceedsWithValue(sizeof(wbuf))); - ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(wbuf, rbuf); -} - -TEST_P(PipeTest, WritePage) { - SKIP_IF(!CreateBlocking()); - - std::vector<char> wbuf(kPageSize); - RandomizeBuffer(wbuf.data(), wbuf.size()); - std::vector<char> rbuf(wbuf.size()); - - ASSERT_THAT(write(wfd_.get(), wbuf.data(), wbuf.size()), - SyscallSucceedsWithValue(wbuf.size())); - ASSERT_THAT(read(rfd_.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(rbuf.size())); - EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), wbuf.size()), 0); -} - -TEST_P(PipeTest, NonBlocking) { - SKIP_IF(!CreateNonBlocking()); - - int wbuf = kTestValue; - int rbuf = ~kTestValue; - EXPECT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)), - SyscallFailsWithErrno(EWOULDBLOCK)); - ASSERT_THAT(write(wfd_.get(), &wbuf, sizeof(wbuf)), - SyscallSucceedsWithValue(sizeof(wbuf))); - - ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(wbuf, rbuf); - EXPECT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST(Pipe2Test, CloExec) { - int fds[2]; - ASSERT_THAT(pipe2(fds, O_CLOEXEC), SyscallSucceeds()); - EXPECT_THAT(fcntl(fds[0], F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); - EXPECT_THAT(fcntl(fds[1], F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); - EXPECT_THAT(close(fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(fds[1]), SyscallSucceeds()); -} - -TEST(Pipe2Test, BadOptions) { - int fds[2]; - EXPECT_THAT(pipe2(fds, 0xDEAD), SyscallFailsWithErrno(EINVAL)); -} - -// Tests that opening named pipes with O_TRUNC shouldn't cause an error, but -// calls to (f)truncate should. -TEST(NamedPipeTest, Truncate) { - const std::string tmp_path = NewTempAbsPath(); - SKIP_IF(mkfifo(tmp_path.c_str(), 0644) != 0); - - ASSERT_THAT(open(tmp_path.c_str(), O_NONBLOCK | O_RDONLY), SyscallSucceeds()); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(tmp_path.c_str(), O_RDWR | O_NONBLOCK | O_TRUNC)); - - ASSERT_THAT(truncate(tmp_path.c_str(), 0), SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(PipeTest, Seek) { - SKIP_IF(!CreateBlocking()); - - for (int i = 0; i < 4; i++) { - // Attempt absolute seeks. - EXPECT_THAT(lseek(rfd_.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(rfd_.get(), 4, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(wfd_.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(wfd_.get(), 4, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); - - // Attempt relative seeks. - EXPECT_THAT(lseek(rfd_.get(), 0, SEEK_CUR), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(rfd_.get(), 4, SEEK_CUR), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(wfd_.get(), 0, SEEK_CUR), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(wfd_.get(), 4, SEEK_CUR), SyscallFailsWithErrno(ESPIPE)); - - // Attempt end-of-file seeks. - EXPECT_THAT(lseek(rfd_.get(), 0, SEEK_CUR), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(rfd_.get(), -4, SEEK_END), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(wfd_.get(), 0, SEEK_CUR), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(lseek(wfd_.get(), -4, SEEK_END), SyscallFailsWithErrno(ESPIPE)); - - // Add some more data to the pipe. - int buf = kTestValue; - ASSERT_THAT(write(wfd_.get(), &buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - } -} - -TEST_P(PipeTest, OffsetCalls) { - SKIP_IF(!CreateBlocking()); - - int buf; - EXPECT_THAT(pread(wfd_.get(), &buf, sizeof(buf), 0), - SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(pwrite(rfd_.get(), &buf, sizeof(buf), 0), - SyscallFailsWithErrno(ESPIPE)); - - struct iovec iov; - EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); -} - -TEST_P(PipeTest, WriterSideCloses) { - SKIP_IF(!CreateBlocking()); - - ScopedThread t([this]() { - int buf = ~kTestValue; - ASSERT_THAT(read(rfd_.get(), &buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - EXPECT_EQ(buf, kTestValue); - // This will return when the close() completes. - ASSERT_THAT(read(rfd_.get(), &buf, sizeof(buf)), SyscallSucceeds()); - // This will return straight away. - ASSERT_THAT(read(rfd_.get(), &buf, sizeof(buf)), - SyscallSucceedsWithValue(0)); - }); - - // Sleep a bit so the thread can block. - absl::SleepFor(syncDelay); - - // Write to unblock. - int buf = kTestValue; - ASSERT_THAT(write(wfd_.get(), &buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // Sleep a bit so the thread can block again. - absl::SleepFor(syncDelay); - - // Allow the thread to complete. - ASSERT_THAT(close(wfd_.release()), SyscallSucceeds()); - t.Join(); -} - -TEST_P(PipeTest, WriterSideClosesReadDataFirst) { - SKIP_IF(!CreateBlocking()); - - int wbuf = kTestValue; - ASSERT_THAT(write(wfd_.get(), &wbuf, sizeof(wbuf)), - SyscallSucceedsWithValue(sizeof(wbuf))); - ASSERT_THAT(close(wfd_.release()), SyscallSucceeds()); - - int rbuf; - ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(wbuf, rbuf); - EXPECT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(PipeTest, ReaderSideCloses) { - SKIP_IF(!CreateBlocking()); - - ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); - int buf = kTestValue; - EXPECT_THAT(write(wfd_.get(), &buf, sizeof(buf)), - SyscallFailsWithErrno(EPIPE)); -} - -TEST_P(PipeTest, CloseTwice) { - SKIP_IF(!CreateBlocking()); - - int reader = rfd_.release(); - int writer = wfd_.release(); - ASSERT_THAT(close(reader), SyscallSucceeds()); - ASSERT_THAT(close(writer), SyscallSucceeds()); - EXPECT_THAT(close(reader), SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(close(writer), SyscallFailsWithErrno(EBADF)); -} - -// Blocking write returns EPIPE when read end is closed if nothing has been -// written. -TEST_P(PipeTest, BlockWriteClosed) { - SKIP_IF(!CreateBlocking()); - - absl::Notification notify; - ScopedThread t([this, ¬ify]() { - std::vector<char> buf(Size()); - // Exactly fill the pipe buffer. - ASSERT_THAT(WriteFd(wfd_.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - notify.Notify(); - - // Attempt to write one more byte. Blocks. - // N.B. Don't use WriteFd, we don't want a retry. - EXPECT_THAT(write(wfd_.get(), buf.data(), 1), SyscallFailsWithErrno(EPIPE)); - }); - - notify.WaitForNotification(); - ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); - t.Join(); -} - -// Blocking write returns EPIPE when read end is closed even if something has -// been written. -TEST_P(PipeTest, BlockPartialWriteClosed) { - SKIP_IF(!CreateBlocking()); - - ScopedThread t([this]() { - const int pipe_size = Size(); - std::vector<char> buf(2 * pipe_size); - - // Write more than fits in the buffer. Blocks then returns partial write - // when the other end is closed. The next call returns EPIPE. - ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(pipe_size)); - EXPECT_THAT(write(wfd_.get(), buf.data(), buf.size()), - SyscallFailsWithErrno(EPIPE)); - }); - - // Leave time for write to become blocked. - absl::SleepFor(syncDelay); - - // Unblock the above. - ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); - t.Join(); -} - -TEST_P(PipeTest, ReadFromClosedFd_NoRandomSave) { - SKIP_IF(!CreateBlocking()); - - absl::Notification notify; - ScopedThread t([this, ¬ify]() { - notify.Notify(); - int buf; - ASSERT_THAT(read(rfd_.get(), &buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_EQ(kTestValue, buf); - }); - notify.WaitForNotification(); - - // Make sure that the thread gets to read(). - absl::SleepFor(syncDelay); - - { - // We cannot save/restore here as the read end of pipe is closed but there - // is ongoing read() above. We will not be able to restart the read() - // successfully in restore run since the read fd is closed. - const DisableSave ds; - ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); - int buf = kTestValue; - ASSERT_THAT(write(wfd_.get(), &buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - t.Join(); - } -} - -TEST_P(PipeTest, FionRead) { - SKIP_IF(!CreateBlocking()); - - int n; - ASSERT_THAT(ioctl(rfd_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - ASSERT_THAT(ioctl(wfd_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - std::vector<char> buf(Size()); - ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_THAT(ioctl(rfd_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, buf.size()); - EXPECT_THAT(ioctl(wfd_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, buf.size()); -} - -// Test that opening an empty anonymous pipe RDONLY via /proc/self/fd/N does not -// block waiting for a writer. -TEST_P(PipeTest, OpenViaProcSelfFD) { - SKIP_IF(!CreateBlocking()); - SKIP_IF(IsNamedPipe()); - - // Close the write end of the pipe. - ASSERT_THAT(close(wfd_.release()), SyscallSucceeds()); - - // Open other side via /proc/self/fd. It should not block. - FileDescriptor proc_self_fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(absl::StrCat("/proc/self/fd/", rfd_.get()), O_RDONLY)); -} - -// Test that opening and reading from an anonymous pipe (with existing writes) -// RDONLY via /proc/self/fd/N returns the existing data. -TEST_P(PipeTest, OpenViaProcSelfFDWithWrites) { - SKIP_IF(!CreateBlocking()); - SKIP_IF(IsNamedPipe()); - - // Write to the pipe and then close the write fd. - int wbuf = kTestValue; - ASSERT_THAT(write(wfd_.get(), &wbuf, sizeof(wbuf)), - SyscallSucceedsWithValue(sizeof(wbuf))); - ASSERT_THAT(close(wfd_.release()), SyscallSucceeds()); - - // Open read side via /proc/self/fd, and read from it. - FileDescriptor proc_self_fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(absl::StrCat("/proc/self/fd/", rfd_.get()), O_RDONLY)); - int rbuf; - ASSERT_THAT(read(proc_self_fd.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(wbuf, rbuf); -} - -// Test that accesses of /proc/<PID>/fd correctly decrement the refcount. -TEST_P(PipeTest, ProcFDReleasesFile) { - SKIP_IF(!CreateBlocking()); - - // Stat the pipe FD, which shouldn't alter the refcount. - struct stat wst; - ASSERT_THAT(lstat(absl::StrCat("/proc/self/fd/", wfd_.get()).c_str(), &wst), - SyscallSucceeds()); - - // Close the write end and ensure that read indicates EOF. - wfd_.reset(); - char buf; - ASSERT_THAT(read(rfd_.get(), &buf, 1), SyscallSucceedsWithValue(0)); -} - -// Same for /proc/<PID>/fdinfo. -TEST_P(PipeTest, ProcFDInfoReleasesFile) { - SKIP_IF(!CreateBlocking()); - - // Stat the pipe FD, which shouldn't alter the refcount. - struct stat wst; - ASSERT_THAT( - lstat(absl::StrCat("/proc/self/fdinfo/", wfd_.get()).c_str(), &wst), - SyscallSucceeds()); - - // Close the write end and ensure that read indicates EOF. - wfd_.reset(); - char buf; - ASSERT_THAT(read(rfd_.get(), &buf, 1), SyscallSucceedsWithValue(0)); -} - -TEST_P(PipeTest, SizeChange) { - SKIP_IF(!CreateBlocking()); - - // Set the minimum possible size. - ASSERT_THAT(fcntl(rfd_.get(), F_SETPIPE_SZ, 0), SyscallSucceeds()); - int min = Size(); - EXPECT_GT(min, 0); // Should be rounded up. - - // Set from the read end. - ASSERT_THAT(fcntl(rfd_.get(), F_SETPIPE_SZ, min + 1), SyscallSucceeds()); - int med = Size(); - EXPECT_GT(med, min); // Should have grown, may be rounded. - - // Set from the write end. - ASSERT_THAT(fcntl(wfd_.get(), F_SETPIPE_SZ, med + 1), SyscallSucceeds()); - int max = Size(); - EXPECT_GT(max, med); // Ditto. -} - -TEST_P(PipeTest, SizeChangeMax) { - SKIP_IF(!CreateBlocking()); - - // Assert there's some maximum. - EXPECT_THAT(fcntl(rfd_.get(), F_SETPIPE_SZ, 0x7fffffffffffffff), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(fcntl(wfd_.get(), F_SETPIPE_SZ, 0x7fffffffffffffff), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(PipeTest, SizeChangeFull) { - SKIP_IF(!CreateBlocking()); - - // Ensure that we adjust to a large enough size to avoid rounding when we - // perform the size decrease. If rounding occurs, we may not actually - // adjust the size and the call below will return success. It was found via - // experimentation that this granularity avoids the rounding for Linux. - constexpr int kDelta = 64 * 1024; - ASSERT_THAT(fcntl(wfd_.get(), F_SETPIPE_SZ, Size() + kDelta), - SyscallSucceeds()); - - // Fill the buffer and try to change down. - std::vector<char> buf(Size()); - ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - EXPECT_THAT(fcntl(wfd_.get(), F_SETPIPE_SZ, Size() - kDelta), - SyscallFailsWithErrno(EBUSY)); -} - -TEST_P(PipeTest, Streaming) { - SKIP_IF(!CreateBlocking()); - - // We make too many calls to go through full save cycles. - DisableSave ds; - - // Size() requires 2 syscalls, call it once and remember the value. - const int pipe_size = Size(); - - absl::Notification notify; - ScopedThread t([this, ¬ify, pipe_size]() { - // Don't start until it's full. - notify.WaitForNotification(); - for (int i = 0; i < pipe_size; i++) { - int rbuf; - ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf, i); - } - }); - - // Write 4 bytes * pipe_size. It will fill up the pipe once, notify the reader - // to start. Then we write pipe size worth 3 more times to ensure the reader - // can follow along. - ssize_t total = 0; - for (int i = 0; i < pipe_size; i++) { - ssize_t written = write(wfd_.get(), &i, sizeof(i)); - ASSERT_THAT(written, SyscallSucceedsWithValue(sizeof(i))); - total += written; - - // Is the next write about to fill up the buffer? Wake up the reader once. - if (total < pipe_size && (total + written) >= pipe_size) { - notify.Notify(); - } - } -} - -std::string PipeCreatorName(::testing::TestParamInfo<PipeCreator> info) { - return info.param.name_; // Use the name specified. -} - -INSTANTIATE_TEST_SUITE_P( - Pipes, PipeTest, - ::testing::Values( - PipeCreator{ - "pipe", - [](int fds[2], bool* is_blocking, bool* is_namedpipe) { - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - *is_blocking = true; - *is_namedpipe = false; - }, - }, - PipeCreator{ - "pipe2blocking", - [](int fds[2], bool* is_blocking, bool* is_namedpipe) { - ASSERT_THAT(pipe2(fds, 0), SyscallSucceeds()); - *is_blocking = true; - *is_namedpipe = false; - }, - }, - PipeCreator{ - "pipe2nonblocking", - [](int fds[2], bool* is_blocking, bool* is_namedpipe) { - ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); - *is_blocking = false; - *is_namedpipe = false; - }, - }, - PipeCreator{ - "smallbuffer", - [](int fds[2], bool* is_blocking, bool* is_namedpipe) { - // Set to the minimum available size (will round up). - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - ASSERT_THAT(fcntl(fds[0], F_SETPIPE_SZ, 0), SyscallSucceeds()); - *is_blocking = true; - *is_namedpipe = false; - }, - }, - PipeCreator{ - "namednonblocking", - [](int fds[2], bool* is_blocking, bool* is_namedpipe) { - // Create a new file-based pipe (non-blocking). - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); - SKIP_IF(mkfifo(file.path().c_str(), 0644) != 0); - fds[0] = open(file.path().c_str(), O_NONBLOCK | O_RDONLY); - fds[1] = open(file.path().c_str(), O_NONBLOCK | O_WRONLY); - MaybeSave(); - *is_blocking = false; - *is_namedpipe = true; - }, - }, - PipeCreator{ - "namedblocking", - [](int fds[2], bool* is_blocking, bool* is_namedpipe) { - // Create a new file-based pipe (blocking). - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); - SKIP_IF(mkfifo(file.path().c_str(), 0644) != 0); - ScopedThread t([&file, &fds]() { - fds[1] = open(file.path().c_str(), O_WRONLY); - }); - fds[0] = open(file.path().c_str(), O_RDONLY); - t.Join(); - MaybeSave(); - *is_blocking = true; - *is_namedpipe = true; - }, - }), - PipeCreatorName); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc deleted file mode 100644 index c42472474..000000000 --- a/test/syscalls/linux/poll.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <poll.h> -#include <sys/resource.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include <algorithm> -#include <iostream> - -#include "gtest/gtest.h" -#include "absl/synchronization/notification.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/base_poll_test.h" -#include "test/util/eventfd_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { -namespace { - -class PollTest : public BasePollTest { - protected: - void SetUp() override { BasePollTest::SetUp(); } - void TearDown() override { BasePollTest::TearDown(); } -}; - -TEST_F(PollTest, InvalidFds) { - // fds is invalid because it's null, but we tell ppoll the length is non-zero. - EXPECT_THAT(poll(nullptr, 1, 1), SyscallFailsWithErrno(EFAULT)); - EXPECT_THAT(poll(nullptr, -1, 1), SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(PollTest, NullFds) { - EXPECT_THAT(poll(nullptr, 0, 10), SyscallSucceeds()); -} - -TEST_F(PollTest, ZeroTimeout) { - EXPECT_THAT(poll(nullptr, 0, 0), SyscallSucceeds()); -} - -// If random S/R interrupts the poll, SIGALRM may be delivered before poll -// restarts, causing the poll to hang forever. -TEST_F(PollTest, NegativeTimeout_NoRandomSave) { - // Negative timeout mean wait forever so set a timer. - SetTimer(absl::Milliseconds(100)); - EXPECT_THAT(poll(nullptr, 0, -1), SyscallFailsWithErrno(EINTR)); - EXPECT_TRUE(TimerFired()); -} - -TEST_F(PollTest, NonBlockingEventPOLLIN) { - // Create a pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - FileDescriptor fd0(fds[0]); - FileDescriptor fd1(fds[1]); - - // Write some data to the pipe. - char s[] = "foo\n"; - ASSERT_THAT(WriteFd(fd1.get(), s, strlen(s) + 1), SyscallSucceeds()); - - // Poll on the reader fd with POLLIN event. - struct pollfd poll_fd = {fd0.get(), POLLIN, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1)); - - // Should trigger POLLIN event. - EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN); -} - -TEST_F(PollTest, BlockingEventPOLLIN) { - // Create a pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - FileDescriptor fd0(fds[0]); - FileDescriptor fd1(fds[1]); - - // Start a blocking poll on the read fd. - absl::Notification notify; - ScopedThread t([&fd0, ¬ify]() { - notify.Notify(); - - // Poll on the reader fd with POLLIN event. - struct pollfd poll_fd = {fd0.get(), POLLIN, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, -1), SyscallSucceedsWithValue(1)); - - // Should trigger POLLIN event. - EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN); - }); - - notify.WaitForNotification(); - absl::SleepFor(absl::Seconds(1.0)); - - // Write some data to the pipe. - char s[] = "foo\n"; - ASSERT_THAT(WriteFd(fd1.get(), s, strlen(s) + 1), SyscallSucceeds()); -} - -TEST_F(PollTest, NonBlockingEventPOLLHUP) { - // Create a pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - FileDescriptor fd0(fds[0]); - FileDescriptor fd1(fds[1]); - - // Close the writer fd. - fd1.reset(); - - // Poll on the reader fd with POLLIN event. - struct pollfd poll_fd = {fd0.get(), POLLIN, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1)); - - // Should trigger POLLHUP event. - EXPECT_EQ(poll_fd.revents & POLLHUP, POLLHUP); - - // Should not trigger POLLIN event. - EXPECT_EQ(poll_fd.revents & POLLIN, 0); -} - -TEST_F(PollTest, BlockingEventPOLLHUP) { - // Create a pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - FileDescriptor fd0(fds[0]); - FileDescriptor fd1(fds[1]); - - // Start a blocking poll on the read fd. - absl::Notification notify; - ScopedThread t([&fd0, ¬ify]() { - notify.Notify(); - - // Poll on the reader fd with POLLIN event. - struct pollfd poll_fd = {fd0.get(), POLLIN, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, -1), SyscallSucceedsWithValue(1)); - - // Should trigger POLLHUP event. - EXPECT_EQ(poll_fd.revents & POLLHUP, POLLHUP); - - // Should not trigger POLLIN event. - EXPECT_EQ(poll_fd.revents & POLLIN, 0); - }); - - notify.WaitForNotification(); - absl::SleepFor(absl::Seconds(1.0)); - - // Write some data and close the writer fd. - fd1.reset(); -} - -TEST_F(PollTest, NonBlockingEventPOLLERR) { - // Create a pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - FileDescriptor fd0(fds[0]); - FileDescriptor fd1(fds[1]); - - // Close the reader fd. - fd0.reset(); - - // Poll on the writer fd with POLLOUT event. - struct pollfd poll_fd = {fd1.get(), POLLOUT, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1)); - - // Should trigger POLLERR event. - EXPECT_EQ(poll_fd.revents & POLLERR, POLLERR); - - // Should also trigger POLLOUT event. - EXPECT_EQ(poll_fd.revents & POLLOUT, POLLOUT); -} - -// This test will validate that if an FD is already ready on some event, whether -// it's POLLIN or POLLOUT it will not immediately return unless that's actually -// what the caller was interested in. -TEST_F(PollTest, ImmediatelyReturnOnlyOnPollEvents) { - // Create a pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - FileDescriptor fd0(fds[0]); - FileDescriptor fd1(fds[1]); - - // Wait for read related event on the write side of the pipe, since a write - // is possible on fds[1] it would mean that POLLOUT would return immediately. - // We should make sure that we're not woken up with that state that we didn't - // specificially request. - constexpr int kTimeoutMs = 100; - struct pollfd poll_fd = {fd1.get(), POLLIN | POLLPRI | POLLRDHUP, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kTimeoutMs), - SyscallSucceedsWithValue(0)); // We should timeout. - EXPECT_EQ(poll_fd.revents, 0); // Nothing should be in returned events. - - // Now let's poll on POLLOUT and we should get back 1 fd as being ready and - // it should contain POLLOUT in the revents. - poll_fd.events = POLLOUT; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kTimeoutMs), - SyscallSucceedsWithValue(1)); // 1 fd should have an event. - EXPECT_EQ(poll_fd.revents, POLLOUT); // POLLOUT should be in revents. -} - -// This test validates that poll(2) while data is available immediately returns. -TEST_F(PollTest, PollLevelTriggered) { - int fds[2] = {}; - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, /*protocol=*/0, fds), - SyscallSucceeds()); - - FileDescriptor fd0(fds[0]); - FileDescriptor fd1(fds[1]); - - // Write two bytes to the socket. - const char* kBuf = "aa"; - ASSERT_THAT(RetryEINTR(send)(fd0.get(), kBuf, /*len=*/2, /*flags=*/0), - SyscallSucceedsWithValue(2)); // 2 bytes should be written. - - // Poll(2) should immediately return as there is data available to read. - constexpr int kInfiniteTimeout = -1; - struct pollfd poll_fd = {fd1.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, /*nfds=*/1, kInfiniteTimeout), - SyscallSucceedsWithValue(1)); // 1 fd should be ready to read. - EXPECT_NE(poll_fd.revents & POLLIN, 0); - - // Read a single byte. - char read_byte = 0; - ASSERT_THAT(RetryEINTR(recv)(fd1.get(), &read_byte, /*len=*/1, /*flags=*/0), - SyscallSucceedsWithValue(1)); // 1 byte should be read. - ASSERT_EQ(read_byte, 'a'); // We should have read a single 'a'. - - // Create a separate pollfd for our second poll. - struct pollfd poll_fd_after = {fd1.get(), POLLIN, 0}; - - // Poll(2) should again immediately return since we only read one byte. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd_after, /*nfds=*/1, kInfiniteTimeout), - SyscallSucceedsWithValue(1)); // 1 fd should be ready to read. - EXPECT_NE(poll_fd_after.revents & POLLIN, 0); -} - -TEST_F(PollTest, Nfds) { - // Stash value of RLIMIT_NOFILES. - struct rlimit rlim; - TEST_PCHECK(getrlimit(RLIMIT_NOFILE, &rlim) == 0); - - // gVisor caps the number of FDs that epoll can use beyond RLIMIT_NOFILE. - constexpr rlim_t gVisorMax = 1048576; - if (rlim.rlim_cur > gVisorMax) { - rlim.rlim_cur = gVisorMax; - TEST_PCHECK(setrlimit(RLIMIT_NOFILE, &rlim) == 0); - } - - rlim_t max_fds = rlim.rlim_cur; - std::cout << "Using limit: " << max_fds; - - // Create an eventfd. Since its value is initially zero, it is writable. - FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); - - // Create the biggest possible pollfd array such that each element is valid. - // Each entry in the 'fds' array refers to the eventfd and polls for - // "writable" events (events=POLLOUT). This essentially guarantees that the - // poll() is a no-op and allows negative testing of the 'nfds' parameter. - std::vector<struct pollfd> fds(max_fds + 1, - {.fd = efd.get(), .events = POLLOUT}); - - // Verify that 'nfds' up to RLIMIT_NOFILE are allowed. - EXPECT_THAT(RetryEINTR(poll)(fds.data(), 1, 1), SyscallSucceedsWithValue(1)); - EXPECT_THAT(RetryEINTR(poll)(fds.data(), max_fds / 2, 1), - SyscallSucceedsWithValue(max_fds / 2)); - EXPECT_THAT(RetryEINTR(poll)(fds.data(), max_fds, 1), - SyscallSucceedsWithValue(max_fds)); - - // If 'nfds' exceeds RLIMIT_NOFILE then it must fail with EINVAL. - EXPECT_THAT(poll(fds.data(), max_fds + 1, 1), SyscallFailsWithErrno(EINVAL)); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/ppoll.cc b/test/syscalls/linux/ppoll.cc deleted file mode 100644 index 8245a11e8..000000000 --- a/test/syscalls/linux/ppoll.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <poll.h> -#include <signal.h> -#include <sys/syscall.h> -#include <sys/time.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/base_poll_test.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -// Linux and glibc have a different idea of the sizeof sigset_t. When calling -// the syscall directly, use what the kernel expects. -unsigned kSigsetSize = SIGRTMAX / 8; - -// Linux ppoll(2) differs from the glibc wrapper function in that Linux updates -// the timeout with the amount of time remaining. In order to test this behavior -// we need to use the syscall directly. -int syscallPpoll(struct pollfd* fds, nfds_t nfds, struct timespec* timeout_ts, - const sigset_t* sigmask, unsigned mask_size) { - return syscall(SYS_ppoll, fds, nfds, timeout_ts, sigmask, mask_size); -} - -class PpollTest : public BasePollTest { - protected: - void SetUp() override { BasePollTest::SetUp(); } - void TearDown() override { BasePollTest::TearDown(); } -}; - -TEST_F(PpollTest, InvalidFds) { - // fds is invalid because it's null, but we tell ppoll the length is non-zero. - struct timespec timeout = {}; - sigset_t sigmask; - TEST_PCHECK(sigemptyset(&sigmask) == 0); - EXPECT_THAT(syscallPpoll(nullptr, 1, &timeout, &sigmask, kSigsetSize), - SyscallFailsWithErrno(EFAULT)); - EXPECT_THAT(syscallPpoll(nullptr, -1, &timeout, &sigmask, kSigsetSize), - SyscallFailsWithErrno(EINVAL)); -} - -// See that when fds is null, ppoll behaves like sleep. -TEST_F(PpollTest, NullFds) { - struct timespec timeout = absl::ToTimespec(absl::Milliseconds(10)); - ASSERT_THAT(syscallPpoll(nullptr, 0, &timeout, nullptr, 0), - SyscallSucceeds()); - EXPECT_EQ(timeout.tv_sec, 0); - EXPECT_EQ(timeout.tv_nsec, 0); -} - -TEST_F(PpollTest, ZeroTimeout) { - struct timespec timeout = {}; - ASSERT_THAT(syscallPpoll(nullptr, 0, &timeout, nullptr, 0), - SyscallSucceeds()); - EXPECT_EQ(timeout.tv_sec, 0); - EXPECT_EQ(timeout.tv_nsec, 0); -} - -// If random S/R interrupts the ppoll, SIGALRM may be delivered before ppoll -// restarts, causing the ppoll to hang forever. -TEST_F(PpollTest, NoTimeout_NoRandomSave) { - // When there's no timeout, ppoll may never return so set a timer. - SetTimer(absl::Milliseconds(100)); - // See that we get interrupted by the timer. - ASSERT_THAT(syscallPpoll(nullptr, 0, nullptr, nullptr, 0), - SyscallFailsWithErrno(EINTR)); - EXPECT_TRUE(TimerFired()); -} - -TEST_F(PpollTest, InvalidTimeoutNegative) { - struct timespec timeout = absl::ToTimespec(absl::Nanoseconds(-1)); - EXPECT_THAT(syscallPpoll(nullptr, 0, &timeout, nullptr, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(PpollTest, InvalidTimeoutNotNormalized) { - struct timespec timeout = {0, 1000000001}; - EXPECT_THAT(syscallPpoll(nullptr, 0, &timeout, nullptr, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(PpollTest, InvalidMaskSize) { - struct timespec timeout = {}; - sigset_t sigmask; - TEST_PCHECK(sigemptyset(&sigmask) == 0); - EXPECT_THAT(syscallPpoll(nullptr, 0, &timeout, &sigmask, 128), - SyscallFailsWithErrno(EINVAL)); -} - -// Verify that signals blocked by the ppoll mask (that would otherwise be -// allowed) do not interrupt ppoll. -TEST_F(PpollTest, SignalMaskBlocksSignal) { - absl::Duration duration(absl::Seconds(30)); - struct timespec timeout = absl::ToTimespec(duration); - absl::Duration timer_duration(absl::Seconds(10)); - - // Call with a mask that blocks SIGALRM. See that ppoll is not interrupted - // (i.e. returns 0) and that upon completion, the timer has fired. - sigset_t mask; - ASSERT_THAT(sigprocmask(0, nullptr, &mask), SyscallSucceeds()); - TEST_PCHECK(sigaddset(&mask, SIGALRM) == 0); - SetTimer(timer_duration); - MaybeSave(); - ASSERT_FALSE(TimerFired()); - ASSERT_THAT(syscallPpoll(nullptr, 0, &timeout, &mask, kSigsetSize), - SyscallSucceeds()); - EXPECT_TRUE(TimerFired()); - EXPECT_EQ(absl::DurationFromTimespec(timeout), absl::Duration()); -} - -// Verify that signals allowed by the ppoll mask (that would otherwise be -// blocked) interrupt ppoll. -TEST_F(PpollTest, SignalMaskAllowsSignal) { - absl::Duration duration(absl::Seconds(30)); - struct timespec timeout = absl::ToTimespec(duration); - absl::Duration timer_duration(absl::Seconds(10)); - - sigset_t mask; - ASSERT_THAT(sigprocmask(0, nullptr, &mask), SyscallSucceeds()); - - // Block SIGALRM. - auto cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGALRM)); - - // Call with a mask that unblocks SIGALRM. See that ppoll is interrupted. - SetTimer(timer_duration); - MaybeSave(); - ASSERT_FALSE(TimerFired()); - ASSERT_THAT(syscallPpoll(nullptr, 0, &timeout, &mask, kSigsetSize), - SyscallFailsWithErrno(EINTR)); - EXPECT_TRUE(TimerFired()); - EXPECT_GT(absl::DurationFromTimespec(timeout), absl::Duration()); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc deleted file mode 100644 index 04c5161f5..000000000 --- a/test/syscalls/linux/prctl.cc +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/prctl.h> -#include <sys/ptrace.h> -#include <sys/types.h> -#include <sys/wait.h> -#include <unistd.h> - -#include <string> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -ABSL_FLAG(bool, prctl_no_new_privs_test_child, false, - "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) " - "plus an offset (see test source)."); - -namespace gvisor { -namespace testing { - -namespace { - -#ifndef SUID_DUMP_DISABLE -#define SUID_DUMP_DISABLE 0 -#endif /* SUID_DUMP_DISABLE */ -#ifndef SUID_DUMP_USER -#define SUID_DUMP_USER 1 -#endif /* SUID_DUMP_USER */ -#ifndef SUID_DUMP_ROOT -#define SUID_DUMP_ROOT 2 -#endif /* SUID_DUMP_ROOT */ - -TEST(PrctlTest, NameInitialized) { - const size_t name_length = 20; - char name[name_length] = {}; - ASSERT_THAT(prctl(PR_GET_NAME, name), SyscallSucceeds()); - ASSERT_NE(std::string(name), ""); -} - -TEST(PrctlTest, SetNameLongName) { - const size_t name_length = 20; - const std::string long_name(name_length, 'A'); - ASSERT_THAT(prctl(PR_SET_NAME, long_name.c_str()), SyscallSucceeds()); - char truncated_name[name_length] = {}; - ASSERT_THAT(prctl(PR_GET_NAME, truncated_name), SyscallSucceeds()); - const size_t truncated_length = 15; - ASSERT_EQ(long_name.substr(0, truncated_length), std::string(truncated_name)); -} - -TEST(PrctlTest, ChildProcessName) { - constexpr size_t kMaxNameLength = 15; - - char parent_name[kMaxNameLength + 1] = {}; - memset(parent_name, 'a', kMaxNameLength); - - ASSERT_THAT(prctl(PR_SET_NAME, parent_name), SyscallSucceeds()); - - pid_t child_pid = fork(); - TEST_PCHECK(child_pid >= 0); - if (child_pid == 0) { - char child_name[kMaxNameLength + 1] = {}; - TEST_PCHECK(prctl(PR_GET_NAME, child_name) >= 0); - TEST_CHECK(memcmp(parent_name, child_name, sizeof(parent_name)) == 0); - _exit(0); - } - - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status =" << status; -} - -// Offset added to exit code from test child to distinguish from other abnormal -// exits. -constexpr int kPrctlNoNewPrivsTestChildExitBase = 100; - -TEST(PrctlTest, NoNewPrivsPreservedAcrossCloneForkAndExecve) { - // Check if no_new_privs is already set. If it is, we can still test that it's - // preserved across clone/fork/execve, but we also expect it to still be set - // at the end of the test. Otherwise, call prctl(PR_SET_NO_NEW_PRIVS) so as - // not to contaminate the original thread. - int no_new_privs; - ASSERT_THAT(no_new_privs = prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0), - SyscallSucceeds()); - ScopedThread([] { - ASSERT_THAT(prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0), SyscallSucceeds()); - EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0), - SyscallSucceedsWithValue(1)); - ScopedThread([] { - EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0), - SyscallSucceedsWithValue(1)); - // Note that these ASSERT_*s failing will only return from this thread, - // but this is the intended behavior. - pid_t child_pid = -1; - int execve_errno = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec("/proc/self/exe", - {"/proc/self/exe", "--prctl_no_new_privs_test_child"}, {}, - nullptr, &child_pid, &execve_errno)); - - ASSERT_GT(child_pid, 0); - ASSERT_EQ(execve_errno, 0); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), - SyscallSucceeds()); - ASSERT_TRUE(WIFEXITED(status)); - ASSERT_EQ(WEXITSTATUS(status), kPrctlNoNewPrivsTestChildExitBase + 1); - - EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0), - SyscallSucceedsWithValue(1)); - }); - EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0), - SyscallSucceedsWithValue(1)); - }); - EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0), - SyscallSucceedsWithValue(no_new_privs)); -} - -TEST(PrctlTest, PDeathSig) { - pid_t child_pid; - - // Make the new process' parent a separate thread since the parent death - // signal fires when the parent *thread* exits. - ScopedThread([&] { - child_pid = fork(); - TEST_CHECK(child_pid >= 0); - if (child_pid == 0) { - // In child process. - TEST_CHECK(prctl(PR_SET_PDEATHSIG, SIGKILL) >= 0); - int signo; - TEST_CHECK(prctl(PR_GET_PDEATHSIG, &signo) >= 0); - TEST_CHECK(signo == SIGKILL); - // Enable tracing, then raise SIGSTOP and expect our parent to suppress - // it. - TEST_CHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) >= 0); - raise(SIGSTOP); - // Sleep until killed by our parent death signal. sleep(3) is - // async-signal-safe, absl::SleepFor isn't. - while (true) { - sleep(10); - } - } - // In parent process. - - // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << "status = " << status; - - // Suppress the SIGSTOP and detach from the child. - ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); - }); - - // The child should have been killed by its parent death SIGKILL. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << "status = " << status; -} - -// This test is to validate that calling prctl with PR_SET_MM without the -// CAP_SYS_RESOURCE returns EPERM. -TEST(PrctlTest, InvalidPrSetMM) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))) { - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_RESOURCE, - false)); // Drop capability to test below. - } - ASSERT_THAT(prctl(PR_SET_MM, 0, 0, 0, 0), SyscallFailsWithErrno(EPERM)); -} - -// Sanity check that dumpability is remembered. -TEST(PrctlTest, SetGetDumpability) { - int before; - ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); - auto cleanup = Cleanup([before] { - ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); - }); - - EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_DISABLE), SyscallSucceeds()); - EXPECT_THAT(prctl(PR_GET_DUMPABLE), - SyscallSucceedsWithValue(SUID_DUMP_DISABLE)); - - EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_USER), SyscallSucceeds()); - EXPECT_THAT(prctl(PR_GET_DUMPABLE), SyscallSucceedsWithValue(SUID_DUMP_USER)); -} - -// SUID_DUMP_ROOT cannot be set via PR_SET_DUMPABLE. -TEST(PrctlTest, RootDumpability) { - EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_ROOT), - SyscallFailsWithErrno(EINVAL)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - gvisor::testing::TestInit(&argc, &argv); - - if (absl::GetFlag(FLAGS_prctl_no_new_privs_test_child)) { - exit(gvisor::testing::kPrctlNoNewPrivsTestChildExitBase + - prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)); - } - - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc deleted file mode 100644 index c4e9cf528..000000000 --- a/test/syscalls/linux/prctl_setuid.cc +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sched.h> -#include <sys/prctl.h> - -#include <string> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "test/util/capability_util.h" -#include "test/util/logging.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID"); -// This flag is used to verify that after an exec PR_GET_KEEPCAPS -// returns 0, the return code will be offset by kPrGetKeepCapsExitBase. -ABSL_FLAG(bool, prctl_pr_get_keepcaps, false, - "If true the test will verify that prctl with pr_get_keepcaps" - "returns 0. The test will exit with the result of that check."); - -// These tests exist seperately from prctl because we need to start -// them as root. Setuid() has the behavior that permissions are fully -// removed if one of the UIDs were 0 before a setuid() call. This -// behavior can be changed by using PR_SET_KEEPCAPS and that is what -// is tested here. -// -// Reference setuid(2): -// The setuid() function checks the effective user ID of -// the caller and if it is the superuser, all process-related user ID's -// are set to uid. After this has occurred, it is impossible for the -// program to regain root privileges. -// -// Thus, a set-user-ID-root program wishing to temporarily drop root -// privileges, assume the identity of an unprivileged user, and then -// regain root privileges afterward cannot use setuid(). You can -// accomplish this with seteuid(2). -namespace gvisor { -namespace testing { - -// Offset added to exit code from test child to distinguish from other abnormal -// exits. -constexpr int kPrGetKeepCapsExitBase = 100; - -namespace { - -class PrctlKeepCapsSetuidTest : public ::testing::Test { - protected: - void SetUp() override { - // PR_GET_KEEPCAPS will only return 0 or 1 (on success). - ASSERT_THAT(original_keepcaps_ = prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0), - SyscallSucceeds()); - ASSERT_TRUE(original_keepcaps_ == 0 || original_keepcaps_ == 1); - } - - void TearDown() override { - // Restore PR_SET_KEEPCAPS. - ASSERT_THAT(prctl(PR_SET_KEEPCAPS, original_keepcaps_, 0, 0, 0), - SyscallSucceeds()); - - // Verify that it was restored. - ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0), - SyscallSucceedsWithValue(original_keepcaps_)); - } - - // The original keep caps value exposed so tests can use it if they need. - int original_keepcaps_ = 0; -}; - -// This test will verify that a bad value, eg. not 0 or 1 for -// PR_SET_KEEPCAPS will return EINVAL as required by prctl(2). -TEST_F(PrctlKeepCapsSetuidTest, PrctlBadArgsToKeepCaps) { - ASSERT_THAT(prctl(PR_SET_KEEPCAPS, 2, 0, 0, 0), - SyscallFailsWithErrno(EINVAL)); -} - -// This test will verify that a setuid(2) without PR_SET_KEEPCAPS will cause -// all capabilities to be dropped. -TEST_F(PrctlKeepCapsSetuidTest, SetUidNoKeepCaps) { - // getuid(2) never fails. - if (getuid() != 0) { - SKIP_IF(!IsRunningOnGvisor()); - FAIL() << "User is not root on gvisor platform."; - } - - // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting - // this test. Otherwise, the files are created by root (UID before the - // test), but cannot be opened by the `uid` set below after the test. After - // calling setuid(non-zero-UID), there is no way to get root privileges - // back. - ScopedThread([] { - // Start by verifying we have a capability. - TEST_CHECK(HaveCapability(CAP_SYS_ADMIN).ValueOrDie()); - - // Verify that PR_GET_KEEPCAPS is disabled. - ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0), - SyscallSucceedsWithValue(0)); - - // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. POSIX threads, however, require that - // all threads have the same UIDs, so using the setuid wrapper sets all - // threads' real UID. - EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)), - SyscallSucceeds()); - - // Verify that we changed uid. - EXPECT_THAT(getuid(), - SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid))); - - // Verify we lost the capability in the effective set, this always happens. - TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie()); - - // We should have also lost it in the permitted set by the setuid() so - // SetCapability should fail when we try to add it back to the effective set - ASSERT_FALSE(SetCapability(CAP_SYS_ADMIN, true).ok()); - }); -} - -// This test will verify that a setuid with PR_SET_KEEPCAPS will cause -// capabilities to be retained after we switch away from the root user. -TEST_F(PrctlKeepCapsSetuidTest, SetUidKeepCaps) { - // getuid(2) never fails. - if (getuid() != 0) { - SKIP_IF(!IsRunningOnGvisor()); - FAIL() << "User is not root on gvisor platform."; - } - - // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting - // this test. Otherwise, the files are created by root (UID before the - // test), but cannot be opened by the `uid` set below after the test. After - // calling setuid(non-zero-UID), there is no way to get root privileges - // back. - ScopedThread([] { - // Start by verifying we have a capability. - TEST_CHECK(HaveCapability(CAP_SYS_ADMIN).ValueOrDie()); - - // Set PR_SET_KEEPCAPS. - ASSERT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds()); - - // Verify PR_SET_KEEPCAPS was set before we proceed. - ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0), - SyscallSucceedsWithValue(1)); - - // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. POSIX threads, however, require that - // all threads have the same UIDs, so using the setuid wrapper sets all - // threads' real UID. - EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)), - SyscallSucceeds()); - - // Verify that we changed uid. - EXPECT_THAT(getuid(), - SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid))); - - // Verify we lost the capability in the effective set, this always happens. - TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie()); - - // We lost the capability in the effective set, but it will still - // exist in the permitted set so we can elevate the capability. - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, true)); - - // Verify we got back the capability in the effective set. - TEST_CHECK(HaveCapability(CAP_SYS_ADMIN).ValueOrDie()); - }); -} - -// This test will verify that PR_SET_KEEPCAPS is not retained -// across an execve. According to prctl(2): -// "The "keep capabilities" value will be reset to 0 on subsequent -// calls to execve(2)." -TEST_F(PrctlKeepCapsSetuidTest, NoKeepCapsAfterExec) { - ASSERT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds()); - - // Verify PR_SET_KEEPCAPS was set before we proceed. - ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0), SyscallSucceedsWithValue(1)); - - pid_t child_pid = -1; - int execve_errno = 0; - // Do an exec and then verify that PR_GET_KEEPCAPS returns 0 - // see the body of main below. - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec( - "/proc/self/exe", {"/proc/self/exe", "--prctl_pr_get_keepcaps"}, {}, - nullptr, &child_pid, &execve_errno)); - - ASSERT_GT(child_pid, 0); - ASSERT_EQ(execve_errno, 0); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - ASSERT_TRUE(WIFEXITED(status)); - // PR_SET_KEEPCAPS should have been cleared by the exec. - // Success should return gvisor::testing::kPrGetKeepCapsExitBase + 0 - ASSERT_EQ(WEXITSTATUS(status), kPrGetKeepCapsExitBase); -} - -TEST_F(PrctlKeepCapsSetuidTest, NoKeepCapsAfterNewUserNamespace) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - - // Fork to avoid changing the user namespace of the original test process. - pid_t const child_pid = fork(); - - if (child_pid == 0) { - // Verify that the keepcaps flag is set to 0 when we change user namespaces. - TEST_PCHECK(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0) == 0); - MaybeSave(); - - TEST_PCHECK(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0) == 1); - MaybeSave(); - - TEST_PCHECK(unshare(CLONE_NEWUSER) == 0); - MaybeSave(); - - TEST_PCHECK(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0) == 0); - MaybeSave(); - - _exit(0); - } - - int status; - ASSERT_THAT(child_pid, SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status = " << status; -} - -// This test will verify that PR_SET_KEEPCAPS and PR_GET_KEEPCAPS work correctly -TEST_F(PrctlKeepCapsSetuidTest, PrGetKeepCaps) { - // Set PR_SET_KEEPCAPS to the negation of the original. - ASSERT_THAT(prctl(PR_SET_KEEPCAPS, !original_keepcaps_, 0, 0, 0), - SyscallSucceeds()); - - // Verify it was set. - ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0), - SyscallSucceedsWithValue(!original_keepcaps_)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - gvisor::testing::TestInit(&argc, &argv); - - if (absl::GetFlag(FLAGS_prctl_pr_get_keepcaps)) { - return gvisor::testing::kPrGetKeepCapsExitBase + - prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0); - } - - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc deleted file mode 100644 index 2cecf2e5f..000000000 --- a/test/syscalls/linux/pread64.cc +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <sys/mman.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class Pread64Test : public ::testing::Test { - void SetUp() override { - name_ = NewTempAbsPath(); - ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_CREAT, 0644)); - } - - void TearDown() override { unlink(name_.c_str()); } - - public: - std::string name_; -}; - -TEST(Pread64TestNoTempFile, BadFileDescriptor) { - char buf[1024]; - EXPECT_THAT(pread64(-1, buf, 1024, 0), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(Pread64Test, ZeroBuffer) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDWR)); - - char msg[] = "hello world"; - EXPECT_THAT(pwrite64(fd.get(), msg, strlen(msg), 0), - SyscallSucceedsWithValue(strlen(msg))); - - char buf[10]; - EXPECT_THAT(pread64(fd.get(), buf, 0, 0), SyscallSucceedsWithValue(0)); -} - -TEST_F(Pread64Test, BadBuffer) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDWR)); - - char msg[] = "hello world"; - EXPECT_THAT(pwrite64(fd.get(), msg, strlen(msg), 0), - SyscallSucceedsWithValue(strlen(msg))); - - char* bad_buffer = nullptr; - EXPECT_THAT(pread64(fd.get(), bad_buffer, 1024, 0), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(Pread64Test, WriteOnlyNotReadable) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_WRONLY)); - - char buf[1024]; - EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(Pread64Test, DirNotReadable) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY)); - - char buf[1024]; - EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EISDIR)); -} - -TEST_F(Pread64Test, BadOffset) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDONLY)); - - char buf[1024]; - EXPECT_THAT(pread64(fd.get(), buf, 1024, -1), SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(Pread64Test, OffsetNotIncremented) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDWR)); - - char msg[] = "hello world"; - EXPECT_THAT(write(fd.get(), msg, strlen(msg)), - SyscallSucceedsWithValue(strlen(msg))); - int offset; - EXPECT_THAT(offset = lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); - - char buf1[1024]; - EXPECT_THAT(pread64(fd.get(), buf1, 1024, 0), - SyscallSucceedsWithValue(strlen(msg))); - EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(offset)); - - char buf2[1024]; - EXPECT_THAT(pread64(fd.get(), buf2, 1024, 3), - SyscallSucceedsWithValue(strlen(msg) - 3)); - EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(offset)); -} - -TEST_F(Pread64Test, EndOfFile) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDONLY)); - - char buf[1024]; - EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0)); -} - -TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) { - int sock_fds[2]; - EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds()); - - char buf[1024]; - EXPECT_THAT(pread64(sock_fds[0], buf, 1024, 0), - SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(pread64(sock_fds[1], buf, 1024, 0), - SyscallFailsWithErrno(ESPIPE)); - - EXPECT_THAT(close(sock_fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(sock_fds[1]), SyscallSucceeds()); -} - -TEST(Pread64TestNoTempFile, CantReadPipe) { - char buf[1024]; - - int pipe_fds[2]; - EXPECT_THAT(pipe(pipe_fds), SyscallSucceeds()); - - EXPECT_THAT(pread64(pipe_fds[0], buf, 1024, 0), - SyscallFailsWithErrno(ESPIPE)); - - EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc deleted file mode 100644 index 5b0743fe9..000000000 --- a/test/syscalls/linux/preadv.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/syscall.h> -#include <sys/types.h> -#include <sys/uio.h> -#include <sys/wait.h> -#include <unistd.h> - -#include <atomic> -#include <string> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/memory_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Stress copy-on-write. Attempts to reproduce b/38430174. -TEST(PreadvTest, MMConcurrencyStress) { - // Fill a one-page file with zeroes (the contents don't really matter). - const auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - /* parent = */ GetAbsoluteTestTmpdir(), - /* content = */ std::string(kPageSize, 0), TempPath::kDefaultFileMode)); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); - - // Get a one-page private mapping to read to. - const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - - // Repeatedly fork in a separate thread to force the mapping to become - // copy-on-write. - std::atomic<bool> done(false); - const ScopedThread t([&] { - while (!done.load()) { - const pid_t pid = fork(); - TEST_CHECK(pid >= 0); - if (pid == 0) { - // In child. The parent was obviously multithreaded, so it's neither - // safe nor necessary to do much more than exit. - syscall(SYS_exit_group, 0); - } - int status; - ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), - SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status = " << status; - } - }); - - // Repeatedly read to the mapping. - struct iovec iov[2]; - iov[0].iov_base = m.ptr(); - iov[0].iov_len = kPageSize / 2; - iov[1].iov_base = reinterpret_cast<void*>(m.addr() + kPageSize / 2); - iov[1].iov_len = kPageSize / 2; - constexpr absl::Duration kTestDuration = absl::Seconds(5); - const absl::Time end = absl::Now() + kTestDuration; - while (absl::Now() < end) { - // Among other causes, save/restore cycles may cause interruptions resulting - // in partial reads, so we don't expect any particular return value. - EXPECT_THAT(RetryEINTR(preadv)(fd.get(), iov, 2, 0), SyscallSucceeds()); - } - - // Stop the other thread. - done.store(true); - - // The test passes if it neither deadlocks nor crashes the OS. -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc deleted file mode 100644 index 4a9acd7ae..000000000 --- a/test/syscalls/linux/preadv2.cc +++ /dev/null @@ -1,280 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <sys/uio.h> - -#include <string> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -#ifndef SYS_preadv2 -#if defined(__x86_64__) -#define SYS_preadv2 327 -#elif defined(__aarch64__) -#define SYS_preadv2 286 -#else -#error "Unknown architecture" -#endif -#endif // SYS_preadv2 - -#ifndef RWF_HIPRI -#define RWF_HIPRI 0x1 -#endif // RWF_HIPRI - -constexpr int kBufSize = 1024; - -std::string SetContent() { - std::string content; - for (int i = 0; i < kBufSize; i++) { - content += static_cast<char>((i % 10) + '0'); - } - return content; -} - -ssize_t preadv2(unsigned long fd, const struct iovec* iov, unsigned long iovcnt, - off_t offset, unsigned long flags) { - // syscall on preadv2 does some weird things (see man syscall and search - // preadv2), so we insert a 0 to word align the flags argument on native. - return syscall(SYS_preadv2, fd, iov, iovcnt, offset, 0, flags); -} - -// This test is the base case where we call preadv (no offset, no flags). -TEST(Preadv2Test, TestBaseCall) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - std::string content = SetContent(); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), content, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - std::vector<char> buf(kBufSize); - struct iovec iov[2]; - iov[0].iov_base = buf.data(); - iov[0].iov_len = buf.size() / 2; - iov[1].iov_base = static_cast<char*>(iov[0].iov_base) + (content.size() / 2); - iov[1].iov_len = content.size() / 2; - - EXPECT_THAT(preadv2(fd.get(), iov, /*iovcnt*/ 2, /*offset=*/0, /*flags=*/0), - SyscallSucceedsWithValue(kBufSize)); - - EXPECT_EQ(content, std::string(buf.data(), buf.size())); -} - -// This test is where we call preadv with an offset and no flags. -TEST(Preadv2Test, TestValidPositiveOffset) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - std::string content = SetContent(); - const std::string prefix = "0"; - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), prefix + content, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - std::vector<char> buf(kBufSize, '0'); - struct iovec iov; - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - EXPECT_THAT(preadv2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/prefix.size(), - /*flags=*/0), - SyscallSucceedsWithValue(kBufSize)); - - EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - - EXPECT_EQ(content, std::string(buf.data(), buf.size())); -} - -// This test is the base case where we call readv by using -1 as the offset. The -// read should use the file offset, so the test increments it by one prior to -// calling preadv2. -TEST(Preadv2Test, TestNegativeOneOffset) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - std::string content = SetContent(); - const std::string prefix = "231"; - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), prefix + content, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - ASSERT_THAT(lseek(fd.get(), prefix.size(), SEEK_SET), - SyscallSucceedsWithValue(prefix.size())); - - std::vector<char> buf(kBufSize, '0'); - struct iovec iov; - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - EXPECT_THAT(preadv2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/-1, /*flags=*/0), - SyscallSucceedsWithValue(kBufSize)); - - EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(prefix.size() + buf.size())); - - EXPECT_EQ(content, std::string(buf.data(), buf.size())); -} - -// preadv2 requires if the RWF_HIPRI flag is passed, the fd must be opened with -// O_DIRECT. This test implements a correct call with the RWF_HIPRI flag. -TEST(Preadv2Test, TestCallWithRWF_HIPRI) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - std::string content = SetContent(); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), content, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - EXPECT_THAT(fsync(fd.get()), SyscallSucceeds()); - - std::vector<char> buf(kBufSize, '0'); - struct iovec iov; - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - EXPECT_THAT( - preadv2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/RWF_HIPRI), - SyscallSucceedsWithValue(kBufSize)); - - EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - - EXPECT_EQ(content, std::string(buf.data(), buf.size())); -} -// This test calls preadv2 with an invalid flag. -TEST(Preadv2Test, TestInvalidFlag) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY | O_DIRECT)); - - std::vector<char> buf(kBufSize, '0'); - struct iovec iov; - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - EXPECT_THAT(preadv2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/0xF0), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -// This test calls preadv2 with an invalid offset. -TEST(Preadv2Test, TestInvalidOffset) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY | O_DIRECT)); - - auto iov = absl::make_unique<struct iovec[]>(1); - iov[0].iov_base = nullptr; - iov[0].iov_len = 0; - - EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/-8, - /*flags=*/0), - SyscallFailsWithErrno(EINVAL)); -} - -// This test calls preadv with a file set O_WRONLY. -TEST(Preadv2Test, TestUnreadableFile) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); - - auto iov = absl::make_unique<struct iovec[]>(1); - iov[0].iov_base = nullptr; - iov[0].iov_len = 0; - - EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, - /*offset=*/0, /*flags=*/0), - SyscallFailsWithErrno(EBADF)); -} - -// Calling preadv2 with a non-negative offset calls preadv. Calling preadv with -// an unseekable file is not allowed. A pipe is used for an unseekable file. -TEST(Preadv2Test, TestUnseekableFileInvalid) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - int pipe_fds[2]; - - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - - auto iov = absl::make_unique<struct iovec[]>(1); - iov[0].iov_base = nullptr; - iov[0].iov_len = 0; - - EXPECT_THAT(preadv2(pipe_fds[0], iov.get(), /*iovcnt=*/1, - /*offset=*/2, /*flags=*/0), - SyscallFailsWithErrno(ESPIPE)); - - EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds()); -} - -TEST(Preadv2Test, TestUnseekableFileValid) { - SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - int pipe_fds[2]; - - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - - std::vector<char> content(32, 'X'); - - EXPECT_THAT(write(pipe_fds[1], content.data(), content.size()), - SyscallSucceedsWithValue(content.size())); - - std::vector<char> buf(content.size()); - auto iov = absl::make_unique<struct iovec[]>(1); - iov[0].iov_base = buf.data(); - iov[0].iov_len = buf.size(); - - EXPECT_THAT(preadv2(pipe_fds[0], iov.get(), /*iovcnt=*/1, - /*offset=*/static_cast<off_t>(-1), /*flags=*/0), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(content, buf); - - EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/priority.cc b/test/syscalls/linux/priority.cc deleted file mode 100644 index 1d9bdfa70..000000000 --- a/test/syscalls/linux/priority.cc +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/resource.h> -#include <sys/time.h> -#include <sys/types.h> -#include <unistd.h> - -#include <string> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "test/util/capability_util.h" -#include "test/util/fs_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// These tests are for both the getpriority(2) and setpriority(2) syscalls -// These tests are very rudimentary because getpriority and setpriority -// have not yet been fully implemented. - -// Getpriority does something -TEST(GetpriorityTest, Implemented) { - // "getpriority() can legitimately return the value -1, it is necessary to - // clear the external variable errno prior to the call" - errno = 0; - EXPECT_THAT(getpriority(PRIO_PROCESS, /*who=*/0), SyscallSucceeds()); -} - -// Invalid which -TEST(GetpriorityTest, InvalidWhich) { - errno = 0; - EXPECT_THAT(getpriority(/*which=*/3, /*who=*/0), - SyscallFailsWithErrno(EINVAL)); -} - -// Process is found when which=PRIO_PROCESS -TEST(GetpriorityTest, ValidWho) { - errno = 0; - EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()), SyscallSucceeds()); -} - -// Process is not found when which=PRIO_PROCESS -TEST(GetpriorityTest, InvalidWho) { - errno = 0; - // Flaky, but it's tough to avoid a race condition when finding an unused pid - EXPECT_THAT(getpriority(PRIO_PROCESS, /*who=*/INT_MAX - 1), - SyscallFailsWithErrno(ESRCH)); -} - -// Setpriority does something -TEST(SetpriorityTest, Implemented) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE))); - - // No need to clear errno for setpriority(): - // "The setpriority() call returns 0 if there is no error, or -1 if there is" - EXPECT_THAT(setpriority(PRIO_PROCESS, /*who=*/0, /*nice=*/16), - SyscallSucceeds()); -} - -// Invalid which -TEST(Setpriority, InvalidWhich) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE))); - - EXPECT_THAT(setpriority(/*which=*/3, /*who=*/0, /*nice=*/16), - SyscallFailsWithErrno(EINVAL)); -} - -// Process is found when which=PRIO_PROCESS -TEST(SetpriorityTest, ValidWho) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE))); - - EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), /*nice=*/16), - SyscallSucceeds()); -} - -// niceval is within the range [-20, 19] -TEST(SetpriorityTest, InsideRange) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE))); - - // Set 0 < niceval < 19 - int nice = 12; - EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), nice), SyscallSucceeds()); - - errno = 0; - EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()), - SyscallSucceedsWithValue(nice)); - - // Set -20 < niceval < 0 - nice = -12; - EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), nice), SyscallSucceeds()); - - errno = 0; - EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()), - SyscallSucceedsWithValue(nice)); -} - -// Verify that priority/niceness are exposed via /proc/PID/stat. -TEST(SetpriorityTest, NicenessExposedViaProcfs) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE))); - - constexpr int kNiceVal = 12; - ASSERT_THAT(setpriority(PRIO_PROCESS, getpid(), kNiceVal), SyscallSucceeds()); - - errno = 0; - ASSERT_THAT(getpriority(PRIO_PROCESS, getpid()), - SyscallSucceedsWithValue(kNiceVal)); - - // Now verify we can read that same value via /proc/self/stat. - std::string proc_stat; - ASSERT_NO_ERRNO(GetContents("/proc/self/stat", &proc_stat)); - std::vector<std::string> pieces = absl::StrSplit(proc_stat, ' '); - ASSERT_GT(pieces.size(), 20); - - int niceness_procfs = 0; - ASSERT_TRUE(absl::SimpleAtoi(pieces[18], &niceness_procfs)); - EXPECT_EQ(niceness_procfs, kNiceVal); -} - -// In the kernel's implementation, values outside the range of [-20, 19] are -// truncated to these minimum and maximum values. See -// https://elixir.bootlin.com/linux/v4.4/source/kernel/sys.c#L190 -TEST(SetpriorityTest, OutsideRange) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE))); - - // Set niceval > 19 - EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), /*nice=*/100), - SyscallSucceeds()); - - errno = 0; - // Test niceval truncated to 19 - EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()), - SyscallSucceedsWithValue(/*maxnice=*/19)); - - // Set niceval < -20 - EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), /*nice=*/-100), - SyscallSucceeds()); - - errno = 0; - // Test niceval truncated to -20 - EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()), - SyscallSucceedsWithValue(/*minnice=*/-20)); -} - -// Process is not found when which=PRIO_PROCESS -TEST(SetpriorityTest, InvalidWho) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE))); - - // Flaky, but it's tough to avoid a race condition when finding an unused pid - EXPECT_THAT(setpriority(PRIO_PROCESS, - /*who=*/INT_MAX - 1, - /*nice=*/16), - SyscallFailsWithErrno(ESRCH)); -} - -// Nice succeeds, correctly modifies (or in this case does not -// modify priority of process -TEST(SetpriorityTest, NiceSucceeds) { - errno = 0; - const int priority_before = getpriority(PRIO_PROCESS, /*who=*/0); - ASSERT_THAT(nice(/*inc=*/0), SyscallSucceeds()); - - // nice(0) should not change priority - EXPECT_EQ(priority_before, getpriority(PRIO_PROCESS, /*who=*/0)); -} - -// Threads resulting from clone() maintain parent's priority -// Changes to child priority do not affect parent's priority -TEST(GetpriorityTest, CloneMaintainsPriority) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE))); - - constexpr int kParentPriority = 16; - constexpr int kChildPriority = 14; - ASSERT_THAT(setpriority(PRIO_PROCESS, getpid(), kParentPriority), - SyscallSucceeds()); - - ScopedThread th([]() { - // Check that priority equals that of parent thread - pid_t my_tid; - EXPECT_THAT(my_tid = syscall(__NR_gettid), SyscallSucceeds()); - EXPECT_THAT(getpriority(PRIO_PROCESS, my_tid), - SyscallSucceedsWithValue(kParentPriority)); - - // Change the child thread's priority - EXPECT_THAT(setpriority(PRIO_PROCESS, my_tid, kChildPriority), - SyscallSucceeds()); - }); - th.Join(); - - // Check that parent's priority reemained the same even though - // the child's priority was altered - EXPECT_EQ(kParentPriority, getpriority(PRIO_PROCESS, syscall(__NR_gettid))); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/priority_execve.cc b/test/syscalls/linux/priority_execve.cc deleted file mode 100644 index 5cb343bad..000000000 --- a/test/syscalls/linux/priority_execve.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <stdio.h> -#include <stdlib.h> -#include <sys/resource.h> -#include <sys/time.h> -#include <sys/types.h> -#include <unistd.h> - -int main(int argc, char** argv, char** envp) { - errno = 0; - int prio = getpriority(PRIO_PROCESS, getpid()); - - // NOTE: getpriority() can legitimately return negative values - // in the range [-20, 0). If errno is set, exit with a value that - // could not be reached by a valid priority. Valid exit values - // for the test are in the range [1, 40], so we'll use 0. - if (errno != 0) { - printf("getpriority() failed with errno = %d\n", errno); - exit(0); - } - - // Used by test to verify priority is being maintained through - // calls to execve(). Since prio should always be in the range - // [-20, 19], we offset by 20 so as not to have negative exit codes. - exit(20 - prio); - - return 0; -} diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc deleted file mode 100644 index 5a70f6c3b..000000000 --- a/test/syscalls/linux/proc.cc +++ /dev/null @@ -1,2101 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <elf.h> -#include <errno.h> -#include <fcntl.h> -#include <limits.h> -#include <sched.h> -#include <signal.h> -#include <stddef.h> -#include <stdint.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <sys/mman.h> -#include <sys/prctl.h> -#include <sys/stat.h> -#include <sys/utsname.h> -#include <syscall.h> -#include <unistd.h> - -#include <algorithm> -#include <atomic> -#include <functional> -#include <iostream> -#include <map> -#include <memory> -#include <ostream> -#include <regex> -#include <string> -#include <unordered_set> -#include <utility> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/ascii.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/synchronization/notification.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/memory_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/time_util.h" -#include "test/util/timer_util.h" - -// NOTE(magi): No, this isn't really a syscall but this is a really simple -// way to get it tested on both gVisor, PTrace and Linux. - -using ::testing::AllOf; -using ::testing::AnyOf; -using ::testing::ContainerEq; -using ::testing::Contains; -using ::testing::ContainsRegex; -using ::testing::Eq; -using ::testing::Gt; -using ::testing::HasSubstr; -using ::testing::IsSupersetOf; -using ::testing::Pair; -using ::testing::UnorderedElementsAre; -using ::testing::UnorderedElementsAreArray; - -// Exported by glibc. -extern char** environ; - -namespace gvisor { -namespace testing { -namespace { - -#ifndef SUID_DUMP_DISABLE -#define SUID_DUMP_DISABLE 0 -#endif /* SUID_DUMP_DISABLE */ -#ifndef SUID_DUMP_USER -#define SUID_DUMP_USER 1 -#endif /* SUID_DUMP_USER */ -#ifndef SUID_DUMP_ROOT -#define SUID_DUMP_ROOT 2 -#endif /* SUID_DUMP_ROOT */ - -#if defined(__x86_64__) || defined(__i386__) -// This list of "required" fields is taken from reading the file -// arch/x86/kernel/cpu/proc.c and seeing which fields will be unconditionally -// printed by the kernel. -static const char* required_fields[] = { - "processor", - "vendor_id", - "cpu family", - "model\t\t:", - "model name", - "stepping", - "cpu MHz", - "fpu\t\t:", - "fpu_exception", - "cpuid level", - "wp", - "bogomips", - "clflush size", - "cache_alignment", - "address sizes", - "power management", -}; -#elif __aarch64__ -// This list of "required" fields is taken from reading the file -// arch/arm64/kernel/cpuinfo.c and seeing which fields will be unconditionally -// printed by the kernel. -static const char* required_fields[] = { - "processor", "BogoMIPS", "Features", "CPU implementer", - "CPU architecture", "CPU variant", "CPU part", "CPU revision", -}; -#else -#error "Unknown architecture" -#endif - -// Takes the subprocess command line and pid. -// If it returns !OK, WithSubprocess returns immediately. -using SubprocessCallback = std::function<PosixError(int)>; - -std::vector<std::string> saved_argv; // NOLINT - -// Helper function to dump /proc/{pid}/status and check the -// state data. State should = "Z" for zombied or "RSD" for -// running, interruptible sleeping (S), or uninterruptible sleep -// (D). -void CompareProcessState(absl::string_view state, int pid) { - auto status_file = ASSERT_NO_ERRNO_AND_VALUE( - GetContents(absl::StrCat("/proc/", pid, "/status"))); - // N.B. POSIX extended regexes don't support shorthand character classes (\w) - // inside of brackets. - EXPECT_THAT(status_file, - ContainsRegex(absl::StrCat("State:.[", state, - R"EOL(]\s+\([a-zA-Z ]+\))EOL"))); -} - -// Run callbacks while a subprocess is running, zombied, and/or exited. -PosixError WithSubprocess(SubprocessCallback const& running, - SubprocessCallback const& zombied, - SubprocessCallback const& exited) { - int pipe_fds[2] = {}; - if (pipe(pipe_fds) < 0) { - return PosixError(errno, "pipe"); - } - - int child_pid = fork(); - if (child_pid < 0) { - return PosixError(errno, "fork"); - } - - if (child_pid == 0) { - close(pipe_fds[0]); // Close the read end. - const DisableSave ds; // Timing issues. - - // Write to the pipe to tell it we're ready. - char buf = 'a'; - int res = 0; - res = WriteFd(pipe_fds[1], &buf, sizeof(buf)); - TEST_CHECK_MSG(res == sizeof(buf), "Write failure in subprocess"); - - while (true) { - SleepSafe(absl::Milliseconds(100)); - } - } - - close(pipe_fds[1]); // Close the write end. - - int status = 0; - auto wait_cleanup = Cleanup([child_pid, &status] { - EXPECT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); - }); - auto kill_cleanup = Cleanup([child_pid] { - EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); - }); - - // Wait for the child. - char buf = 0; - int res = ReadFd(pipe_fds[0], &buf, sizeof(buf)); - if (res < 0) { - return PosixError(errno, "Read from pipe"); - } else if (res == 0) { - return PosixError(EPIPE, "Unable to read from pipe: EOF"); - } - - if (running) { - // The first arg, RSD, refers to a "running process", or a process with a - // state of Running (R), Interruptable Sleep (S) or Uninterruptable - // Sleep (D). - CompareProcessState("RSD", child_pid); - RETURN_IF_ERRNO(running(child_pid)); - } - - // Kill the process. - kill_cleanup.Release()(); - siginfo_t info; - // Wait until the child process has exited (WEXITED flag) but don't - // reap the child (WNOWAIT flag). - EXPECT_THAT(waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED), - SyscallSucceeds()); - - if (zombied) { - // Arg of "Z" refers to a Zombied Process. - CompareProcessState("Z", child_pid); - RETURN_IF_ERRNO(zombied(child_pid)); - } - - // Wait on the process. - wait_cleanup.Release()(); - // If the process is reaped, then then this should return - // with ECHILD. - EXPECT_THAT(waitpid(child_pid, &status, WNOHANG), - SyscallFailsWithErrno(ECHILD)); - - if (exited) { - RETURN_IF_ERRNO(exited(child_pid)); - } - - return NoError(); -} - -// Access the file returned by name when a subprocess is running. -PosixError AccessWhileRunning(std::function<std::string(int pid)> name, - int flags, std::function<void(int fd)> access) { - FileDescriptor fd; - return WithSubprocess( - [&](int pid) -> PosixError { - // Running. - ASSIGN_OR_RETURN_ERRNO(fd, Open(name(pid), flags)); - - access(fd.get()); - return NoError(); - }, - nullptr, nullptr); -} - -// Access the file returned by name when the a subprocess is zombied. -PosixError AccessWhileZombied(std::function<std::string(int pid)> name, - int flags, std::function<void(int fd)> access) { - FileDescriptor fd; - return WithSubprocess( - [&](int pid) -> PosixError { - // Running. - ASSIGN_OR_RETURN_ERRNO(fd, Open(name(pid), flags)); - return NoError(); - }, - [&](int pid) -> PosixError { - // Zombied. - access(fd.get()); - return NoError(); - }, - nullptr); -} - -// Access the file returned by name when the a subprocess is exited. -PosixError AccessWhileExited(std::function<std::string(int pid)> name, - int flags, std::function<void(int fd)> access) { - FileDescriptor fd; - return WithSubprocess( - [&](int pid) -> PosixError { - // Running. - ASSIGN_OR_RETURN_ERRNO(fd, Open(name(pid), flags)); - return NoError(); - }, - nullptr, - [&](int pid) -> PosixError { - // Exited. - access(fd.get()); - return NoError(); - }); -} - -// ReadFd(fd=/proc/PID/basename) while PID is running. -int ReadWhileRunning(std::string const& basename, void* buf, size_t count) { - int ret = 0; - int err = 0; - EXPECT_NO_ERRNO(AccessWhileRunning( - [&](int pid) -> std::string { - return absl::StrCat("/proc/", pid, "/", basename); - }, - O_RDONLY, - [&](int fd) { - ret = ReadFd(fd, buf, count); - err = errno; - })); - errno = err; - return ret; -} - -// ReadFd(fd=/proc/PID/basename) while PID is zombied. -int ReadWhileZombied(std::string const& basename, void* buf, size_t count) { - int ret = 0; - int err = 0; - EXPECT_NO_ERRNO(AccessWhileZombied( - [&](int pid) -> std::string { - return absl::StrCat("/proc/", pid, "/", basename); - }, - O_RDONLY, - [&](int fd) { - ret = ReadFd(fd, buf, count); - err = errno; - })); - errno = err; - return ret; -} - -// ReadFd(fd=/proc/PID/basename) while PID is exited. -int ReadWhileExited(std::string const& basename, void* buf, size_t count) { - int ret = 0; - int err = 0; - EXPECT_NO_ERRNO(AccessWhileExited( - [&](int pid) -> std::string { - return absl::StrCat("/proc/", pid, "/", basename); - }, - O_RDONLY, - [&](int fd) { - ret = ReadFd(fd, buf, count); - err = errno; - })); - errno = err; - return ret; -} - -// readlinkat(fd=/proc/PID/, basename) while PID is running. -int ReadlinkWhileRunning(std::string const& basename, char* buf, size_t count) { - int ret = 0; - int err = 0; - EXPECT_NO_ERRNO(AccessWhileRunning( - [&](int pid) -> std::string { return absl::StrCat("/proc/", pid, "/"); }, - O_DIRECTORY, - [&](int fd) { - ret = readlinkat(fd, basename.c_str(), buf, count); - err = errno; - })); - errno = err; - return ret; -} - -// readlinkat(fd=/proc/PID/, basename) while PID is zombied. -int ReadlinkWhileZombied(std::string const& basename, char* buf, size_t count) { - int ret = 0; - int err = 0; - EXPECT_NO_ERRNO(AccessWhileZombied( - [&](int pid) -> std::string { return absl::StrCat("/proc/", pid, "/"); }, - O_DIRECTORY, - [&](int fd) { - ret = readlinkat(fd, basename.c_str(), buf, count); - err = errno; - })); - errno = err; - return ret; -} - -// readlinkat(fd=/proc/PID/, basename) while PID is exited. -int ReadlinkWhileExited(std::string const& basename, char* buf, size_t count) { - int ret = 0; - int err = 0; - EXPECT_NO_ERRNO(AccessWhileExited( - [&](int pid) -> std::string { return absl::StrCat("/proc/", pid, "/"); }, - O_DIRECTORY, - [&](int fd) { - ret = readlinkat(fd, basename.c_str(), buf, count); - err = errno; - })); - errno = err; - return ret; -} - -TEST(ProcTest, NotFoundInRoot) { - struct stat s; - EXPECT_THAT(stat("/proc/foobar", &s), SyscallFailsWithErrno(ENOENT)); -} - -TEST(ProcSelfTest, IsThreadGroupLeader) { - ScopedThread([] { - const pid_t tgid = getpid(); - const pid_t tid = syscall(SYS_gettid); - EXPECT_NE(tgid, tid); - auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self")); - EXPECT_EQ(link, absl::StrCat(tgid)); - }); -} - -TEST(ProcThreadSelfTest, Basic) { - const pid_t tgid = getpid(); - const pid_t tid = syscall(SYS_gettid); - EXPECT_EQ(tgid, tid); - auto link_threadself = - ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/thread-self")); - EXPECT_EQ(link_threadself, absl::StrCat(tgid, "/task/", tid)); - // Just read one file inside thread-self to ensure that the link is valid. - auto link_threadself_exe = - ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/thread-self/exe")); - auto link_procself_exe = - ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe")); - EXPECT_EQ(link_threadself_exe, link_procself_exe); -} - -TEST(ProcThreadSelfTest, Thread) { - ScopedThread([] { - const pid_t tgid = getpid(); - const pid_t tid = syscall(SYS_gettid); - EXPECT_NE(tgid, tid); - auto link_threadself = - ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/thread-self")); - - EXPECT_EQ(link_threadself, absl::StrCat(tgid, "/task/", tid)); - // Just read one file inside thread-self to ensure that the link is valid. - auto link_threadself_exe = - ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/thread-self/exe")); - auto link_procself_exe = - ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe")); - EXPECT_EQ(link_threadself_exe, link_procself_exe); - // A thread should not have "/proc/<tid>/task". - struct stat s; - EXPECT_THAT(stat("/proc/thread-self/task", &s), - SyscallFailsWithErrno(ENOENT)); - }); -} - -// Returns the /proc/PID/maps entry for the MAP_PRIVATE | MAP_ANONYMOUS mapping -// m with start address addr and length len. -std::string AnonymousMapsEntry(uintptr_t addr, size_t len, int prot) { - return absl::StrCat(absl::Hex(addr, absl::PadSpec::kZeroPad8), "-", - absl::Hex(addr + len, absl::PadSpec::kZeroPad8), " ", - prot & PROT_READ ? "r" : "-", - prot & PROT_WRITE ? "w" : "-", - prot & PROT_EXEC ? "x" : "-", "p 00000000 00:00 0 "); -} - -std::string AnonymousMapsEntryForMapping(const Mapping& m, int prot) { - return AnonymousMapsEntry(m.addr(), m.len(), prot); -} - -PosixErrorOr<std::map<uint64_t, uint64_t>> ReadProcSelfAuxv() { - std::string auxv_file; - RETURN_IF_ERRNO(GetContents("/proc/self/auxv", &auxv_file)); - const Elf64_auxv_t* auxv_data = - reinterpret_cast<const Elf64_auxv_t*>(auxv_file.data()); - std::map<uint64_t, uint64_t> auxv_entries; - for (int i = 0; auxv_data[i].a_type != AT_NULL; i++) { - auto a_type = auxv_data[i].a_type; - EXPECT_EQ(0, auxv_entries.count(a_type)) << "a_type: " << a_type; - auxv_entries.emplace(a_type, auxv_data[i].a_un.a_val); - } - return auxv_entries; -} - -TEST(ProcSelfAuxv, EntryPresence) { - auto auxv_entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfAuxv()); - - EXPECT_EQ(auxv_entries.count(AT_ENTRY), 1); - EXPECT_EQ(auxv_entries.count(AT_PHDR), 1); - EXPECT_EQ(auxv_entries.count(AT_PHENT), 1); - EXPECT_EQ(auxv_entries.count(AT_PHNUM), 1); - EXPECT_EQ(auxv_entries.count(AT_BASE), 1); - EXPECT_EQ(auxv_entries.count(AT_UID), 1); - EXPECT_EQ(auxv_entries.count(AT_EUID), 1); - EXPECT_EQ(auxv_entries.count(AT_GID), 1); - EXPECT_EQ(auxv_entries.count(AT_EGID), 1); - EXPECT_EQ(auxv_entries.count(AT_SECURE), 1); - EXPECT_EQ(auxv_entries.count(AT_CLKTCK), 1); - EXPECT_EQ(auxv_entries.count(AT_RANDOM), 1); - EXPECT_EQ(auxv_entries.count(AT_EXECFN), 1); - EXPECT_EQ(auxv_entries.count(AT_PAGESZ), 1); - EXPECT_EQ(auxv_entries.count(AT_SYSINFO_EHDR), 1); -} - -TEST(ProcSelfAuxv, EntryValues) { - auto proc_auxv = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfAuxv()); - - // We need to find the ELF auxiliary vector. The section of memory pointed to - // by envp contains some pointers to non-null pointers, followed by a single - // pointer to a null pointer, followed by the auxiliary vector. - char** envpi = environ; - while (*envpi) { - ++envpi; - } - - const Elf64_auxv_t* envp_auxv = - reinterpret_cast<const Elf64_auxv_t*>(envpi + 1); - int i; - for (i = 0; envp_auxv[i].a_type != AT_NULL; i++) { - auto a_type = envp_auxv[i].a_type; - EXPECT_EQ(proc_auxv.count(a_type), 1); - EXPECT_EQ(proc_auxv[a_type], envp_auxv[i].a_un.a_val) - << "a_type: " << a_type; - } - EXPECT_EQ(i, proc_auxv.size()); -} - -// Just open and read /proc/self/maps, check that we can find [stack] -TEST(ProcSelfMaps, Basic) { - auto proc_self_maps = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - - std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n'); - std::vector<std::string> stacks; - // Make sure there's a stack in there. - for (const auto& str : strings) { - if (str.find("[stack]") != std::string::npos) { - stacks.push_back(str); - } - } - ASSERT_EQ(1, stacks.size()) << "[stack] not found in: " << proc_self_maps; - // Linux pads to 73 characters then we add 7. - EXPECT_EQ(80, stacks[0].length()); -} - -TEST(ProcSelfMaps, Map1) { - Mapping mapping = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_READ, MAP_PRIVATE)); - auto proc_self_maps = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n'); - std::vector<std::string> addrs; - // Make sure if is listed. - for (const auto& str : strings) { - if (str == AnonymousMapsEntryForMapping(mapping, PROT_READ)) { - addrs.push_back(str); - } - } - ASSERT_EQ(1, addrs.size()); -} - -TEST(ProcSelfMaps, Map2) { - // NOTE(magi): The permissions must be different or the pages will get merged. - Mapping map1 = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_EXEC, MAP_PRIVATE)); - Mapping map2 = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_WRITE, MAP_PRIVATE)); - - auto proc_self_maps = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n'); - std::vector<std::string> addrs; - // Make sure if is listed. - for (const auto& str : strings) { - if (str == AnonymousMapsEntryForMapping(map1, PROT_READ | PROT_EXEC)) { - addrs.push_back(str); - } - } - ASSERT_EQ(1, addrs.size()); - addrs.clear(); - for (const auto& str : strings) { - if (str == AnonymousMapsEntryForMapping(map2, PROT_WRITE)) { - addrs.push_back(str); - } - } - ASSERT_EQ(1, addrs.size()); -} - -TEST(ProcSelfMaps, MapUnmap) { - Mapping map1 = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kPageSize, PROT_READ | PROT_EXEC, MAP_PRIVATE)); - Mapping map2 = - ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_WRITE, MAP_PRIVATE)); - - auto proc_self_maps = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n'); - std::vector<std::string> addrs; - // Make sure if is listed. - for (const auto& str : strings) { - if (str == AnonymousMapsEntryForMapping(map1, PROT_READ | PROT_EXEC)) { - addrs.push_back(str); - } - } - ASSERT_EQ(1, addrs.size()) << proc_self_maps; - addrs.clear(); - for (const auto& str : strings) { - if (str == AnonymousMapsEntryForMapping(map2, PROT_WRITE)) { - addrs.push_back(str); - } - } - ASSERT_EQ(1, addrs.size()); - - map2.reset(); - - // Read it again. - proc_self_maps = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - strings = absl::StrSplit(proc_self_maps, '\n'); - // First entry should be there. - addrs.clear(); - for (const auto& str : strings) { - if (str == AnonymousMapsEntryForMapping(map1, PROT_READ | PROT_EXEC)) { - addrs.push_back(str); - } - } - ASSERT_EQ(1, addrs.size()); - addrs.clear(); - // But not the second. - for (const auto& str : strings) { - if (str == AnonymousMapsEntryForMapping(map2, PROT_WRITE)) { - addrs.push_back(str); - } - } - ASSERT_EQ(0, addrs.size()); -} - -TEST(ProcSelfMaps, Mprotect) { - // FIXME(jamieliu): Linux's mprotect() sometimes fails to merge VMAs in this - // case. - SKIP_IF(!IsRunningOnGvisor()); - - // Reserve 5 pages of address space. - Mapping m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(5 * kPageSize, PROT_NONE, MAP_PRIVATE)); - - // Change the permissions on the middle 3 pages. (The first and last pages may - // be merged with other vmas on either side, so they aren't tested directly; - // they just ensure that the middle 3 pages are bracketed by VMAs with - // incompatible permissions.) - ASSERT_THAT(mprotect(reinterpret_cast<void*>(m.addr() + kPageSize), - 3 * kPageSize, PROT_READ), - SyscallSucceeds()); - - // Check that the middle 3 pages make up a single VMA. - auto proc_self_maps = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n'); - EXPECT_THAT(strings, Contains(AnonymousMapsEntry(m.addr() + kPageSize, - 3 * kPageSize, PROT_READ))); - - // Change the permissions on the middle page only. - ASSERT_THAT(mprotect(reinterpret_cast<void*>(m.addr() + 2 * kPageSize), - kPageSize, PROT_READ | PROT_WRITE), - SyscallSucceeds()); - - // Check that the single VMA has been split into 3 VMAs. - proc_self_maps = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - strings = absl::StrSplit(proc_self_maps, '\n'); - EXPECT_THAT( - strings, - IsSupersetOf( - {AnonymousMapsEntry(m.addr() + kPageSize, kPageSize, PROT_READ), - AnonymousMapsEntry(m.addr() + 2 * kPageSize, kPageSize, - PROT_READ | PROT_WRITE), - AnonymousMapsEntry(m.addr() + 3 * kPageSize, kPageSize, - PROT_READ)})); - - // Change the permissions on the middle page back. - ASSERT_THAT(mprotect(reinterpret_cast<void*>(m.addr() + 2 * kPageSize), - kPageSize, PROT_READ), - SyscallSucceeds()); - - // Check that the 3 VMAs have been merged back into a single VMA. - proc_self_maps = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - strings = absl::StrSplit(proc_self_maps, '\n'); - EXPECT_THAT(strings, Contains(AnonymousMapsEntry(m.addr() + kPageSize, - 3 * kPageSize, PROT_READ))); -} - -TEST(ProcSelfFd, OpenFd) { - int pipe_fds[2]; - ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds()); - - // Reopen the write end. - const std::string path = absl::StrCat("/proc/self/fd/", pipe_fds[1]); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_WRONLY)); - - // Ensure that a read/write works. - const std::string data = "hello"; - std::unique_ptr<char[]> buffer(new char[data.size()]); - EXPECT_THAT(write(fd.get(), data.c_str(), data.size()), - SyscallSucceedsWithValue(5)); - EXPECT_THAT(read(pipe_fds[0], buffer.get(), data.size()), - SyscallSucceedsWithValue(5)); - EXPECT_EQ(strncmp(buffer.get(), data.c_str(), data.size()), 0); - - // Cleanup. - ASSERT_THAT(close(pipe_fds[0]), SyscallSucceeds()); - ASSERT_THAT(close(pipe_fds[1]), SyscallSucceeds()); -} - -TEST(ProcSelfFdInfo, CorrectFds) { - // Make sure there is at least one open file. - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); - - // Get files in /proc/self/fd. - auto fd_files = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/fd", false)); - - // Get files in /proc/self/fdinfo. - auto fdinfo_files = - ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/fdinfo", false)); - - // They should contain the same fds. - EXPECT_THAT(fd_files, UnorderedElementsAreArray(fdinfo_files)); - - // Both should contain fd. - auto fd_s = absl::StrCat(fd.get()); - EXPECT_THAT(fd_files, Contains(fd_s)); -} - -TEST(ProcSelfFdInfo, Flags) { - std::string path = NewTempAbsPath(); - - // Create file here with O_CREAT to test that O_CREAT does not appear in - // fdinfo flags. - int flags = O_CREAT | O_RDWR | O_APPEND | O_CLOEXEC; - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, flags, 0644)); - - // Automatically delete path. - TempPath temp_path(path); - - // O_CREAT does not appear in fdinfo flags. - flags &= ~O_CREAT; - - // O_LARGEFILE always appears (on x86_64). - flags |= kOLargeFile; - - auto fd_info = ASSERT_NO_ERRNO_AND_VALUE( - GetContents(absl::StrCat("/proc/self/fdinfo/", fd.get()))); - EXPECT_THAT(fd_info, HasSubstr(absl::StrFormat("flags:\t%#o", flags))); -} - -TEST(ProcSelfExe, Absolute) { - auto exe = ASSERT_NO_ERRNO_AND_VALUE( - ReadLink(absl::StrCat("/proc/", getpid(), "/exe"))); - EXPECT_EQ(exe[0], '/'); -} - -// Sanity check for /proc/cpuinfo fields that must be present. -TEST(ProcCpuinfo, RequiredFieldsArePresent) { - std::string proc_cpuinfo = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/cpuinfo")); - ASSERT_FALSE(proc_cpuinfo.empty()); - std::vector<std::string> cpuinfo_fields = absl::StrSplit(proc_cpuinfo, '\n'); - - // Check that the usual fields are there. We don't really care about the - // contents. - for (const std::string& field : required_fields) { - EXPECT_THAT(proc_cpuinfo, HasSubstr(field)); - } -} - -TEST(ProcCpuinfo, DeniesWrite) { - EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES)); -} - -// Sanity checks that uptime is present. -TEST(ProcUptime, IsPresent) { - std::string proc_uptime = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/uptime")); - ASSERT_FALSE(proc_uptime.empty()); - std::vector<std::string> uptime_parts = absl::StrSplit(proc_uptime, ' '); - - // Parse once. - double uptime0, uptime1, idletime0, idletime1; - ASSERT_TRUE(absl::SimpleAtod(uptime_parts[0], &uptime0)); - ASSERT_TRUE(absl::SimpleAtod(uptime_parts[1], &idletime0)); - - // Sleep for one second. - absl::SleepFor(absl::Seconds(1)); - - // Parse again. - proc_uptime = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/uptime")); - ASSERT_FALSE(proc_uptime.empty()); - uptime_parts = absl::StrSplit(proc_uptime, ' '); - ASSERT_TRUE(absl::SimpleAtod(uptime_parts[0], &uptime1)); - ASSERT_TRUE(absl::SimpleAtod(uptime_parts[1], &idletime1)); - - // Sanity check. - // - // We assert that between 0.99 and 59.99 seconds have passed. If more than a - // minute has passed, then we must be executing really, really slowly. - EXPECT_GE(uptime0, 0.0); - EXPECT_GE(idletime0, 0.0); - EXPECT_GT(uptime1, uptime0); - EXPECT_GE(uptime1, uptime0 + 0.99); - EXPECT_LE(uptime1, uptime0 + 59.99); - EXPECT_GE(idletime1, idletime0); -} - -TEST(ProcMeminfo, ContainsBasicFields) { - std::string proc_meminfo = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/meminfo")); - EXPECT_THAT(proc_meminfo, AllOf(ContainsRegex(R"(MemTotal:\s+[0-9]+ kB)"), - ContainsRegex(R"(MemFree:\s+[0-9]+ kB)"))); -} - -TEST(ProcStat, ContainsBasicFields) { - std::string proc_stat = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/stat")); - - std::vector<std::string> names; - for (auto const& line : absl::StrSplit(proc_stat, '\n')) { - std::vector<std::string> fields = - absl::StrSplit(line, ' ', absl::SkipWhitespace()); - if (fields.empty()) { - continue; - } - names.push_back(fields[0]); - } - - EXPECT_THAT(names, - IsSupersetOf({"cpu", "intr", "ctxt", "btime", "processes", - "procs_running", "procs_blocked", "softirq"})); -} - -TEST(ProcStat, EndsWithNewline) { - std::string proc_stat = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/stat")); - EXPECT_EQ(proc_stat.back(), '\n'); -} - -TEST(ProcStat, Fields) { - std::string proc_stat = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/stat")); - - std::vector<std::string> names; - for (auto const& line : absl::StrSplit(proc_stat, '\n')) { - std::vector<std::string> fields = - absl::StrSplit(line, ' ', absl::SkipWhitespace()); - if (fields.empty()) { - continue; - } - - if (absl::StartsWith(fields[0], "cpu")) { - // As of Linux 3.11, each CPU entry has 10 fields, plus the name. - EXPECT_GE(fields.size(), 11) << proc_stat; - } else if (fields[0] == "ctxt") { - // Single field. - EXPECT_EQ(fields.size(), 2) << proc_stat; - } else if (fields[0] == "btime") { - // Single field. - EXPECT_EQ(fields.size(), 2) << proc_stat; - } else if (fields[0] == "itime") { - // Single field. - ASSERT_EQ(fields.size(), 2) << proc_stat; - // This is the only floating point field. - double val; - EXPECT_TRUE(absl::SimpleAtod(fields[1], &val)) << proc_stat; - continue; - } else if (fields[0] == "processes") { - // Single field. - EXPECT_EQ(fields.size(), 2) << proc_stat; - } else if (fields[0] == "procs_running") { - // Single field. - EXPECT_EQ(fields.size(), 2) << proc_stat; - } else if (fields[0] == "procs_blocked") { - // Single field. - EXPECT_EQ(fields.size(), 2) << proc_stat; - } else if (fields[0] == "softirq") { - // As of Linux 3.11, there are 10 softirqs. 12 fields for name + total. - EXPECT_GE(fields.size(), 12) << proc_stat; - } - - // All fields besides itime are valid base 10 numbers. - for (size_t i = 1; i < fields.size(); i++) { - uint64_t val; - EXPECT_TRUE(absl::SimpleAtoi(fields[i], &val)) << proc_stat; - } - } -} - -TEST(ProcLoadavg, EndsWithNewline) { - std::string proc_loadvg = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/loadavg")); - EXPECT_EQ(proc_loadvg.back(), '\n'); -} - -TEST(ProcLoadavg, Fields) { - std::string proc_loadvg = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/loadavg")); - std::vector<std::string> lines = absl::StrSplit(proc_loadvg, '\n'); - - // Single line. - EXPECT_EQ(lines.size(), 2) << proc_loadvg; - - std::vector<std::string> fields = - absl::StrSplit(lines[0], absl::ByAnyChar(" /"), absl::SkipWhitespace()); - - // Six fields. - EXPECT_EQ(fields.size(), 6) << proc_loadvg; - - double val; - uint64_t val2; - // First three fields are floating point numbers. - EXPECT_TRUE(absl::SimpleAtod(fields[0], &val)) << proc_loadvg; - EXPECT_TRUE(absl::SimpleAtod(fields[1], &val)) << proc_loadvg; - EXPECT_TRUE(absl::SimpleAtod(fields[2], &val)) << proc_loadvg; - // Rest of the fields are valid base 10 numbers. - EXPECT_TRUE(absl::SimpleAtoi(fields[3], &val2)) << proc_loadvg; - EXPECT_TRUE(absl::SimpleAtoi(fields[4], &val2)) << proc_loadvg; - EXPECT_TRUE(absl::SimpleAtoi(fields[5], &val2)) << proc_loadvg; -} - -// NOTE: Tests in priority.cc also check certain priority related fields in -// /proc/self/stat. - -class ProcPidStatTest : public ::testing::TestWithParam<std::string> {}; - -TEST_P(ProcPidStatTest, HasBasicFields) { - std::string proc_pid_stat = ASSERT_NO_ERRNO_AND_VALUE( - GetContents(absl::StrCat("/proc/", GetParam(), "/stat"))); - - ASSERT_FALSE(proc_pid_stat.empty()); - std::vector<std::string> fields = absl::StrSplit(proc_pid_stat, ' '); - ASSERT_GE(fields.size(), 24); - EXPECT_EQ(absl::StrCat(getpid()), fields[0]); - // fields[1] is the thread name. - EXPECT_EQ("R", fields[2]); // task state - EXPECT_EQ(absl::StrCat(getppid()), fields[3]); - - // If the test starts up quickly, then the process start time and the kernel - // boot time will be very close, and the proc starttime field (which is the - // delta of the two times) will be 0. For that unfortunate reason, we can - // only check that starttime >= 0, and not that it is strictly > 0. - uint64_t starttime; - ASSERT_TRUE(absl::SimpleAtoi(fields[21], &starttime)); - EXPECT_GE(starttime, 0); - - uint64_t vss; - ASSERT_TRUE(absl::SimpleAtoi(fields[22], &vss)); - EXPECT_GT(vss, 0); - - uint64_t rss; - ASSERT_TRUE(absl::SimpleAtoi(fields[23], &rss)); - EXPECT_GT(rss, 0); - - uint64_t rsslim; - ASSERT_TRUE(absl::SimpleAtoi(fields[24], &rsslim)); - EXPECT_GT(rsslim, 0); -} - -INSTANTIATE_TEST_SUITE_P(SelfAndNumericPid, ProcPidStatTest, - ::testing::Values("self", absl::StrCat(getpid()))); - -using ProcPidStatmTest = ::testing::TestWithParam<std::string>; - -TEST_P(ProcPidStatmTest, HasBasicFields) { - std::string proc_pid_statm = ASSERT_NO_ERRNO_AND_VALUE( - GetContents(absl::StrCat("/proc/", GetParam(), "/statm"))); - ASSERT_FALSE(proc_pid_statm.empty()); - std::vector<std::string> fields = absl::StrSplit(proc_pid_statm, ' '); - ASSERT_GE(fields.size(), 7); - - uint64_t vss; - ASSERT_TRUE(absl::SimpleAtoi(fields[0], &vss)); - EXPECT_GT(vss, 0); - - uint64_t rss; - ASSERT_TRUE(absl::SimpleAtoi(fields[1], &rss)); - EXPECT_GT(rss, 0); -} - -INSTANTIATE_TEST_SUITE_P(SelfAndNumericPid, ProcPidStatmTest, - ::testing::Values("self", absl::StrCat(getpid()))); - -PosixErrorOr<uint64_t> CurrentRSS() { - ASSIGN_OR_RETURN_ERRNO(auto proc_self_stat, GetContents("/proc/self/stat")); - if (proc_self_stat.empty()) { - return PosixError(EINVAL, "empty /proc/self/stat"); - } - - std::vector<std::string> fields = absl::StrSplit(proc_self_stat, ' '); - if (fields.size() < 24) { - return PosixError( - EINVAL, - absl::StrCat("/proc/self/stat has too few fields: ", proc_self_stat)); - } - - uint64_t rss; - if (!absl::SimpleAtoi(fields[23], &rss)) { - return PosixError( - EINVAL, absl::StrCat("/proc/self/stat RSS field is not a number: ", - fields[23])); - } - - // RSS is given in number of pages. - return rss * kPageSize; -} - -// The size of mapping created by MapPopulateRSS. -constexpr uint64_t kMappingSize = 100 << 20; - -// Tolerance on RSS comparisons to account for background thread mappings, -// reclaimed pages, newly faulted pages, etc. -constexpr uint64_t kRSSTolerance = 5 << 20; - -// Capture RSS before and after an anonymous mapping with passed prot. -void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) { - *before = ASSERT_NO_ERRNO_AND_VALUE(CurrentRSS()); - - // N.B. The kernel asynchronously accumulates per-task RSS counters into the - // mm RSS, which is exposed by /proc/PID/stat. Task exit is a synchronization - // point (kernel/exit.c:do_exit -> sync_mm_rss), so perform the mapping on - // another thread to ensure it is reflected in RSS after the thread exits. - Mapping mapping; - ScopedThread t([&mapping, prot] { - mapping = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(kMappingSize, prot, MAP_PRIVATE | MAP_POPULATE)); - }); - t.Join(); - - *after = ASSERT_NO_ERRNO_AND_VALUE(CurrentRSS()); -} - -// TODO(b/73896574): Test for PROT_READ + MAP_POPULATE anonymous mappings. Their -// semantics are more subtle: -// -// Small pages -> Zero page mapped, not counted in RSS -// (mm/memory.c:do_anonymous_page). -// -// Huge pages (THP enabled, use_zero_page=0) -> Pages committed -// (mm/memory.c:__handle_mm_fault -> create_huge_pmd). -// -// Huge pages (THP enabled, use_zero_page=1) -> Zero page mapped, not counted in -// RSS (mm/huge_memory.c:do_huge_pmd_anonymous_page). - -// PROT_WRITE + MAP_POPULATE anonymous mappings are always committed. -TEST(ProcSelfStat, PopulateWriteRSS) { - uint64_t before, after; - MapPopulateRSS(PROT_READ | PROT_WRITE, &before, &after); - - // Mapping is committed. - EXPECT_NEAR(before + kMappingSize, after, kRSSTolerance); -} - -// PROT_NONE + MAP_POPULATE anonymous mappings are never committed. -TEST(ProcSelfStat, PopulateNoneRSS) { - uint64_t before, after; - MapPopulateRSS(PROT_NONE, &before, &after); - - // Mapping not committed. - EXPECT_NEAR(before, after, kRSSTolerance); -} - -// Returns the calling thread's name. -PosixErrorOr<std::string> ThreadName() { - // "The buffer should allow space for up to 16 bytes; the returned std::string - // will be null-terminated if it is shorter than that." - prctl(2). But we - // always want the thread name to be null-terminated. - char thread_name[17]; - int rc = prctl(PR_GET_NAME, thread_name, 0, 0, 0); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "prctl(PR_GET_NAME)"); - } - thread_name[16] = '\0'; - return std::string(thread_name); -} - -// Parses the contents of a /proc/[pid]/status file into a collection of -// key-value pairs. -PosixErrorOr<std::map<std::string, std::string>> ParseProcStatus( - absl::string_view status_str) { - std::map<std::string, std::string> fields; - for (absl::string_view const line : - absl::StrSplit(status_str, '\n', absl::SkipWhitespace())) { - const std::pair<absl::string_view, absl::string_view> kv = - absl::StrSplit(line, absl::MaxSplits(":\t", 1)); - if (kv.first.empty()) { - return PosixError( - EINVAL, absl::StrCat("failed to parse key in line \"", line, "\"")); - } - std::string key(kv.first); - if (fields.count(key)) { - return PosixError(EINVAL, - absl::StrCat("duplicate key \"", kv.first, "\"")); - } - std::string value(kv.second); - absl::StripLeadingAsciiWhitespace(&value); - fields.emplace(std::move(key), std::move(value)); - } - return fields; -} - -TEST(ParseProcStatusTest, ParsesSimpleStatusFileWithMixedWhitespaceCorrectly) { - EXPECT_THAT( - ParseProcStatus( - "Name:\tinit\nState:\tS (sleeping)\nCapEff:\t 0000001fffffffff\n"), - IsPosixErrorOkAndHolds(UnorderedElementsAre( - Pair("Name", "init"), Pair("State", "S (sleeping)"), - Pair("CapEff", "0000001fffffffff")))); -} - -TEST(ParseProcStatusTest, DetectsDuplicateKeys) { - auto proc_status_or = ParseProcStatus("Name:\tfoo\nName:\tfoo\n"); - EXPECT_THAT(proc_status_or, - PosixErrorIs(EINVAL, ::testing::StrEq("duplicate key \"Name\""))); -} - -TEST(ParseProcStatusTest, DetectsMissingTabs) { - EXPECT_THAT(ParseProcStatus("Name:foo\nPid: 1\n"), - IsPosixErrorOkAndHolds(UnorderedElementsAre(Pair("Name:foo", ""), - Pair("Pid: 1", "")))); -} - -TEST(ProcPidStatusTest, HasBasicFields) { - // Do this on a separate thread since we want tgid != tid. - ScopedThread([] { - const pid_t tgid = getpid(); - const pid_t tid = syscall(SYS_gettid); - EXPECT_NE(tgid, tid); - const auto thread_name = ASSERT_NO_ERRNO_AND_VALUE(ThreadName()); - - std::string status_str = ASSERT_NO_ERRNO_AND_VALUE( - GetContents(absl::StrCat("/proc/", tid, "/status"))); - - ASSERT_FALSE(status_str.empty()); - const auto status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(status_str)); - EXPECT_THAT(status, IsSupersetOf({Pair("Name", thread_name), - Pair("Tgid", absl::StrCat(tgid)), - Pair("Pid", absl::StrCat(tid)), - Pair("PPid", absl::StrCat(getppid()))})); - }); -} - -TEST(ProcPidStatusTest, StateRunning) { - // Task must be running when reading the file. - const pid_t tid = syscall(SYS_gettid); - std::string status_str = ASSERT_NO_ERRNO_AND_VALUE( - GetContents(absl::StrCat("/proc/", tid, "/status"))); - - EXPECT_THAT(ParseProcStatus(status_str), - IsPosixErrorOkAndHolds(Contains(Pair("State", "R (running)")))); -} - -TEST(ProcPidStatusTest, StateSleeping_NoRandomSave) { - // Starts a child process that blocks and checks that State is sleeping. - auto res = WithSubprocess( - [&](int pid) -> PosixError { - // Because this test is timing based we will disable cooperative saving - // and the test itself also has random saving disabled. - const DisableSave ds; - // Try multiple times in case the child isn't sleeping when status file - // is read. - MonotonicTimer timer; - timer.Start(); - for (;;) { - ASSIGN_OR_RETURN_ERRNO( - std::string status_str, - GetContents(absl::StrCat("/proc/", pid, "/status"))); - ASSIGN_OR_RETURN_ERRNO(auto map, ParseProcStatus(status_str)); - if (map["State"] == std::string("S (sleeping)")) { - // Test passed! - return NoError(); - } - if (timer.Duration() > absl::Seconds(10)) { - return PosixError(ETIMEDOUT, "Timeout waiting for child to sleep"); - } - absl::SleepFor(absl::Milliseconds(10)); - } - }, - nullptr, nullptr); - ASSERT_NO_ERRNO(res); -} - -TEST(ProcPidStatusTest, ValuesAreTabDelimited) { - std::string status_str = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/status")); - ASSERT_FALSE(status_str.empty()); - for (absl::string_view const line : - absl::StrSplit(status_str, '\n', absl::SkipWhitespace())) { - EXPECT_NE(std::string::npos, line.find(":\t")); - } -} - -// Threads properly counts running threads. -// -// TODO(mpratt): Test zombied threads while the thread group leader is still -// running with generalized fork and clone children from the wait test. -TEST(ProcPidStatusTest, Threads) { - char buf[4096] = {}; - EXPECT_THAT(ReadWhileRunning("status", buf, sizeof(buf) - 1), - SyscallSucceedsWithValue(Gt(0))); - - auto status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(buf)); - auto it = status.find("Threads"); - ASSERT_NE(it, status.end()); - int threads = -1; - EXPECT_TRUE(absl::SimpleAtoi(it->second, &threads)) - << "Threads value " << it->second << " is not a number"; - // Don't make assumptions about the exact number of threads, as it may not be - // constant. - EXPECT_GE(threads, 1); - - memset(buf, 0, sizeof(buf)); - EXPECT_THAT(ReadWhileZombied("status", buf, sizeof(buf) - 1), - SyscallSucceedsWithValue(Gt(0))); - - status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(buf)); - it = status.find("Threads"); - ASSERT_NE(it, status.end()); - threads = -1; - EXPECT_TRUE(absl::SimpleAtoi(it->second, &threads)) - << "Threads value " << it->second << " is not a number"; - // There must be only the thread group leader remaining, zombied. - EXPECT_EQ(threads, 1); -} - -// Returns true if all characters in s are digits. -bool IsDigits(absl::string_view s) { - return std::all_of(s.begin(), s.end(), absl::ascii_isdigit); -} - -TEST(ProcPidStatTest, VmStats) { - std::string status_str = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/status")); - ASSERT_FALSE(status_str.empty()); - auto status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(status_str)); - - const auto vss_it = status.find("VmSize"); - ASSERT_NE(vss_it, status.end()); - - absl::string_view vss_str(vss_it->second); - - // Room for the " kB" suffix plus at least one digit. - ASSERT_GT(vss_str.length(), 3); - EXPECT_TRUE(absl::EndsWith(vss_str, " kB")); - // Everything else is part of a number. - EXPECT_TRUE(IsDigits(vss_str.substr(0, vss_str.length() - 3))) << vss_str; - // ... which is not 0. - EXPECT_NE('0', vss_str[0]); - - const auto rss_it = status.find("VmRSS"); - ASSERT_NE(rss_it, status.end()); - - absl::string_view rss_str(rss_it->second); - - // Room for the " kB" suffix plus at least one digit. - ASSERT_GT(rss_str.length(), 3); - EXPECT_TRUE(absl::EndsWith(rss_str, " kB")); - // Everything else is part of a number. - EXPECT_TRUE(IsDigits(rss_str.substr(0, rss_str.length() - 3))) << rss_str; - // ... which is not 0. - EXPECT_NE('0', rss_str[0]); - - const auto data_it = status.find("VmData"); - ASSERT_NE(data_it, status.end()); - - absl::string_view data_str(data_it->second); - - // Room for the " kB" suffix plus at least one digit. - ASSERT_GT(data_str.length(), 3); - EXPECT_TRUE(absl::EndsWith(data_str, " kB")); - // Everything else is part of a number. - EXPECT_TRUE(IsDigits(data_str.substr(0, data_str.length() - 3))) << data_str; - // ... which is not 0. - EXPECT_NE('0', data_str[0]); -} - -// Parse an array of NUL-terminated char* arrays, returning a vector of -// strings. -std::vector<std::string> ParseNulTerminatedStrings(std::string contents) { - EXPECT_EQ('\0', contents.back()); - // The split will leave an empty string if the NUL-byte remains, so pop - // it. - contents.pop_back(); - - return absl::StrSplit(contents, '\0'); -} - -TEST(ProcPidCmdline, MatchesArgv) { - std::vector<std::string> proc_cmdline = ParseNulTerminatedStrings( - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/cmdline"))); - EXPECT_THAT(saved_argv, ContainerEq(proc_cmdline)); -} - -TEST(ProcPidEnviron, MatchesEnviron) { - std::vector<std::string> proc_environ = ParseNulTerminatedStrings( - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/environ"))); - // Get the environment from the environ variable, which we will compare with - // /proc/self/environ. - std::vector<std::string> env; - for (char** v = environ; *v; v++) { - env.push_back(*v); - } - EXPECT_THAT(env, ContainerEq(proc_environ)); -} - -TEST(ProcPidCmdline, SubprocessForkSameCmdline) { - std::vector<std::string> proc_cmdline_parent; - std::vector<std::string> proc_cmdline; - proc_cmdline_parent = ParseNulTerminatedStrings( - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/cmdline"))); - auto res = WithSubprocess( - [&](int pid) -> PosixError { - ASSIGN_OR_RETURN_ERRNO( - auto raw_cmdline, - GetContents(absl::StrCat("/proc/", pid, "/cmdline"))); - proc_cmdline = ParseNulTerminatedStrings(raw_cmdline); - return NoError(); - }, - nullptr, nullptr); - ASSERT_NO_ERRNO(res); - - for (size_t i = 0; i < proc_cmdline_parent.size(); i++) { - EXPECT_EQ(proc_cmdline_parent[i], proc_cmdline[i]); - } -} - -// Test whether /proc/PID/ symlinks can be read for a running process. -TEST(ProcPidSymlink, SubprocessRunning) { - char buf[1]; - - EXPECT_THAT(ReadlinkWhileRunning("exe", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadlinkWhileRunning("ns/net", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadlinkWhileRunning("ns/pid", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadlinkWhileRunning("ns/user", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); -} - -// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux -// on proc files. -TEST(ProcPidSymlink, SubprocessZombied) { - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - char buf[1]; - - int want = EACCES; - if (!IsRunningOnGvisor()) { - auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); - if (version.major == 4 && version.minor > 3) { - want = ENOENT; - } - } - - EXPECT_THAT(ReadlinkWhileZombied("exe", buf, sizeof(buf)), - SyscallFailsWithErrno(want)); - - if (!IsRunningOnGvisor()) { - EXPECT_THAT(ReadlinkWhileZombied("ns/net", buf, sizeof(buf)), - SyscallFailsWithErrno(want)); - } - - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. - // - // ~4.3: Syscall fails with EACCES. - // 4.17 & gVisor: Syscall succeeds and returns 1. - // - // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); - - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. - // - // ~4.3: Syscall fails with EACCES. - // 4.17 & gVisor: Syscall succeeds and returns 1. - // - // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); -} - -// Test whether /proc/PID/ symlinks can be read for an exited process. -TEST(ProcPidSymlink, SubprocessExited) { - // FIXME(gvisor.dev/issue/164): These all succeed on gVisor. - SKIP_IF(IsRunningOnGvisor()); - - char buf[1]; - - EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - - EXPECT_THAT(ReadlinkWhileExited("ns/net", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - - EXPECT_THAT(ReadlinkWhileExited("ns/pid", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - - EXPECT_THAT(ReadlinkWhileExited("ns/user", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); -} - -// /proc/PID/exe points to the correct binary. -TEST(ProcPidExe, Subprocess) { - auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe")); - auto expected_absolute_path = - ASSERT_NO_ERRNO_AND_VALUE(MakeAbsolute(link, "")); - - char actual[PATH_MAX + 1] = {}; - ASSERT_THAT(ReadlinkWhileRunning("exe", actual, sizeof(actual)), - SyscallSucceedsWithValue(Gt(0))); - EXPECT_EQ(actual, expected_absolute_path); -} - -// Test whether /proc/PID/ files can be read for a running process. -TEST(ProcPidFile, SubprocessRunning) { - char buf[1]; - - EXPECT_THAT(ReadWhileRunning("auxv", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("cmdline", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("comm", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("gid_map", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("io", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("maps", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("stat", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("status", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("uid_map", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("oom_score", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileRunning("oom_score_adj", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); -} - -// Test whether /proc/PID/ files can be read for a zombie process. -TEST(ProcPidFile, SubprocessZombie) { - char buf[1]; - - // FIXME(gvisor.dev/issue/164): Loosen requirement due to inconsistent - // behavior on different kernels. - // - // ~4.3: Succeds and returns 0. - // 4.17: Succeeds and returns 1. - // gVisor: Succeeds and returns 0. - EXPECT_THAT(ReadWhileZombied("auxv", buf, sizeof(buf)), SyscallSucceeds()); - - EXPECT_THAT(ReadWhileZombied("cmdline", buf, sizeof(buf)), - SyscallSucceedsWithValue(0)); - - EXPECT_THAT(ReadWhileZombied("comm", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileZombied("gid_map", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileZombied("maps", buf, sizeof(buf)), - SyscallSucceedsWithValue(0)); - - EXPECT_THAT(ReadWhileZombied("stat", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileZombied("status", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileZombied("uid_map", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileZombied("oom_score", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(ReadWhileZombied("oom_score_adj", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. - // - // ~4.3: Fails and returns EACCES. - // gVisor & 4.17: Succeeds and returns 1. - // - // EXPECT_THAT(ReadWhileZombied("io", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); -} - -// Test whether /proc/PID/ files can be read for an exited process. -TEST(ProcPidFile, SubprocessExited) { - char buf[1]; - - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels. - // - // ~4.3: Fails and returns ESRCH. - // gVisor: Fails with ESRCH. - // 4.17: Succeeds and returns 1. - // - // EXPECT_THAT(ReadWhileExited("auxv", buf, sizeof(buf)), - // SyscallFailsWithErrno(ESRCH)); - - EXPECT_THAT(ReadWhileExited("cmdline", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - - if (!IsRunningOnGvisor()) { - // FIXME(gvisor.dev/issue/164): Succeeds on gVisor. - EXPECT_THAT(ReadWhileExited("comm", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - } - - EXPECT_THAT(ReadWhileExited("gid_map", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - if (!IsRunningOnGvisor()) { - // FIXME(gvisor.dev/issue/164): Succeeds on gVisor. - EXPECT_THAT(ReadWhileExited("io", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - } - - if (!IsRunningOnGvisor()) { - // FIXME(gvisor.dev/issue/164): Returns EOF on gVisor. - EXPECT_THAT(ReadWhileExited("maps", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - } - - if (!IsRunningOnGvisor()) { - // FIXME(gvisor.dev/issue/164): Succeeds on gVisor. - EXPECT_THAT(ReadWhileExited("stat", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - } - - if (!IsRunningOnGvisor()) { - // FIXME(gvisor.dev/issue/164): Succeeds on gVisor. - EXPECT_THAT(ReadWhileExited("status", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - } - - EXPECT_THAT(ReadWhileExited("uid_map", buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - if (!IsRunningOnGvisor()) { - // FIXME(gvisor.dev/issue/164): Succeeds on gVisor. - EXPECT_THAT(ReadWhileExited("oom_score", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); - } - - EXPECT_THAT(ReadWhileExited("oom_score_adj", buf, sizeof(buf)), - SyscallFailsWithErrno(ESRCH)); -} - -PosixError DirContainsImpl(absl::string_view path, - const std::vector<std::string>& targets, - bool strict) { - ASSIGN_OR_RETURN_ERRNO(auto listing, ListDir(path, false)); - bool success = true; - - for (auto& expected_entry : targets) { - auto cursor = std::find(listing.begin(), listing.end(), expected_entry); - if (cursor == listing.end()) { - success = false; - } - } - - if (!success) { - return PosixError( - ENOENT, - absl::StrCat("Failed to find one or more paths in '", path, "'")); - } - - if (strict) { - if (targets.size() != listing.size()) { - return PosixError( - EINVAL, - absl::StrCat("Expected to find ", targets.size(), " elements in '", - path, "', but found ", listing.size())); - } - } - - return NoError(); -} - -PosixError DirContains(absl::string_view path, - const std::vector<std::string>& targets) { - return DirContainsImpl(path, targets, false); -} - -PosixError DirContainsExactly(absl::string_view path, - const std::vector<std::string>& targets) { - return DirContainsImpl(path, targets, true); -} - -PosixError EventuallyDirContainsExactly( - absl::string_view path, const std::vector<std::string>& targets) { - constexpr int kRetryCount = 100; - const absl::Duration kRetryDelay = absl::Milliseconds(100); - - for (int i = 0; i < kRetryCount; ++i) { - auto res = DirContainsExactly(path, targets); - if (res.ok()) { - return res; - } else if (i < kRetryCount - 1) { - // Sleep if this isn't the final iteration. - absl::SleepFor(kRetryDelay); - } - } - return PosixError(ETIMEDOUT, - "Timed out while waiting for directory to contain files "); -} - -TEST(ProcTask, Basic) { - EXPECT_NO_ERRNO( - DirContains("/proc/self/task", {".", "..", absl::StrCat(getpid())})); -} - -std::vector<std::string> TaskFiles( - const std::vector<std::string>& initial_contents, - const std::vector<pid_t>& pids) { - return VecCat<std::string>( - initial_contents, - ApplyVec<std::string>([](const pid_t p) { return absl::StrCat(p); }, - pids)); -} - -std::vector<std::string> TaskFiles(const std::vector<pid_t>& pids) { - return TaskFiles({".", "..", absl::StrCat(getpid())}, pids); -} - -// Helper class for creating a new task in the current thread group. -class BlockingChild { - public: - BlockingChild() : thread_([=] { Start(); }) {} - ~BlockingChild() { Join(); } - - pid_t Tid() const { - absl::MutexLock ml(&mu_); - mu_.Await(absl::Condition(&tid_ready_)); - return tid_; - } - - void Join() { Stop(); } - - private: - void Start() { - absl::MutexLock ml(&mu_); - tid_ = syscall(__NR_gettid); - tid_ready_ = true; - mu_.Await(absl::Condition(&stop_)); - } - - void Stop() { - absl::MutexLock ml(&mu_); - stop_ = true; - } - - mutable absl::Mutex mu_; - bool stop_ ABSL_GUARDED_BY(mu_) = false; - pid_t tid_; - bool tid_ready_ ABSL_GUARDED_BY(mu_) = false; - - // Must be last to ensure that the destructor for the thread is run before - // any other member of the object is destroyed. - ScopedThread thread_; -}; - -TEST(ProcTask, NewThreadAppears) { - auto initial = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/task", false)); - BlockingChild child1; - EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task", - TaskFiles(initial, {child1.Tid()}))); -} - -TEST(ProcTask, KilledThreadsDisappear) { - auto initial = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/task/", false)); - - BlockingChild child1; - EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task", - TaskFiles(initial, {child1.Tid()}))); - - // Stat child1's task file. Regression test for b/32097707. - struct stat statbuf; - const std::string child1_task_file = - absl::StrCat("/proc/self/task/", child1.Tid()); - EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf), SyscallSucceeds()); - - BlockingChild child2; - EXPECT_NO_ERRNO(DirContainsExactly( - "/proc/self/task", TaskFiles(initial, {child1.Tid(), child2.Tid()}))); - - BlockingChild child3; - BlockingChild child4; - BlockingChild child5; - EXPECT_NO_ERRNO(DirContainsExactly( - "/proc/self/task", - TaskFiles(initial, {child1.Tid(), child2.Tid(), child3.Tid(), - child4.Tid(), child5.Tid()}))); - - child2.Join(); - EXPECT_NO_ERRNO(EventuallyDirContainsExactly( - "/proc/self/task", TaskFiles(initial, {child1.Tid(), child3.Tid(), - child4.Tid(), child5.Tid()}))); - - child1.Join(); - child4.Join(); - EXPECT_NO_ERRNO(EventuallyDirContainsExactly( - "/proc/self/task", TaskFiles(initial, {child3.Tid(), child5.Tid()}))); - - // Stat child1's task file again. This time it should fail. See b/32097707. - EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf), - SyscallFailsWithErrno(ENOENT)); - - child3.Join(); - child5.Join(); - EXPECT_NO_ERRNO(EventuallyDirContainsExactly("/proc/self/task", initial)); -} - -TEST(ProcTask, ChildTaskDir) { - BlockingChild child1; - EXPECT_NO_ERRNO(DirContains("/proc/self/task", TaskFiles({child1.Tid()}))); - EXPECT_NO_ERRNO(DirContains(absl::StrCat("/proc/", child1.Tid(), "/task"), - TaskFiles({child1.Tid()}))); -} - -PosixError VerifyPidDir(std::string path) { - return DirContains(path, {"exe", "fd", "io", "maps", "ns", "stat", "status"}); -} - -TEST(ProcTask, VerifyTaskDir) { - EXPECT_NO_ERRNO(VerifyPidDir("/proc/self")); - - EXPECT_NO_ERRNO(VerifyPidDir(absl::StrCat("/proc/self/task/", getpid()))); - BlockingChild child1; - EXPECT_NO_ERRNO(VerifyPidDir(absl::StrCat("/proc/self/task/", child1.Tid()))); - - // Only the first level of task directories should contain the 'task' - // directory. That is: - // - // /proc/1234/task <- should exist - // /proc/1234/task/1234/task <- should not exist - // /proc/1234/task/1235/task <- should not exist (where 1235 is in the same - // thread group as 1234). - EXPECT_FALSE( - DirContains(absl::StrCat("/proc/self/task/", getpid()), {"task"}).ok()) - << "Found 'task' directory in an inner directory."; -} - -TEST(ProcTask, TaskDirCannotBeDeleted) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - - EXPECT_THAT(rmdir("/proc/self/task"), SyscallFails()); - EXPECT_THAT(rmdir(absl::StrCat("/proc/self/task/", getpid()).c_str()), - SyscallFailsWithErrno(EACCES)); -} - -TEST(ProcTask, TaskDirHasCorrectMetadata) { - struct stat st; - EXPECT_THAT(stat("/proc/self/task", &st), SyscallSucceeds()); - EXPECT_TRUE(S_ISDIR(st.st_mode)); - - // Verify file is readable and executable by everyone. - mode_t expected_permissions = - S_IRUSR | S_IXUSR | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; - mode_t permissions = st.st_mode & (S_IRWXU | S_IRWXG | S_IRWXO); - EXPECT_EQ(expected_permissions, permissions); -} - -TEST(ProcTask, TaskDirCanSeekToEnd) { - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/task", O_RDONLY)); - EXPECT_THAT(lseek(dirfd.get(), 0, SEEK_END), SyscallSucceeds()); -} - -TEST(ProcTask, VerifyTaskDirNlinks) { - // A task directory will have 3 links if the taskgroup has a single - // thread. For example, the following shows where the links to - // '/proc/12345/task comes' from for a single threaded process with pid 12345: - // - // /proc/12345/task <-- 1 link for the directory itself - // . <-- link from "." - // .. - // 12345 - // . - // .. <-- link from ".." to parent. - // <other contents of a task dir> - // - // We can't assert an absolute number of links since we don't control how many - // threads the test framework spawns. Instead, we'll ensure creating a new - // thread increases the number of links as expected. - - // Once we reach the test body, we can count on the thread count being stable - // unless we spawn a new one. - uint64_t initial_links = ASSERT_NO_ERRNO_AND_VALUE(Links("/proc/self/task")); - ASSERT_GE(initial_links, 3); - - // For each new subtask, we should gain a new link. - BlockingChild child1; - EXPECT_THAT(Links("/proc/self/task"), - IsPosixErrorOkAndHolds(initial_links + 1)); - BlockingChild child2; - EXPECT_THAT(Links("/proc/self/task"), - IsPosixErrorOkAndHolds(initial_links + 2)); -} - -TEST(ProcTask, CommContainsThreadNameAndTrailingNewline) { - constexpr char kThreadName[] = "TestThread12345"; - ASSERT_THAT(prctl(PR_SET_NAME, kThreadName), SyscallSucceeds()); - - auto thread_name = ASSERT_NO_ERRNO_AND_VALUE( - GetContents(JoinPath("/proc", absl::StrCat(getpid()), "task", - absl::StrCat(syscall(SYS_gettid)), "comm"))); - EXPECT_EQ(absl::StrCat(kThreadName, "\n"), thread_name); -} - -TEST(ProcTaskNs, NsDirExistsAndHasCorrectMetadata) { - EXPECT_NO_ERRNO(DirContains("/proc/self/ns", {"net", "pid", "user"})); - - // Let's just test the 'pid' entry, all of them are very similar. - struct stat st; - EXPECT_THAT(lstat("/proc/self/ns/pid", &st), SyscallSucceeds()); - EXPECT_TRUE(S_ISLNK(st.st_mode)); - - auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/ns/pid")); - EXPECT_THAT(link, ::testing::StartsWith("pid:[")); -} - -TEST(ProcTaskNs, AccessOnNsNodeSucceeds) { - EXPECT_THAT(access("/proc/self/ns/pid", F_OK), SyscallSucceeds()); -} - -TEST(ProcSysKernelHostname, Exists) { - EXPECT_THAT(open("/proc/sys/kernel/hostname", O_RDONLY), SyscallSucceeds()); -} - -TEST(ProcSysKernelHostname, MatchesUname) { - struct utsname buf; - EXPECT_THAT(uname(&buf), SyscallSucceeds()); - const std::string hostname = absl::StrCat(buf.nodename, "\n"); - auto procfs_hostname = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/hostname")); - EXPECT_EQ(procfs_hostname, hostname); -} - -TEST(ProcSysVmMmapMinAddr, HasNumericValue) { - const std::string mmap_min_addr_str = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/vm/mmap_min_addr")); - uintptr_t mmap_min_addr; - EXPECT_TRUE(absl::SimpleAtoi(mmap_min_addr_str, &mmap_min_addr)) - << "/proc/sys/vm/mmap_min_addr does not contain a numeric value: " - << mmap_min_addr_str; -} - -TEST(ProcSysVmOvercommitMemory, HasNumericValue) { - const std::string overcommit_memory_str = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/vm/overcommit_memory")); - uintptr_t overcommit_memory; - EXPECT_TRUE(absl::SimpleAtoi(overcommit_memory_str, &overcommit_memory)) - << "/proc/sys/vm/overcommit_memory does not contain a numeric value: " - << overcommit_memory; -} - -// Check that link for proc fd entries point the target node, not the -// symlink itself. Regression test for b/31155070. -TEST(ProcTaskFd, FstatatFollowsSymlink) { - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - struct stat sproc = {}; - EXPECT_THAT( - fstatat(-1, absl::StrCat("/proc/self/fd/", fd.get()).c_str(), &sproc, 0), - SyscallSucceeds()); - - struct stat sfile = {}; - EXPECT_THAT(fstatat(-1, file.path().c_str(), &sfile, 0), SyscallSucceeds()); - - // If fstatat follows the fd symlink, the device and inode numbers should - // match at a minimum. - EXPECT_EQ(sproc.st_dev, sfile.st_dev); - EXPECT_EQ(sproc.st_ino, sfile.st_ino); - EXPECT_EQ(0, memcmp(&sfile, &sproc, sizeof(sfile))); -} - -TEST(ProcFilesystems, Bug65172365) { - std::string proc_filesystems = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/filesystems")); - ASSERT_FALSE(proc_filesystems.empty()); -} - -TEST(ProcFilesystems, PresenceOfShmMaxMniAll) { - uint64_t shmmax = 0; - uint64_t shmall = 0; - uint64_t shmmni = 0; - std::string proc_file; - proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmmax")); - ASSERT_FALSE(proc_file.empty()); - ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmmax)); - proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmall")); - ASSERT_FALSE(proc_file.empty()); - ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmall)); - proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmmni")); - ASSERT_FALSE(proc_file.empty()); - ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmmni)); - - ASSERT_GT(shmmax, 0); - ASSERT_GT(shmall, 0); - ASSERT_GT(shmmni, 0); - ASSERT_LE(shmall, shmmax); - - // These values should never be higher than this by default, for more - // information see uapi/linux/shm.h - ASSERT_LE(shmmax, ULONG_MAX - (1UL << 24)); - ASSERT_LE(shmall, ULONG_MAX - (1UL << 24)); -} - -// Check that /proc/mounts is a symlink to self/mounts. -TEST(ProcMounts, IsSymlink) { - auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/mounts")); - EXPECT_EQ(link, "self/mounts"); -} - -TEST(ProcSelfMountinfo, RequiredFieldsArePresent) { - auto mountinfo = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mountinfo")); - EXPECT_THAT( - mountinfo, - AllOf( - // Root mount. - ContainsRegex( - R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / / (rw|ro).*- \S+ \S+ (rw|ro)\S*)"), - // Proc mount - always rw. - ContainsRegex( - R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)"))); -} - -// Check that /proc/self/mounts looks something like a real mounts file. -TEST(ProcSelfMounts, RequiredFieldsArePresent) { - auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts")); - EXPECT_THAT(mounts, - AllOf( - // Root mount. - ContainsRegex(R"(\S+ / \S+ (rw|ro)\S* [0-9]+ [0-9]+\s)"), - // Root mount. - ContainsRegex(R"(\S+ /proc \S+ rw\S* [0-9]+ [0-9]+\s)"))); -} - -void CheckDuplicatesRecursively(std::string path) { - errno = 0; - DIR* dir = opendir(path.c_str()); - if (dir == nullptr) { - // Ignore any directories we can't read or missing directories as the - // directory could have been deleted/mutated from the time the parent - // directory contents were read. - return; - } - auto dir_closer = Cleanup([&dir]() { closedir(dir); }); - std::unordered_set<std::string> children; - while (true) { - // Readdir(3): If the end of the directory stream is reached, NULL is - // returned and errno is not changed. If an error occurs, NULL is returned - // and errno is set appropriately. To distinguish end of stream and from an - // error, set errno to zero before calling readdir() and then check the - // value of errno if NULL is returned. - errno = 0; - struct dirent* dp = readdir(dir); - if (dp == nullptr) { - ASSERT_EQ(errno, 0) << path; - break; // We're done. - } - - if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) { - continue; - } - - ASSERT_EQ(children.find(std::string(dp->d_name)), children.end()) - << dp->d_name; - children.insert(std::string(dp->d_name)); - - ASSERT_NE(dp->d_type, DT_UNKNOWN); - - if (dp->d_type != DT_DIR) { - continue; - } - CheckDuplicatesRecursively(absl::StrCat(path, "/", dp->d_name)); - } -} - -TEST(Proc, NoDuplicates) { CheckDuplicatesRecursively("/proc"); } - -// Most /proc/PID files are owned by the task user with SUID_DUMP_USER. -TEST(ProcPid, UserDumpableOwner) { - int before; - ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); - auto cleanup = Cleanup([before] { - ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); - }); - - EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_USER), SyscallSucceeds()); - - // This applies to the task directory itself and files inside. - struct stat st; - ASSERT_THAT(stat("/proc/self/", &st), SyscallSucceeds()); - EXPECT_EQ(st.st_uid, geteuid()); - EXPECT_EQ(st.st_gid, getegid()); - - ASSERT_THAT(stat("/proc/self/stat", &st), SyscallSucceeds()); - EXPECT_EQ(st.st_uid, geteuid()); - EXPECT_EQ(st.st_gid, getegid()); -} - -// /proc/PID files are owned by root with SUID_DUMP_DISABLE. -TEST(ProcPid, RootDumpableOwner) { - int before; - ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); - auto cleanup = Cleanup([before] { - ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); - }); - - EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_DISABLE), SyscallSucceeds()); - - // This *does not* applies to the task directory itself (or other 0555 - // directories), but does to files inside. - struct stat st; - ASSERT_THAT(stat("/proc/self/", &st), SyscallSucceeds()); - EXPECT_EQ(st.st_uid, geteuid()); - EXPECT_EQ(st.st_gid, getegid()); - - // This file is owned by root. Also allow nobody in case this test is running - // in a userns without root mapped. - ASSERT_THAT(stat("/proc/self/stat", &st), SyscallSucceeds()); - EXPECT_THAT(st.st_uid, AnyOf(Eq(0), Eq(65534))); - EXPECT_THAT(st.st_gid, AnyOf(Eq(0), Eq(65534))); -} - -TEST(Proc, GetdentsEnoent) { - FileDescriptor fd; - ASSERT_NO_ERRNO(WithSubprocess( - [&](int pid) -> PosixError { - // Running. - ASSIGN_OR_RETURN_ERRNO(fd, Open(absl::StrCat("/proc/", pid, "/task"), - O_RDONLY | O_DIRECTORY)); - - return NoError(); - }, - nullptr, nullptr)); - char buf[1024]; - ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)), - SyscallFailsWithErrno(ENOENT)); -} - -void CheckSyscwFromIOFile(const std::string& path, const std::string& regex) { - std::string output; - ASSERT_NO_ERRNO(GetContents(path, &output)); - ASSERT_THAT(output, ContainsRegex(absl::StrCat("syscw:\\s+", regex, "\n"))); -} - -// Checks that there is variable accounting of IO between threads/tasks. -TEST(Proc, PidTidIOAccounting) { - absl::Notification notification; - - // Run a thread with a bunch of writes. Check that io account records exactly - // the number of write calls. File open/close is there to prevent buffering. - ScopedThread writer([¬ification] { - const int num_writes = 100; - for (int i = 0; i < num_writes; i++) { - auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_NO_ERRNO(SetContents(path.path(), "a")); - } - notification.Notify(); - const std::string& writer_dir = - absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io"); - - CheckSyscwFromIOFile(writer_dir, std::to_string(num_writes)); - }); - - // Run a thread and do no writes. Check that no writes are recorded. - ScopedThread noop([¬ification] { - notification.WaitForNotification(); - const std::string& noop_dir = - absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io"); - - CheckSyscwFromIOFile(noop_dir, "0"); - }); - - writer.Join(); - noop.Join(); -} - -} // namespace -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - for (int i = 0; i < argc; ++i) { - gvisor::testing::saved_argv.emplace_back(std::string(argv[i])); - } - - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc deleted file mode 100644 index 05c952b99..000000000 --- a/test/syscalls/linux/proc_net.cc +++ /dev/null @@ -1,356 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <errno.h> -#include <netinet/in.h> -#include <poll.h> -#include <sys/socket.h> -#include <sys/syscall.h> -#include <sys/types.h> - -#include "gtest/gtest.h" -#include "absl/strings/str_split.h" -#include "absl/time/clock.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -constexpr const char kProcNet[] = "/proc/net"; - -TEST(ProcNetSymlinkTarget, FileMode) { - struct stat s; - ASSERT_THAT(stat(kProcNet, &s), SyscallSucceeds()); - EXPECT_EQ(s.st_mode & S_IFMT, S_IFDIR); - EXPECT_EQ(s.st_mode & 0777, 0555); -} - -TEST(ProcNetSymlink, FileMode) { - struct stat s; - ASSERT_THAT(lstat(kProcNet, &s), SyscallSucceeds()); - EXPECT_EQ(s.st_mode & S_IFMT, S_IFLNK); - EXPECT_EQ(s.st_mode & 0777, 0777); -} - -TEST(ProcNetSymlink, Contents) { - char buf[40] = {}; - int n = readlink(kProcNet, buf, sizeof(buf)); - ASSERT_THAT(n, SyscallSucceeds()); - - buf[n] = 0; - EXPECT_STREQ(buf, "self/net"); -} - -TEST(ProcNetIfInet6, Format) { - auto ifinet6 = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/if_inet6")); - EXPECT_THAT(ifinet6, - ::testing::MatchesRegex( - // Ex: "00000000000000000000000000000001 01 80 10 80 lo\n" - "^([a-f0-9]{32}( [a-f0-9]{2}){4} +[a-z][a-z0-9]*\n)+$")); -} - -TEST(ProcSysNetIpv4Sack, Exists) { - EXPECT_THAT(open("/proc/sys/net/ipv4/tcp_sack", O_RDONLY), SyscallSucceeds()); -} - -TEST(ProcSysNetIpv4Sack, CanReadAndWrite) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); - - auto const fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/tcp_sack", O_RDWR)); - - char buf; - EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_TRUE(buf == '0' || buf == '1') << "unexpected tcp_sack: " << buf; - - char to_write = (buf == '1') ? '0' : '1'; - EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), - SyscallSucceedsWithValue(sizeof(to_write))); - - buf = 0; - EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(buf))); - EXPECT_EQ(buf, to_write); -} - -PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp, - const std::string& type, - const std::string& item) { - std::vector<std::string> snmp_vec = absl::StrSplit(snmp, '\n'); - - // /proc/net/snmp prints a line of headers followed by a line of metrics. - // Only search the headers. - for (unsigned i = 0; i < snmp_vec.size(); i = i + 2) { - if (!absl::StartsWith(snmp_vec[i], type)) continue; - - std::vector<std::string> fields = - absl::StrSplit(snmp_vec[i], ' ', absl::SkipWhitespace()); - - EXPECT_TRUE((i + 1) < snmp_vec.size()); - std::vector<std::string> values = - absl::StrSplit(snmp_vec[i + 1], ' ', absl::SkipWhitespace()); - - EXPECT_TRUE(!fields.empty() && fields.size() == values.size()); - - // Metrics start at the first index. - for (unsigned j = 1; j < fields.size(); j++) { - if (fields[j] == item) { - uint64_t val; - if (!absl::SimpleAtoi(values[j], &val)) { - return PosixError(EINVAL, - absl::StrCat("field is not a number: ", values[j])); - } - - return val; - } - } - } - // We should never get here. - return PosixError( - EINVAL, absl::StrCat("failed to find ", type, "/", item, " in:", snmp)); -} - -TEST(ProcNetSnmp, TcpReset_NoRandomSave) { - // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. - DisableSave ds; - - uint64_t oldAttemptFails; - uint64_t oldActiveOpens; - uint64_t oldOutRsts; - auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); - oldOutRsts = - ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts")); - oldAttemptFails = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails")); - - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); - - struct sockaddr_in sin = { - .sin_family = AF_INET, - .sin_port = htons(1234), - }; - - ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); - ASSERT_THAT(connect(s.get(), (struct sockaddr*)&sin, sizeof(sin)), - SyscallFailsWithErrno(ECONNREFUSED)); - - uint64_t newAttemptFails; - uint64_t newActiveOpens; - uint64_t newOutRsts; - snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); - newOutRsts = - ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts")); - newAttemptFails = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails")); - - EXPECT_EQ(oldActiveOpens, newActiveOpens - 1); - EXPECT_EQ(oldOutRsts, newOutRsts - 1); - EXPECT_EQ(oldAttemptFails, newAttemptFails - 1); -} - -TEST(ProcNetSnmp, TcpEstab_NoRandomSave) { - // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. - DisableSave ds; - - uint64_t oldEstabResets; - uint64_t oldActiveOpens; - uint64_t oldPassiveOpens; - uint64_t oldCurrEstab; - auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); - oldPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens")); - oldCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); - oldEstabResets = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); - - FileDescriptor s_listen = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); - struct sockaddr_in sin = { - .sin_family = AF_INET, - .sin_port = 0, - }; - - ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); - ASSERT_THAT(bind(s_listen.get(), (struct sockaddr*)&sin, sizeof(sin)), - SyscallSucceeds()); - ASSERT_THAT(listen(s_listen.get(), 1), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = sizeof(sin); - ASSERT_THAT( - getsockname(s_listen.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen), - SyscallSucceeds()); - - FileDescriptor s_connect = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); - ASSERT_THAT(connect(s_connect.get(), (struct sockaddr*)&sin, sizeof(sin)), - SyscallSucceeds()); - - auto s_accept = - ASSERT_NO_ERRNO_AND_VALUE(Accept(s_listen.get(), nullptr, nullptr)); - - uint64_t newEstabResets; - uint64_t newActiveOpens; - uint64_t newPassiveOpens; - uint64_t newCurrEstab; - snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); - newPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens")); - newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); - - EXPECT_EQ(oldActiveOpens, newActiveOpens - 1); - EXPECT_EQ(oldPassiveOpens, newPassiveOpens - 1); - EXPECT_EQ(oldCurrEstab, newCurrEstab - 2); - - // Send 1 byte from client to server. - ASSERT_THAT(send(s_connect.get(), "a", 1, 0), SyscallSucceedsWithValue(1)); - - constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. - - // Wait until server-side fd sees the data on its side but don't read it. - struct pollfd poll_fd = {s_accept.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Now close server-side fd without reading the data which leads to a RST - // packet sent to client side. - s_accept.reset(-1); - - // Wait until client-side fd sees RST packet. - struct pollfd poll_fd1 = {s_connect.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd1, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Now close client-side fd. - s_connect.reset(-1); - - // Wait until the process of the netstack. - absl::SleepFor(absl::Seconds(1)); - - snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); - newEstabResets = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); - - EXPECT_EQ(oldCurrEstab, newCurrEstab); - EXPECT_EQ(oldEstabResets, newEstabResets - 2); -} - -TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { - // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. - DisableSave ds; - - uint64_t oldOutDatagrams; - uint64_t oldNoPorts; - auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); - oldNoPorts = - ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts")); - - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - - struct sockaddr_in sin = { - .sin_family = AF_INET, - .sin_port = htons(4444), - }; - ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); - ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr*)&sin, sizeof(sin)), - SyscallSucceedsWithValue(1)); - - uint64_t newOutDatagrams; - uint64_t newNoPorts; - snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); - newNoPorts = - ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts")); - - EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1); - EXPECT_EQ(oldNoPorts, newNoPorts - 1); -} - -TEST(ProcNetSnmp, UdpIn) { - // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. - const DisableSave ds; - - uint64_t oldOutDatagrams; - uint64_t oldInDatagrams; - auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); - oldInDatagrams = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams")); - - std::cerr << "snmp: " << std::endl << snmp << std::endl; - FileDescriptor server = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - struct sockaddr_in sin = { - .sin_family = AF_INET, - .sin_port = htons(0), - }; - ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); - ASSERT_THAT(bind(server.get(), (struct sockaddr*)&sin, sizeof(sin)), - SyscallSucceeds()); - // Get the port bound by the server socket. - socklen_t addrlen = sizeof(sin); - ASSERT_THAT( - getsockname(server.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen), - SyscallSucceeds()); - - FileDescriptor client = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - ASSERT_THAT( - sendto(client.get(), "a", 1, 0, (struct sockaddr*)&sin, sizeof(sin)), - SyscallSucceedsWithValue(1)); - - char buf[128]; - ASSERT_THAT(recvfrom(server.get(), buf, sizeof(buf), 0, NULL, NULL), - SyscallSucceedsWithValue(1)); - - uint64_t newOutDatagrams; - uint64_t newInDatagrams; - snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); - std::cerr << "new snmp: " << std::endl << snmp << std::endl; - newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); - newInDatagrams = ASSERT_NO_ERRNO_AND_VALUE( - GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams")); - - EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1); - EXPECT_EQ(oldInDatagrams, newInDatagrams - 1); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc deleted file mode 100644 index 5b6e3e3cd..000000000 --- a/test/syscalls/linux/proc_net_tcp.cc +++ /dev/null @@ -1,496 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/tcp.h> -#include <sys/socket.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -using absl::StrCat; -using absl::StrSplit; - -constexpr char kProcNetTCPHeader[] = - " sl local_address rem_address st tx_queue rx_queue tr tm->when " - "retrnsmt uid timeout inode " - " "; - -// TCPEntry represents a single entry from /proc/net/tcp. -struct TCPEntry { - uint32_t local_addr; - uint16_t local_port; - - uint32_t remote_addr; - uint16_t remote_port; - - uint64_t state; - uint64_t uid; - uint64_t inode; -}; - -// Finds the first entry in 'entries' for which 'predicate' returns true. -// Returns true on match, and sets 'match' to a copy of the matching entry. If -// 'match' is null, it's ignored. -bool FindBy(const std::vector<TCPEntry>& entries, TCPEntry* match, - std::function<bool(const TCPEntry&)> predicate) { - for (const TCPEntry& entry : entries) { - if (predicate(entry)) { - if (match != nullptr) { - *match = entry; - } - return true; - } - } - return false; -} - -bool FindByLocalAddr(const std::vector<TCPEntry>& entries, TCPEntry* match, - const struct sockaddr* addr) { - uint32_t host = IPFromInetSockaddr(addr); - uint16_t port = PortFromInetSockaddr(addr); - return FindBy(entries, match, [host, port](const TCPEntry& e) { - return (e.local_addr == host && e.local_port == port); - }); -} - -bool FindByRemoteAddr(const std::vector<TCPEntry>& entries, TCPEntry* match, - const struct sockaddr* addr) { - uint32_t host = IPFromInetSockaddr(addr); - uint16_t port = PortFromInetSockaddr(addr); - return FindBy(entries, match, [host, port](const TCPEntry& e) { - return (e.remote_addr == host && e.remote_port == port); - }); -} - -// Returns a parsed representation of /proc/net/tcp entries. -PosixErrorOr<std::vector<TCPEntry>> ProcNetTCPEntries() { - std::string content; - RETURN_IF_ERRNO(GetContents("/proc/net/tcp", &content)); - - bool found_header = false; - std::vector<TCPEntry> entries; - std::vector<std::string> lines = StrSplit(content, '\n'); - std::cerr << "<contents of /proc/net/tcp>" << std::endl; - for (const std::string& line : lines) { - std::cerr << line << std::endl; - - if (!found_header) { - EXPECT_EQ(line, kProcNetTCPHeader); - found_header = true; - continue; - } - if (line.empty()) { - continue; - } - - // Parse a single entry from /proc/net/tcp. - // - // Example entries: - // - // clang-format off - // - // sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode - // 0: 00000000:006F 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 1968 1 0000000000000000 100 0 0 10 0 - // 1: 0100007F:7533 00000000:0000 0A 00000000:00000000 00:00000000 00000000 120 0 10684 1 0000000000000000 100 0 0 10 0 - // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ - // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 - // - // clang-format on - - TCPEntry entry; - std::vector<std::string> fields = - StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty()); - - ASSIGN_OR_RETURN_ERRNO(entry.local_addr, AtoiBase(fields[1], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16)); - - ASSIGN_OR_RETURN_ERRNO(entry.remote_addr, AtoiBase(fields[3], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16)); - - ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11])); - ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13])); - - entries.push_back(entry); - } - std::cerr << "<end of /proc/net/tcp>" << std::endl; - - return entries; -} - -TEST(ProcNetTCP, Exists) { - const std::string content = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/tcp")); - const std::string header_line = StrCat(kProcNetTCPHeader, "\n"); - if (IsRunningOnGvisor()) { - // Should be just the header since we don't have any tcp sockets yet. - EXPECT_EQ(content, header_line); - } else { - // On a general linux machine, we could have abitrary sockets on the system, - // so just check the header. - EXPECT_THAT(content, ::testing::StartsWith(header_line)); - } -} - -TEST(ProcNetTCP, EntryUID) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPAcceptBindSocketPair(0).Create()); - std::vector<TCPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); - TCPEntry e; - ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr())); - EXPECT_EQ(e.uid, geteuid()); - ASSERT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr())); - EXPECT_EQ(e.uid, geteuid()); -} - -TEST(ProcNetTCP, BindAcceptConnect) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPAcceptBindSocketPair(0).Create()); - std::vector<TCPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); - // We can only make assertions about the total number of entries if we control - // the entire "machine". - if (IsRunningOnGvisor()) { - EXPECT_EQ(entries.size(), 2); - } - - EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr())); - EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->first_addr())); -} - -TEST(ProcNetTCP, InodeReasonable) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPAcceptBindSocketPair(0).Create()); - std::vector<TCPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); - - TCPEntry accepted_entry; - ASSERT_TRUE(FindByLocalAddr(entries, &accepted_entry, sockets->first_addr())); - EXPECT_NE(accepted_entry.inode, 0); - - TCPEntry client_entry; - ASSERT_TRUE(FindByRemoteAddr(entries, &client_entry, sockets->first_addr())); - EXPECT_NE(client_entry.inode, 0); - EXPECT_NE(accepted_entry.inode, client_entry.inode); -} - -TEST(ProcNetTCP, State) { - std::unique_ptr<FileDescriptor> server = - ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPUnboundSocket(0).Create()); - - auto test_addr = V4Loopback(); - ASSERT_THAT( - bind(server->get(), reinterpret_cast<struct sockaddr*>(&test_addr.addr), - test_addr.addr_len), - SyscallSucceeds()); - - struct sockaddr addr; - socklen_t addrlen = sizeof(struct sockaddr); - ASSERT_THAT(getsockname(server->get(), &addr, &addrlen), SyscallSucceeds()); - ASSERT_EQ(addrlen, sizeof(struct sockaddr)); - - ASSERT_THAT(listen(server->get(), 10), SyscallSucceeds()); - std::vector<TCPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); - TCPEntry listen_entry; - ASSERT_TRUE(FindByLocalAddr(entries, &listen_entry, &addr)); - EXPECT_EQ(listen_entry.state, TCP_LISTEN); - - std::unique_ptr<FileDescriptor> client = - ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPUnboundSocket(0).Create()); - ASSERT_THAT(RetryEINTR(connect)(client->get(), &addr, addrlen), - SyscallSucceeds()); - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); - ASSERT_TRUE(FindByLocalAddr(entries, &listen_entry, &addr)); - EXPECT_EQ(listen_entry.state, TCP_LISTEN); - TCPEntry client_entry; - ASSERT_TRUE(FindByRemoteAddr(entries, &client_entry, &addr)); - EXPECT_EQ(client_entry.state, TCP_ESTABLISHED); - - FileDescriptor accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr)); - - const uint32_t accepted_local_host = IPFromInetSockaddr(&addr); - const uint16_t accepted_local_port = PortFromInetSockaddr(&addr); - - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); - TCPEntry accepted_entry; - ASSERT_TRUE(FindBy(entries, &accepted_entry, - [client_entry, accepted_local_host, - accepted_local_port](const TCPEntry& e) { - return e.local_addr == accepted_local_host && - e.local_port == accepted_local_port && - e.remote_addr == client_entry.local_addr && - e.remote_port == client_entry.local_port; - })); - EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED); -} - -constexpr char kProcNetTCP6Header[] = - " sl local_address remote_address" - " st tx_queue rx_queue tr tm->when retrnsmt" - " uid timeout inode"; - -// TCP6Entry represents a single entry from /proc/net/tcp6. -struct TCP6Entry { - struct in6_addr local_addr; - uint16_t local_port; - - struct in6_addr remote_addr; - uint16_t remote_port; - - uint64_t state; - uint64_t uid; - uint64_t inode; -}; - -bool IPv6AddrEqual(const struct in6_addr* a1, const struct in6_addr* a2) { - return memcmp(a1, a2, sizeof(struct in6_addr)) == 0; -} - -// Finds the first entry in 'entries' for which 'predicate' returns true. -// Returns true on match, and sets 'match' to a copy of the matching entry. If -// 'match' is null, it's ignored. -bool FindBy6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, - std::function<bool(const TCP6Entry&)> predicate) { - for (const TCP6Entry& entry : entries) { - if (predicate(entry)) { - if (match != nullptr) { - *match = entry; - } - return true; - } - } - return false; -} - -const struct in6_addr* IP6FromInetSockaddr(const struct sockaddr* addr) { - auto* addr6 = reinterpret_cast<const struct sockaddr_in6*>(addr); - return &addr6->sin6_addr; -} - -bool FindByLocalAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, - const struct sockaddr* addr) { - const struct in6_addr* local = IP6FromInetSockaddr(addr); - uint16_t port = PortFromInetSockaddr(addr); - return FindBy6(entries, match, [local, port](const TCP6Entry& e) { - return (IPv6AddrEqual(&e.local_addr, local) && e.local_port == port); - }); -} - -bool FindByRemoteAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, - const struct sockaddr* addr) { - const struct in6_addr* remote = IP6FromInetSockaddr(addr); - uint16_t port = PortFromInetSockaddr(addr); - return FindBy6(entries, match, [remote, port](const TCP6Entry& e) { - return (IPv6AddrEqual(&e.remote_addr, remote) && e.remote_port == port); - }); -} - -void ReadIPv6Address(std::string s, struct in6_addr* addr) { - uint32_t a0, a1, a2, a3; - const char* fmt = "%08X%08X%08X%08X"; - EXPECT_EQ(sscanf(s.c_str(), fmt, &a0, &a1, &a2, &a3), 4); - - uint8_t* b = addr->s6_addr; - *((uint32_t*)&b[0]) = a0; - *((uint32_t*)&b[4]) = a1; - *((uint32_t*)&b[8]) = a2; - *((uint32_t*)&b[12]) = a3; -} - -// Returns a parsed representation of /proc/net/tcp6 entries. -PosixErrorOr<std::vector<TCP6Entry>> ProcNetTCP6Entries() { - std::string content; - RETURN_IF_ERRNO(GetContents("/proc/net/tcp6", &content)); - - bool found_header = false; - std::vector<TCP6Entry> entries; - std::vector<std::string> lines = StrSplit(content, '\n'); - std::cerr << "<contents of /proc/net/tcp6>" << std::endl; - for (const std::string& line : lines) { - std::cerr << line << std::endl; - - if (!found_header) { - EXPECT_EQ(line, kProcNetTCP6Header); - found_header = true; - continue; - } - if (line.empty()) { - continue; - } - - // Parse a single entry from /proc/net/tcp6. - // - // Example entries: - // - // clang-format off - // - // sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode - // 0: 00000000000000000000000000000000:1F90 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876340 1 ffff8803da9c9380 100 0 0 10 0 - // 1: 00000000000000000000000000000000:C350 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876987 1 ffff8803ec408000 100 0 0 10 0 - // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ - // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 - // - // clang-format on - - TCP6Entry entry; - std::vector<std::string> fields = - StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty()); - - ReadIPv6Address(fields[1], &entry.local_addr); - ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16)); - ReadIPv6Address(fields[3], &entry.remote_addr); - ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11])); - ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13])); - - entries.push_back(entry); - } - std::cerr << "<end of /proc/net/tcp6>" << std::endl; - - return entries; -} - -TEST(ProcNetTCP6, Exists) { - const std::string content = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/tcp6")); - const std::string header_line = StrCat(kProcNetTCP6Header, "\n"); - if (IsRunningOnGvisor()) { - // Should be just the header since we don't have any tcp sockets yet. - EXPECT_EQ(content, header_line); - } else { - // On a general linux machine, we could have abitrary sockets on the system, - // so just check the header. - EXPECT_THAT(content, ::testing::StartsWith(header_line)); - } -} - -TEST(ProcNetTCP6, EntryUID) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); - std::vector<TCP6Entry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); - TCP6Entry e; - - ASSERT_TRUE(FindByLocalAddr6(entries, &e, sockets->first_addr())); - EXPECT_EQ(e.uid, geteuid()); - ASSERT_TRUE(FindByRemoteAddr6(entries, &e, sockets->first_addr())); - EXPECT_EQ(e.uid, geteuid()); -} - -TEST(ProcNetTCP6, BindAcceptConnect) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); - std::vector<TCP6Entry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); - // We can only make assertions about the total number of entries if we control - // the entire "machine". - if (IsRunningOnGvisor()) { - EXPECT_EQ(entries.size(), 2); - } - - EXPECT_TRUE(FindByLocalAddr6(entries, nullptr, sockets->first_addr())); - EXPECT_TRUE(FindByRemoteAddr6(entries, nullptr, sockets->first_addr())); -} - -TEST(ProcNetTCP6, InodeReasonable) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); - std::vector<TCP6Entry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); - - TCP6Entry accepted_entry; - - ASSERT_TRUE( - FindByLocalAddr6(entries, &accepted_entry, sockets->first_addr())); - EXPECT_NE(accepted_entry.inode, 0); - - TCP6Entry client_entry; - ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, sockets->first_addr())); - EXPECT_NE(client_entry.inode, 0); - EXPECT_NE(accepted_entry.inode, client_entry.inode); -} - -TEST(ProcNetTCP6, State) { - std::unique_ptr<FileDescriptor> server = - ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create()); - - auto test_addr = V6Loopback(); - ASSERT_THAT( - bind(server->get(), reinterpret_cast<struct sockaddr*>(&test_addr.addr), - test_addr.addr_len), - SyscallSucceeds()); - - struct sockaddr_in6 addr6; - socklen_t addrlen = sizeof(struct sockaddr_in6); - auto* addr = reinterpret_cast<struct sockaddr*>(&addr6); - ASSERT_THAT(getsockname(server->get(), addr, &addrlen), SyscallSucceeds()); - ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); - - ASSERT_THAT(listen(server->get(), 10), SyscallSucceeds()); - std::vector<TCP6Entry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); - TCP6Entry listen_entry; - - ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr)); - EXPECT_EQ(listen_entry.state, TCP_LISTEN); - - std::unique_ptr<FileDescriptor> client = - ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create()); - ASSERT_THAT(RetryEINTR(connect)(client->get(), addr, addrlen), - SyscallSucceeds()); - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); - ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr)); - EXPECT_EQ(listen_entry.state, TCP_LISTEN); - TCP6Entry client_entry; - ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, addr)); - EXPECT_EQ(client_entry.state, TCP_ESTABLISHED); - - FileDescriptor accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr)); - - const struct in6_addr* local = IP6FromInetSockaddr(addr); - const uint16_t accepted_local_port = PortFromInetSockaddr(addr); - - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); - TCP6Entry accepted_entry; - ASSERT_TRUE(FindBy6( - entries, &accepted_entry, - [client_entry, local, accepted_local_port](const TCP6Entry& e) { - return IPv6AddrEqual(&e.local_addr, local) && - e.local_port == accepted_local_port && - IPv6AddrEqual(&e.remote_addr, &client_entry.local_addr) && - e.remote_port == client_entry.local_port; - })); - EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc deleted file mode 100644 index 786b4b4af..000000000 --- a/test/syscalls/linux/proc_net_udp.cc +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/tcp.h> -#include <sys/socket.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -using absl::StrCat; -using absl::StrFormat; -using absl::StrSplit; - -constexpr char kProcNetUDPHeader[] = - " sl local_address rem_address st tx_queue rx_queue tr tm->when " - "retrnsmt uid timeout inode ref pointer drops "; - -// UDPEntry represents a single entry from /proc/net/udp. -struct UDPEntry { - uint32_t local_addr; - uint16_t local_port; - - uint32_t remote_addr; - uint16_t remote_port; - - uint64_t state; - uint64_t uid; - uint64_t inode; -}; - -std::string DescribeFirstInetSocket(const SocketPair& sockets) { - const struct sockaddr* addr = sockets.first_addr(); - return StrFormat("First test socket: fd:%d %8X:%4X", sockets.first_fd(), - IPFromInetSockaddr(addr), PortFromInetSockaddr(addr)); -} - -std::string DescribeSecondInetSocket(const SocketPair& sockets) { - const struct sockaddr* addr = sockets.second_addr(); - return StrFormat("Second test socket fd:%d %8X:%4X", sockets.second_fd(), - IPFromInetSockaddr(addr), PortFromInetSockaddr(addr)); -} - -// Finds the first entry in 'entries' for which 'predicate' returns true. -// Returns true on match, and set 'match' to a copy of the matching entry. If -// 'match' is null, it's ignored. -bool FindBy(const std::vector<UDPEntry>& entries, UDPEntry* match, - std::function<bool(const UDPEntry&)> predicate) { - for (const UDPEntry& entry : entries) { - if (predicate(entry)) { - if (match != nullptr) { - *match = entry; - } - return true; - } - } - return false; -} - -bool FindByLocalAddr(const std::vector<UDPEntry>& entries, UDPEntry* match, - const struct sockaddr* addr) { - uint32_t host = IPFromInetSockaddr(addr); - uint16_t port = PortFromInetSockaddr(addr); - return FindBy(entries, match, [host, port](const UDPEntry& e) { - return (e.local_addr == host && e.local_port == port); - }); -} - -bool FindByRemoteAddr(const std::vector<UDPEntry>& entries, UDPEntry* match, - const struct sockaddr* addr) { - uint32_t host = IPFromInetSockaddr(addr); - uint16_t port = PortFromInetSockaddr(addr); - return FindBy(entries, match, [host, port](const UDPEntry& e) { - return (e.remote_addr == host && e.remote_port == port); - }); -} - -PosixErrorOr<uint64_t> InodeFromSocketFD(int fd) { - ASSIGN_OR_RETURN_ERRNO(struct stat s, Fstat(fd)); - if (!S_ISSOCK(s.st_mode)) { - return PosixError(EINVAL, StrFormat("FD %d is not a socket", fd)); - } - return s.st_ino; -} - -PosixErrorOr<bool> FindByFD(const std::vector<UDPEntry>& entries, - UDPEntry* match, int fd) { - ASSIGN_OR_RETURN_ERRNO(uint64_t inode, InodeFromSocketFD(fd)); - return FindBy(entries, match, - [inode](const UDPEntry& e) { return (e.inode == inode); }); -} - -// Returns a parsed representation of /proc/net/udp entries. -PosixErrorOr<std::vector<UDPEntry>> ProcNetUDPEntries() { - std::string content; - RETURN_IF_ERRNO(GetContents("/proc/net/udp", &content)); - - bool found_header = false; - std::vector<UDPEntry> entries; - std::vector<std::string> lines = StrSplit(content, '\n'); - std::cerr << "<contents of /proc/net/udp>" << std::endl; - for (const std::string& line : lines) { - std::cerr << line << std::endl; - - if (!found_header) { - EXPECT_EQ(line, kProcNetUDPHeader); - found_header = true; - continue; - } - if (line.empty()) { - continue; - } - - // Parse a single entry from /proc/net/udp. - // - // Example entries: - // - // clang-format off - // - // sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops - // 3503: 0100007F:0035 00000000:0000 07 00000000:00000000 00:00000000 00000000 0 0 33317 2 0000000000000000 0 - // 3518: 00000000:0044 00000000:0000 07 00000000:00000000 00:00000000 00000000 0 0 40394 2 0000000000000000 0 - // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ - // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 - // - // clang-format on - - UDPEntry entry; - std::vector<std::string> fields = - StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty()); - - ASSIGN_OR_RETURN_ERRNO(entry.local_addr, AtoiBase(fields[1], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16)); - - ASSIGN_OR_RETURN_ERRNO(entry.remote_addr, AtoiBase(fields[3], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16)); - - ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11])); - ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13])); - - // Linux shares internal data structures between TCP and UDP sockets. The - // proc entries for UDP sockets share some fields with TCP sockets, but - // these fields should always be zero as they're not meaningful for UDP - // sockets. - EXPECT_EQ(fields[8], "00") << StrFormat("sl:%s, tr", fields[0]); - EXPECT_EQ(fields[9], "00000000") << StrFormat("sl:%s, tm->when", fields[0]); - EXPECT_EQ(fields[10], "00000000") - << StrFormat("sl:%s, retrnsmt", fields[0]); - EXPECT_EQ(fields[12], "0") << StrFormat("sl:%s, timeout", fields[0]); - - entries.push_back(entry); - } - std::cerr << "<end of /proc/net/udp>" << std::endl; - - return entries; -} - -TEST(ProcNetUDP, Exists) { - const std::string content = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/udp")); - const std::string header_line = StrCat(kProcNetUDPHeader, "\n"); - EXPECT_THAT(content, ::testing::StartsWith(header_line)); -} - -TEST(ProcNetUDP, EntryUID) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create()); - std::vector<UDPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries()); - UDPEntry e; - ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr())) - << DescribeFirstInetSocket(*sockets); - EXPECT_EQ(e.uid, geteuid()); - ASSERT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr())) - << DescribeSecondInetSocket(*sockets); - EXPECT_EQ(e.uid, geteuid()); -} - -TEST(ProcNetUDP, FindMutualEntries) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create()); - std::vector<UDPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries()); - - EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr())) - << DescribeFirstInetSocket(*sockets); - EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->first_addr())) - << DescribeSecondInetSocket(*sockets); - - EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr())) - << DescribeSecondInetSocket(*sockets); - EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->second_addr())) - << DescribeFirstInetSocket(*sockets); -} - -TEST(ProcNetUDP, EntriesRemovedOnClose) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create()); - std::vector<UDPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries()); - - EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr())) - << DescribeFirstInetSocket(*sockets); - EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr())) - << DescribeSecondInetSocket(*sockets); - - EXPECT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries()); - // First socket's entry should be gone, but the second socket's entry should - // still exist. - EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->first_addr())) - << DescribeFirstInetSocket(*sockets); - EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr())) - << DescribeSecondInetSocket(*sockets); - - EXPECT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries()); - // Both entries should be gone. - EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->first_addr())) - << DescribeFirstInetSocket(*sockets); - EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->second_addr())) - << DescribeSecondInetSocket(*sockets); -} - -PosixErrorOr<std::unique_ptr<FileDescriptor>> BoundUDPSocket() { - ASSIGN_OR_RETURN_ERRNO(std::unique_ptr<FileDescriptor> socket, - IPv4UDPUnboundSocket(0).Create()); - struct sockaddr_in addr; - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = htonl(INADDR_ANY); - addr.sin_port = 0; - - int res = bind(socket->get(), reinterpret_cast<const struct sockaddr*>(&addr), - sizeof(addr)); - if (res) { - return PosixError(errno, "bind()"); - } - return socket; -} - -TEST(ProcNetUDP, BoundEntry) { - std::unique_ptr<FileDescriptor> socket = - ASSERT_NO_ERRNO_AND_VALUE(BoundUDPSocket()); - struct sockaddr addr; - socklen_t len = sizeof(addr); - ASSERT_THAT(getsockname(socket->get(), &addr, &len), SyscallSucceeds()); - uint16_t port = PortFromInetSockaddr(&addr); - - std::vector<UDPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries()); - UDPEntry e; - ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(FindByFD(entries, &e, socket->get()))); - EXPECT_EQ(e.local_port, port); - EXPECT_EQ(e.remote_addr, 0); - EXPECT_EQ(e.remote_port, 0); -} - -TEST(ProcNetUDP, BoundSocketStateClosed) { - std::unique_ptr<FileDescriptor> socket = - ASSERT_NO_ERRNO_AND_VALUE(BoundUDPSocket()); - std::vector<UDPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries()); - UDPEntry e; - ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(FindByFD(entries, &e, socket->get()))); - EXPECT_EQ(e.state, TCP_CLOSE); -} - -TEST(ProcNetUDP, ConnectedSocketStateEstablished) { - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create()); - std::vector<UDPEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries()); - - UDPEntry e; - ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr())) - << DescribeFirstInetSocket(*sockets); - EXPECT_EQ(e.state, TCP_ESTABLISHED); - - ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->second_addr())) - << DescribeSecondInetSocket(*sockets); - EXPECT_EQ(e.state, TCP_ESTABLISHED); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc deleted file mode 100644 index 66db0acaa..000000000 --- a/test/syscalls/linux/proc_net_unix.cc +++ /dev/null @@ -1,443 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gtest/gtest.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -using absl::StrCat; -using absl::StreamFormat; -using absl::StrFormat; - -constexpr char kProcNetUnixHeader[] = - "Num RefCount Protocol Flags Type St Inode Path"; - -// Possible values of the "st" field in a /proc/net/unix entry. Source: Linux -// kernel, include/uapi/linux/net.h. -enum { - SS_FREE = 0, // Not allocated - SS_UNCONNECTED, // Unconnected to any socket - SS_CONNECTING, // In process of connecting - SS_CONNECTED, // Connected to socket - SS_DISCONNECTING // In process of disconnecting -}; - -// UnixEntry represents a single entry from /proc/net/unix. -struct UnixEntry { - uintptr_t addr; - uint64_t refs; - uint64_t protocol; - uint64_t flags; - uint64_t type; - uint64_t state; - uint64_t inode; - std::string path; -}; - -// Abstract socket paths can have either trailing null bytes or '@'s as padding -// at the end, depending on the linux version. This function strips any such -// padding. -void StripAbstractPathPadding(std::string* s) { - const char pad_char = s->back(); - if (pad_char != '\0' && pad_char != '@') { - return; - } - - const auto last_pos = s->find_last_not_of(pad_char); - if (last_pos != std::string::npos) { - s->resize(last_pos + 1); - } -} - -// Precondition: addr must be a unix socket address (i.e. sockaddr_un) and -// addr->sun_path must be null-terminated. This is always the case if addr comes -// from Linux: -// -// Per man unix(7): -// -// "When the address of a pathname socket is returned (by [getsockname(2)]), its -// length is -// -// offsetof(struct sockaddr_un, sun_path) + strlen(sun_path) + 1 -// -// and sun_path contains the null-terminated pathname." -std::string ExtractPath(const struct sockaddr* addr) { - const char* path = - reinterpret_cast<const struct sockaddr_un*>(addr)->sun_path; - // Note: sockaddr_un.sun_path is an embedded character array of length - // UNIX_PATH_MAX, so we can always safely dereference the first 2 bytes below. - // - // We also rely on the path being null-terminated. - if (path[0] == 0) { - std::string abstract_path = StrCat("@", &path[1]); - StripAbstractPathPadding(&abstract_path); - return abstract_path; - } - return std::string(path); -} - -// Returns a parsed representation of /proc/net/unix entries. -PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() { - std::string content; - RETURN_IF_ERRNO(GetContents("/proc/net/unix", &content)); - - bool skipped_header = false; - std::vector<UnixEntry> entries; - std::vector<std::string> lines = absl::StrSplit(content, '\n'); - std::cerr << "<contents of /proc/net/unix>" << std::endl; - for (std::string line : lines) { - // Emit the proc entry to the test output to provide context for the test - // results. - std::cerr << line << std::endl; - - if (!skipped_header) { - EXPECT_EQ(line, kProcNetUnixHeader); - skipped_header = true; - continue; - } - if (line.empty()) { - continue; - } - - // Parse a single entry from /proc/net/unix. - // - // Sample file: - // - // clang-format off - // - // Num RefCount Protocol Flags Type St Inode Path" - // ffffa130e7041c00: 00000002 00000000 00010000 0001 01 1299413685 /tmp/control_server/13293772586877554487 - // ffffa14f547dc400: 00000002 00000000 00010000 0001 01 3793 @remote_coredump - // - // clang-format on - // - // Note that from the second entry, the inode number can be padded using - // spaces, so we need to handle it separately during parsing. See - // net/unix/af_unix.c:unix_seq_show() for how these entries are produced. In - // particular, only the inode field is padded with spaces. - UnixEntry entry; - - // Process the first 6 fields, up to but not including "Inode". - std::vector<std::string> fields = - absl::StrSplit(line, absl::MaxSplits(' ', 6)); - - if (fields.size() < 7) { - return PosixError(EINVAL, StrFormat("Invalid entry: '%s'\n", line)); - } - - // AtoiBase can't handle the ':' in the "Num" field, so strip it out. - std::vector<std::string> addr = absl::StrSplit(fields[0], ':'); - ASSIGN_OR_RETURN_ERRNO(entry.addr, AtoiBase(addr[0], 16)); - - ASSIGN_OR_RETURN_ERRNO(entry.refs, AtoiBase(fields[1], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.protocol, AtoiBase(fields[2], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.flags, AtoiBase(fields[3], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.type, AtoiBase(fields[4], 16)); - ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16)); - - absl::string_view rest = absl::StripAsciiWhitespace(fields[6]); - fields = absl::StrSplit(rest, absl::MaxSplits(' ', 1)); - if (fields.empty()) { - return PosixError( - EINVAL, StrFormat("Invalid entry, missing 'Inode': '%s'\n", line)); - } - ASSIGN_OR_RETURN_ERRNO(entry.inode, AtoiBase(fields[0], 10)); - - entry.path = ""; - if (fields.size() > 1) { - entry.path = fields[1]; - StripAbstractPathPadding(&entry.path); - } - - entries.push_back(entry); - } - std::cerr << "<end of /proc/net/unix>" << std::endl; - - return entries; -} - -// Finds the first entry in 'entries' for which 'predicate' returns true. -// Returns true on match, and sets 'match' to point to the matching entry. -bool FindBy(std::vector<UnixEntry> entries, UnixEntry* match, - std::function<bool(const UnixEntry&)> predicate) { - for (int i = 0; i < entries.size(); ++i) { - if (predicate(entries[i])) { - *match = entries[i]; - return true; - } - } - return false; -} - -bool FindByPath(std::vector<UnixEntry> entries, UnixEntry* match, - const std::string& path) { - return FindBy(entries, match, - [path](const UnixEntry& e) { return e.path == path; }); -} - -TEST(ProcNetUnix, Exists) { - const std::string content = - ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/unix")); - const std::string header_line = StrCat(kProcNetUnixHeader, "\n"); - if (IsRunningOnGvisor()) { - // Should be just the header since we don't have any unix domain sockets - // yet. - EXPECT_EQ(content, header_line); - } else { - // However, on a general linux machine, we could have abitrary sockets on - // the system, so just check the header. - EXPECT_THAT(content, ::testing::StartsWith(header_line)); - } -} - -TEST(ProcNetUnix, FilesystemBindAcceptConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE( - FilesystemBoundUnixDomainSocketPair(SOCK_STREAM).Create()); - - std::string path1 = ExtractPath(sockets->first_addr()); - std::string path2 = ExtractPath(sockets->second_addr()); - std::cerr << StreamFormat("Server socket address (path1): %s\n", path1); - std::cerr << StreamFormat("Client socket address (path2): %s\n", path2); - - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - if (IsRunningOnGvisor()) { - EXPECT_EQ(entries.size(), 2); - } - - // The server-side socket's path is listed in the socket entry... - UnixEntry s1; - EXPECT_TRUE(FindByPath(entries, &s1, path1)); - - // ... but the client-side socket's path is not. - UnixEntry s2; - EXPECT_FALSE(FindByPath(entries, &s2, path2)); -} - -TEST(ProcNetUnix, AbstractBindAcceptConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE( - AbstractBoundUnixDomainSocketPair(SOCK_STREAM).Create()); - - std::string path1 = ExtractPath(sockets->first_addr()); - std::string path2 = ExtractPath(sockets->second_addr()); - std::cerr << StreamFormat("Server socket address (path1): '%s'\n", path1); - std::cerr << StreamFormat("Client socket address (path2): '%s'\n", path2); - - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - if (IsRunningOnGvisor()) { - EXPECT_EQ(entries.size(), 2); - } - - // The server-side socket's path is listed in the socket entry... - UnixEntry s1; - EXPECT_TRUE(FindByPath(entries, &s1, path1)); - - // ... but the client-side socket's path is not. - UnixEntry s2; - EXPECT_FALSE(FindByPath(entries, &s2, path2)); -} - -TEST(ProcNetUnix, SocketPair) { - // Under gvisor, ensure a socketpair() syscall creates exactly 2 new - // entries. We have no way to verify this under Linux, as we have no control - // over socket creation on a general Linux machine. - SKIP_IF(!IsRunningOnGvisor()); - - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - ASSERT_EQ(entries.size(), 0); - - auto sockets = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_STREAM).Create()); - - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - EXPECT_EQ(entries.size(), 2); -} - -TEST(ProcNetUnix, StreamSocketStateUnconnectedOnBind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE( - AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - - const std::string address = ExtractPath(sockets->first_addr()); - UnixEntry bind_entry; - ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); - EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); -} - -TEST(ProcNetUnix, StreamSocketStateStateUnconnectedOnListen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE( - AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - - const std::string address = ExtractPath(sockets->first_addr()); - UnixEntry bind_entry; - ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); - EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); - - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - UnixEntry listen_entry; - ASSERT_TRUE( - FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); - EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); - // The bind and listen entries should refer to the same socket. - EXPECT_EQ(listen_entry.inode, bind_entry.inode); -} - -TEST(ProcNetUnix, StreamSocketStateStateConnectedOnAccept) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE( - AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); - const std::string address = ExtractPath(sockets->first_addr()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - UnixEntry listen_entry; - ASSERT_TRUE( - FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - int clientfd; - ASSERT_THAT(clientfd = accept(sockets->first_fd(), nullptr, nullptr), - SyscallSucceeds()); - - // Find the entry for the accepted socket. UDS proc entries don't have a - // remote address, so we distinguish the accepted socket from the listen - // socket by checking for a different inode. - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - UnixEntry accept_entry; - ASSERT_TRUE(FindBy( - entries, &accept_entry, [address, listen_entry](const UnixEntry& e) { - return e.path == address && e.inode != listen_entry.inode; - })); - EXPECT_EQ(accept_entry.state, SS_CONNECTED); - // Listen entry should still be in SS_UNCONNECTED state. - ASSERT_TRUE(FindBy(entries, &listen_entry, - [&sockets, listen_entry](const UnixEntry& e) { - return e.path == ExtractPath(sockets->first_addr()) && - e.inode == listen_entry.inode; - })); - EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); -} - -TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE( - AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); - - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - - // On gVisor, the only two UDS on the system are the ones we just created and - // we rely on this to locate the test socket entries in the remainder of the - // test. On a generic Linux system, we have no easy way to locate the - // corresponding entries, as they don't have an address yet. - if (IsRunningOnGvisor()) { - ASSERT_EQ(entries.size(), 2); - for (auto e : entries) { - ASSERT_EQ(e.state, SS_DISCONNECTING); - } - } - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - const std::string address = ExtractPath(sockets->first_addr()); - UnixEntry bind_entry; - ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); - EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); -} - -TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE( - AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); - - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - - // On gVisor, the only two UDS on the system are the ones we just created and - // we rely on this to locate the test socket entries in the remainder of the - // test. On a generic Linux system, we have no easy way to locate the - // corresponding entries, as they don't have an address yet. - if (IsRunningOnGvisor()) { - ASSERT_EQ(entries.size(), 2); - for (auto e : entries) { - ASSERT_EQ(e.state, SS_DISCONNECTING); - } - } - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - const std::string address = ExtractPath(sockets->first_addr()); - UnixEntry bind_entry; - ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); - - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - - // Once again, we have no easy way to identify the connecting socket as it has - // no listed address. We can only identify the entry as the "non-bind socket - // entry" on gVisor, where we're guaranteed to have only the two entries we - // create during this test. - if (IsRunningOnGvisor()) { - ASSERT_EQ(entries.size(), 2); - UnixEntry connect_entry; - ASSERT_TRUE( - FindBy(entries, &connect_entry, [bind_entry](const UnixEntry& e) { - return e.inode != bind_entry.inode; - })); - EXPECT_EQ(connect_entry.state, SS_CONNECTING); - } -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/proc_pid_oomscore.cc b/test/syscalls/linux/proc_pid_oomscore.cc deleted file mode 100644 index 707821a3f..000000000 --- a/test/syscalls/linux/proc_pid_oomscore.cc +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> - -#include <exception> -#include <iostream> -#include <string> - -#include "test/util/fs_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -PosixErrorOr<int> ReadProcNumber(std::string path) { - ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents(path)); - EXPECT_EQ(contents[contents.length() - 1], '\n'); - - int num; - if (!absl::SimpleAtoi(contents, &num)) { - return PosixError(EINVAL, "invalid value: " + contents); - } - - return num; -} - -TEST(ProcPidOomscoreTest, BasicRead) { - auto const oom_score = - ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score")); - EXPECT_LE(oom_score, 1000); - EXPECT_GE(oom_score, -1000); -} - -TEST(ProcPidOomscoreAdjTest, BasicRead) { - auto const oom_score = - ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj")); - - // oom_score_adj defaults to 0. - EXPECT_EQ(oom_score, 0); -} - -TEST(ProcPidOomscoreAdjTest, BasicWrite) { - constexpr int test_value = 7; - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/oom_score_adj", O_WRONLY)); - ASSERT_THAT( - RetryEINTR(write)(fd.get(), std::to_string(test_value).c_str(), 1), - SyscallSucceeds()); - - auto const oom_score = - ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj")); - EXPECT_EQ(oom_score, test_value); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc deleted file mode 100644 index 7f2e8f203..000000000 --- a/test/syscalls/linux/proc_pid_smaps.cc +++ /dev/null @@ -1,468 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stddef.h> -#include <stdint.h> - -#include <algorithm> -#include <iostream> -#include <string> -#include <utility> -#include <vector> - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/memory_util.h" -#include "test/util/posix_error.h" -#include "test/util/proc_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -using ::testing::Contains; -using ::testing::ElementsAreArray; -using ::testing::IsSupersetOf; -using ::testing::Not; -using ::testing::Optional; - -namespace gvisor { -namespace testing { - -namespace { - -struct ProcPidSmapsEntry { - ProcMapsEntry maps_entry; - - // These fields should always exist, as they were included in e070ad49f311 - // "[PATCH] add /proc/pid/smaps". - size_t size_kb; - size_t rss_kb; - size_t shared_clean_kb; - size_t shared_dirty_kb; - size_t private_clean_kb; - size_t private_dirty_kb; - - // These fields were added later and may not be present. - absl::optional<size_t> pss_kb; - absl::optional<size_t> referenced_kb; - absl::optional<size_t> anonymous_kb; - absl::optional<size_t> anon_huge_pages_kb; - absl::optional<size_t> shared_hugetlb_kb; - absl::optional<size_t> private_hugetlb_kb; - absl::optional<size_t> swap_kb; - absl::optional<size_t> swap_pss_kb; - absl::optional<size_t> kernel_page_size_kb; - absl::optional<size_t> mmu_page_size_kb; - absl::optional<size_t> locked_kb; - - // Caution: "Note that there is no guarantee that every flag and associated - // mnemonic will be present in all further kernel releases. Things get - // changed, the flags may be vanished or the reverse -- new added." - Linux - // Documentation/filesystems/proc.txt, on VmFlags. Avoid checking for any - // flags that are not extremely well-established. - absl::optional<std::vector<std::string>> vm_flags; -}; - -// Given the value part of a /proc/[pid]/smaps field containing a value in kB -// (for example, " 4 kB", returns the value in kB (in this example, 4). -PosixErrorOr<size_t> SmapsValueKb(absl::string_view value) { - // TODO(jamieliu): let us use RE2 or <regex> - std::pair<absl::string_view, absl::string_view> parts = - absl::StrSplit(value, ' ', absl::SkipEmpty()); - if (parts.second != "kB") { - return PosixError(EINVAL, - absl::StrCat("invalid smaps field value: ", value)); - } - ASSIGN_OR_RETURN_ERRNO(auto val_kb, Atoi<size_t>(parts.first)); - return val_kb; -} - -PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( - absl::string_view contents) { - std::vector<ProcPidSmapsEntry> entries; - absl::optional<ProcPidSmapsEntry> entry; - bool have_size_kb = false; - bool have_rss_kb = false; - bool have_shared_clean_kb = false; - bool have_shared_dirty_kb = false; - bool have_private_clean_kb = false; - bool have_private_dirty_kb = false; - - auto const finish_entry = [&] { - if (entry) { - if (!have_size_kb) { - return PosixError(EINVAL, "smaps entry is missing Size"); - } - if (!have_rss_kb) { - return PosixError(EINVAL, "smaps entry is missing Rss"); - } - if (!have_shared_clean_kb) { - return PosixError(EINVAL, "smaps entry is missing Shared_Clean"); - } - if (!have_shared_dirty_kb) { - return PosixError(EINVAL, "smaps entry is missing Shared_Dirty"); - } - if (!have_private_clean_kb) { - return PosixError(EINVAL, "smaps entry is missing Private_Clean"); - } - if (!have_private_dirty_kb) { - return PosixError(EINVAL, "smaps entry is missing Private_Dirty"); - } - // std::move(entry.value()) instead of std::move(entry).value(), because - // otherwise tools may report a "use-after-move" warning, which is - // spurious because entry.emplace() below resets entry to a new - // ProcPidSmapsEntry. - entries.emplace_back(std::move(entry.value())); - } - entry.emplace(); - have_size_kb = false; - have_rss_kb = false; - have_shared_clean_kb = false; - have_shared_dirty_kb = false; - have_private_clean_kb = false; - have_private_dirty_kb = false; - return NoError(); - }; - - // Holds key/value pairs from smaps field lines. Declared here so it can be - // captured by reference by the following lambdas. - std::vector<absl::string_view> key_value; - - auto const on_required_field_kb = [&](size_t* field, bool* have_field) { - if (*have_field) { - return PosixError( - EINVAL, - absl::StrFormat("smaps entry has duplicate %s line", key_value[0])); - } - ASSIGN_OR_RETURN_ERRNO(*field, SmapsValueKb(key_value[1])); - *have_field = true; - return NoError(); - }; - - auto const on_optional_field_kb = [&](absl::optional<size_t>* field) { - if (*field) { - return PosixError( - EINVAL, - absl::StrFormat("smaps entry has duplicate %s line", key_value[0])); - } - ASSIGN_OR_RETURN_ERRNO(*field, SmapsValueKb(key_value[1])); - return NoError(); - }; - - absl::flat_hash_set<std::string> unknown_fields; - auto const on_unknown_field = [&] { - absl::string_view key = key_value[0]; - // Don't mention unknown fields more than once. - if (unknown_fields.count(key)) { - return; - } - unknown_fields.insert(std::string(key)); - std::cerr << "skipping unknown smaps field " << key; - }; - - auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty()); - for (absl::string_view l : lines) { - // Is this line a valid /proc/[pid]/maps entry? - auto maybe_maps_entry = ParseProcMapsLine(l); - if (maybe_maps_entry.ok()) { - // This marks the beginning of a new /proc/[pid]/smaps entry. - RETURN_IF_ERRNO(finish_entry()); - entry->maps_entry = std::move(maybe_maps_entry).ValueOrDie(); - continue; - } - // Otherwise it's a field in an existing /proc/[pid]/smaps entry of the form - // "key:value" (where value in practice will be preceded by a variable - // amount of whitespace). - if (!entry) { - std::cerr << "smaps line not considered a maps line: " - << maybe_maps_entry.error_message(); - return PosixError( - EINVAL, - absl::StrCat("smaps field line without preceding maps line: ", l)); - } - key_value = absl::StrSplit(l, absl::MaxSplits(':', 1)); - if (key_value.size() != 2) { - return PosixError(EINVAL, absl::StrCat("invalid smaps field line: ", l)); - } - absl::string_view const key = key_value[0]; - if (key == "Size") { - RETURN_IF_ERRNO(on_required_field_kb(&entry->size_kb, &have_size_kb)); - } else if (key == "Rss") { - RETURN_IF_ERRNO(on_required_field_kb(&entry->rss_kb, &have_rss_kb)); - } else if (key == "Shared_Clean") { - RETURN_IF_ERRNO( - on_required_field_kb(&entry->shared_clean_kb, &have_shared_clean_kb)); - } else if (key == "Shared_Dirty") { - RETURN_IF_ERRNO( - on_required_field_kb(&entry->shared_dirty_kb, &have_shared_dirty_kb)); - } else if (key == "Private_Clean") { - RETURN_IF_ERRNO(on_required_field_kb(&entry->private_clean_kb, - &have_private_clean_kb)); - } else if (key == "Private_Dirty") { - RETURN_IF_ERRNO(on_required_field_kb(&entry->private_dirty_kb, - &have_private_dirty_kb)); - } else if (key == "Pss") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->pss_kb)); - } else if (key == "Referenced") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->referenced_kb)); - } else if (key == "Anonymous") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->anonymous_kb)); - } else if (key == "AnonHugePages") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->anon_huge_pages_kb)); - } else if (key == "Shared_Hugetlb") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->shared_hugetlb_kb)); - } else if (key == "Private_Hugetlb") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->private_hugetlb_kb)); - } else if (key == "Swap") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->swap_kb)); - } else if (key == "SwapPss") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->swap_pss_kb)); - } else if (key == "KernelPageSize") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->kernel_page_size_kb)); - } else if (key == "MMUPageSize") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->mmu_page_size_kb)); - } else if (key == "Locked") { - RETURN_IF_ERRNO(on_optional_field_kb(&entry->locked_kb)); - } else if (key == "VmFlags") { - if (entry->vm_flags) { - return PosixError(EINVAL, "duplicate VmFlags line"); - } - entry->vm_flags = absl::StrSplit(key_value[1], ' ', absl::SkipEmpty()); - } else { - on_unknown_field(); - } - } - RETURN_IF_ERRNO(finish_entry()); - return entries; -}; - -TEST(ParseProcPidSmapsTest, Correctness) { - auto entries = ASSERT_NO_ERRNO_AND_VALUE( - ParseProcPidSmaps("0-10000 rw-s 00000000 00:00 0 " - " /dev/zero (deleted)\n" - "Size: 0 kB\n" - "Rss: 1 kB\n" - "Pss: 2 kB\n" - "Shared_Clean: 3 kB\n" - "Shared_Dirty: 4 kB\n" - "Private_Clean: 5 kB\n" - "Private_Dirty: 6 kB\n" - "Referenced: 7 kB\n" - "Anonymous: 8 kB\n" - "AnonHugePages: 9 kB\n" - "Shared_Hugetlb: 10 kB\n" - "Private_Hugetlb: 11 kB\n" - "Swap: 12 kB\n" - "SwapPss: 13 kB\n" - "KernelPageSize: 14 kB\n" - "MMUPageSize: 15 kB\n" - "Locked: 16 kB\n" - "FutureUnknownKey: 17 kB\n" - "VmFlags: rd wr sh mr mw me ms lo ?? sd \n")); - ASSERT_EQ(entries.size(), 1); - auto& entry = entries[0]; - EXPECT_EQ(entry.maps_entry.filename, "/dev/zero (deleted)"); - EXPECT_EQ(entry.size_kb, 0); - EXPECT_EQ(entry.rss_kb, 1); - EXPECT_THAT(entry.pss_kb, Optional(2)); - EXPECT_EQ(entry.shared_clean_kb, 3); - EXPECT_EQ(entry.shared_dirty_kb, 4); - EXPECT_EQ(entry.private_clean_kb, 5); - EXPECT_EQ(entry.private_dirty_kb, 6); - EXPECT_THAT(entry.referenced_kb, Optional(7)); - EXPECT_THAT(entry.anonymous_kb, Optional(8)); - EXPECT_THAT(entry.anon_huge_pages_kb, Optional(9)); - EXPECT_THAT(entry.shared_hugetlb_kb, Optional(10)); - EXPECT_THAT(entry.private_hugetlb_kb, Optional(11)); - EXPECT_THAT(entry.swap_kb, Optional(12)); - EXPECT_THAT(entry.swap_pss_kb, Optional(13)); - EXPECT_THAT(entry.kernel_page_size_kb, Optional(14)); - EXPECT_THAT(entry.mmu_page_size_kb, Optional(15)); - EXPECT_THAT(entry.locked_kb, Optional(16)); - EXPECT_THAT(entry.vm_flags, - Optional(ElementsAreArray({"rd", "wr", "sh", "mr", "mw", "me", - "ms", "lo", "??", "sd"}))); -} - -// Returns the unique entry in entries containing the given address. -PosixErrorOr<ProcPidSmapsEntry> FindUniqueSmapsEntry( - std::vector<ProcPidSmapsEntry> const& entries, uintptr_t addr) { - auto const pred = [&](ProcPidSmapsEntry const& entry) { - return entry.maps_entry.start <= addr && addr < entry.maps_entry.end; - }; - auto const it = std::find_if(entries.begin(), entries.end(), pred); - if (it == entries.end()) { - return PosixError(EINVAL, - absl::StrFormat("no entry contains address %#x", addr)); - } - auto const it2 = std::find_if(it + 1, entries.end(), pred); - if (it2 != entries.end()) { - return PosixError( - EINVAL, - absl::StrFormat("overlapping entries [%#x-%#x) and [%#x-%#x) both " - "contain address %#x", - it->maps_entry.start, it->maps_entry.end, - it2->maps_entry.start, it2->maps_entry.end, addr)); - } - return *it; -} - -PosixErrorOr<std::vector<ProcPidSmapsEntry>> ReadProcSelfSmaps() { - ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents("/proc/self/smaps")); - return ParseProcPidSmaps(contents); -} - -TEST(ProcPidSmapsTest, SharedAnon) { - // Map with MAP_POPULATE so we get some RSS. - Mapping const m = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon( - 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE)); - auto const entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfSmaps()); - auto const entry = - ASSERT_NO_ERRNO_AND_VALUE(FindUniqueSmapsEntry(entries, m.addr())); - - EXPECT_EQ(entry.size_kb, m.len() / 1024); - // It's possible that populated pages have been swapped out, so RSS might be - // less than size. - EXPECT_LE(entry.rss_kb, entry.size_kb); - - if (entry.pss_kb) { - // PSS should be exactly equal to RSS since no other address spaces should - // be sharing our new mapping. - EXPECT_EQ(entry.pss_kb.value(), entry.rss_kb); - } - - // "Shared" and "private" in smaps refers to whether or not *physical pages* - // are shared; thus all pages in our MAP_SHARED mapping should nevertheless - // be private. - EXPECT_EQ(entry.shared_clean_kb, 0); - EXPECT_EQ(entry.shared_dirty_kb, 0); - EXPECT_EQ(entry.private_clean_kb + entry.private_dirty_kb, entry.rss_kb) - << "Private_Clean = " << entry.private_clean_kb - << " kB, Private_Dirty = " << entry.private_dirty_kb << " kB"; - - // Shared anonymous mappings are implemented as a shmem file, so their pages - // are not PageAnon. - if (entry.anonymous_kb) { - EXPECT_EQ(entry.anonymous_kb.value(), 0); - } - - if (entry.vm_flags) { - EXPECT_THAT(entry.vm_flags.value(), - IsSupersetOf({"rd", "wr", "sh", "mr", "mw", "me", "ms"})); - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("ex"))); - } -} - -TEST(ProcPidSmapsTest, PrivateAnon) { - // Map with MAP_POPULATE so we get some RSS. - Mapping const m = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(2 * kPageSize, PROT_WRITE, MAP_PRIVATE | MAP_POPULATE)); - auto const entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfSmaps()); - auto const entry = - ASSERT_NO_ERRNO_AND_VALUE(FindUniqueSmapsEntry(entries, m.addr())); - - // It's possible that our mapping was merged with another vma, so the smaps - // entry might be bigger than our original mapping. - EXPECT_GE(entry.size_kb, m.len() / 1024); - EXPECT_LE(entry.rss_kb, entry.size_kb); - if (entry.pss_kb) { - EXPECT_LE(entry.pss_kb.value(), entry.rss_kb); - } - - if (entry.anonymous_kb) { - EXPECT_EQ(entry.anonymous_kb.value(), entry.rss_kb); - } - - if (entry.vm_flags) { - EXPECT_THAT(entry.vm_flags.value(), IsSupersetOf({"wr", "mr", "mw", "me"})); - // We passed PROT_WRITE to mmap. On at least x86, the mapping is in - // practice readable because there is no way to configure the MMU to make - // pages writable but not readable. However, VmFlags should reflect the - // flags set on the VMA, so "rd" (VM_READ) should not appear in VmFlags. - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("rd"))); - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("ex"))); - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("sh"))); - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("ms"))); - } -} - -TEST(ProcPidSmapsTest, SharedReadOnlyFile) { - size_t const kFileSize = kPageSize; - - auto const temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_THAT(truncate(temp_file.path().c_str(), kFileSize), SyscallSucceeds()); - auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDONLY)); - - auto const m = ASSERT_NO_ERRNO_AND_VALUE(Mmap( - nullptr, kFileSize, PROT_READ, MAP_SHARED | MAP_POPULATE, fd.get(), 0)); - auto const entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfSmaps()); - auto const entry = - ASSERT_NO_ERRNO_AND_VALUE(FindUniqueSmapsEntry(entries, m.addr())); - - // Most of the same logic as the SharedAnon case applies. - EXPECT_EQ(entry.size_kb, kFileSize / 1024); - EXPECT_LE(entry.rss_kb, entry.size_kb); - if (entry.pss_kb) { - EXPECT_EQ(entry.pss_kb.value(), entry.rss_kb); - } - EXPECT_EQ(entry.shared_clean_kb, 0); - EXPECT_EQ(entry.shared_dirty_kb, 0); - EXPECT_EQ(entry.private_clean_kb + entry.private_dirty_kb, entry.rss_kb) - << "Private_Clean = " << entry.private_clean_kb - << " kB, Private_Dirty = " << entry.private_dirty_kb << " kB"; - if (entry.anonymous_kb) { - EXPECT_EQ(entry.anonymous_kb.value(), 0); - } - - if (entry.vm_flags) { - EXPECT_THAT(entry.vm_flags.value(), IsSupersetOf({"rd", "mr", "me", "ms"})); - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("wr"))); - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("ex"))); - // Because the mapped file was opened O_RDONLY, the VMA is !VM_MAYWRITE and - // also !VM_SHARED. - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("sh"))); - EXPECT_THAT(entry.vm_flags.value(), Not(Contains("mw"))); - } -} - -// Tests that gVisor's /proc/[pid]/smaps provides all of the fields we expect it -// to, which as of this writing is all fields provided by Linux 4.4. -TEST(ProcPidSmapsTest, GvisorFields) { - SKIP_IF(!IsRunningOnGvisor()); - auto const entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfSmaps()); - for (auto const& entry : entries) { - EXPECT_TRUE(entry.pss_kb); - EXPECT_TRUE(entry.referenced_kb); - EXPECT_TRUE(entry.anonymous_kb); - EXPECT_TRUE(entry.anon_huge_pages_kb); - EXPECT_TRUE(entry.shared_hugetlb_kb); - EXPECT_TRUE(entry.private_hugetlb_kb); - EXPECT_TRUE(entry.swap_kb); - EXPECT_TRUE(entry.swap_pss_kb); - EXPECT_THAT(entry.kernel_page_size_kb, Optional(kPageSize / 1024)); - EXPECT_THAT(entry.mmu_page_size_kb, Optional(kPageSize / 1024)); - EXPECT_TRUE(entry.locked_kb); - EXPECT_TRUE(entry.vm_flags); - } -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/proc_pid_uid_gid_map.cc b/test/syscalls/linux/proc_pid_uid_gid_map.cc deleted file mode 100644 index 748f7be58..000000000 --- a/test/syscalls/linux/proc_pid_uid_gid_map.cc +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sched.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include <functional> -#include <string> -#include <tuple> -#include <utility> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/ascii.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/logging.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" -#include "test/util/test_util.h" -#include "test/util/time_util.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<int> InNewUserNamespace(const std::function<void()>& fn) { - return InForkedProcess([&] { - TEST_PCHECK(unshare(CLONE_NEWUSER) == 0); - MaybeSave(); - fn(); - }); -} - -PosixErrorOr<std::tuple<pid_t, Cleanup>> CreateProcessInNewUserNamespace() { - int pipefd[2]; - if (pipe(pipefd) < 0) { - return PosixError(errno, "pipe failed"); - } - const auto cleanup_pipe_read = - Cleanup([&] { EXPECT_THAT(close(pipefd[0]), SyscallSucceeds()); }); - auto cleanup_pipe_write = - Cleanup([&] { EXPECT_THAT(close(pipefd[1]), SyscallSucceeds()); }); - pid_t child_pid = fork(); - if (child_pid < 0) { - return PosixError(errno, "fork failed"); - } - if (child_pid == 0) { - // Close our copy of the pipe's read end, which doesn't really matter. - TEST_PCHECK(close(pipefd[0]) >= 0); - TEST_PCHECK(unshare(CLONE_NEWUSER) == 0); - MaybeSave(); - // Indicate that we've switched namespaces by unblocking the parent's read. - TEST_PCHECK(close(pipefd[1]) >= 0); - while (true) { - SleepSafe(absl::Minutes(1)); - } - } - auto cleanup_child = Cleanup([child_pid] { - EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << "status = " << status; - }); - // Close our copy of the pipe's write end, then wait for the child to close - // its copy, indicating that it's switched namespaces. - cleanup_pipe_write.Release()(); - char buf; - if (RetryEINTR(read)(pipefd[0], &buf, 1) < 0) { - return PosixError(errno, "reading from pipe failed"); - } - MaybeSave(); - return std::make_tuple(child_pid, std::move(cleanup_child)); -} - -// TEST_CHECK-fails on error, since this function is used in contexts that -// require async-signal-safety. -void DenySetgroupsByPath(const char* path) { - int fd = open(path, O_WRONLY); - if (fd < 0 && errno == ENOENT) { - // On kernels where this file doesn't exist, writing "deny" to it isn't - // necessary to write to gid_map. - return; - } - TEST_PCHECK(fd >= 0); - MaybeSave(); - char deny[] = "deny"; - TEST_PCHECK(write(fd, deny, sizeof(deny)) == sizeof(deny)); - MaybeSave(); - TEST_PCHECK(close(fd) == 0); -} - -void DenySelfSetgroups() { DenySetgroupsByPath("/proc/self/setgroups"); } - -void DenyPidSetgroups(pid_t pid) { - DenySetgroupsByPath(absl::StrCat("/proc/", pid, "/setgroups").c_str()); -} - -// Returns a valid UID/GID that isn't id. -uint32_t another_id(uint32_t id) { return (id + 1) % 65535; } - -struct TestParam { - std::string desc; - int cap; - std::function<std::string(absl::string_view)> get_map_filename; - std::function<uint32_t()> get_current_id; -}; - -std::string DescribeTestParam(const ::testing::TestParamInfo<TestParam>& info) { - return info.param.desc; -} - -std::vector<TestParam> UidGidMapTestParams() { - return {TestParam{"UID", CAP_SETUID, - [](absl::string_view pid) { - return absl::StrCat("/proc/", pid, "/uid_map"); - }, - []() -> uint32_t { return getuid(); }}, - TestParam{"GID", CAP_SETGID, - [](absl::string_view pid) { - return absl::StrCat("/proc/", pid, "/gid_map"); - }, - []() -> uint32_t { return getgid(); }}}; -} - -class ProcUidGidMapTest : public ::testing::TestWithParam<TestParam> { - protected: - uint32_t CurrentID() { return GetParam().get_current_id(); } -}; - -class ProcSelfUidGidMapTest : public ProcUidGidMapTest { - protected: - PosixErrorOr<int> InNewUserNamespaceWithMapFD( - const std::function<void(int)>& fn) { - std::string map_filename = GetParam().get_map_filename("self"); - return InNewUserNamespace([&] { - int fd = open(map_filename.c_str(), O_RDWR); - TEST_PCHECK(fd >= 0); - MaybeSave(); - fn(fd); - TEST_PCHECK(close(fd) == 0); - }); - } -}; - -class ProcPidUidGidMapTest : public ProcUidGidMapTest { - protected: - PosixErrorOr<bool> HaveSetIDCapability() { - return HaveCapability(GetParam().cap); - } - - // Returns true if the caller is running in a user namespace with all IDs - // mapped. This matters for tests that expect to successfully map arbitrary - // IDs into a child user namespace, since even with CAP_SET*ID this is only - // possible if those IDs are mapped into the current one. - PosixErrorOr<bool> AllIDsMapped() { - ASSIGN_OR_RETURN_ERRNO(std::string id_map, - GetContents(GetParam().get_map_filename("self"))); - absl::StripTrailingAsciiWhitespace(&id_map); - std::vector<std::string> id_map_parts = - absl::StrSplit(id_map, ' ', absl::SkipEmpty()); - return id_map_parts == std::vector<std::string>({"0", "0", "4294967295"}); - } - - PosixErrorOr<FileDescriptor> OpenMapFile(pid_t pid) { - return Open(GetParam().get_map_filename(absl::StrCat(pid)), O_RDWR); - } -}; - -TEST_P(ProcSelfUidGidMapTest, IsInitiallyEmpty) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - EXPECT_THAT(InNewUserNamespaceWithMapFD([](int fd) { - char buf[64]; - TEST_PCHECK(read(fd, buf, sizeof(buf)) == 0); - }), - IsPosixErrorOkAndHolds(0)); -} - -TEST_P(ProcSelfUidGidMapTest, IdentityMapOwnID) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - uint32_t id = CurrentID(); - std::string line = absl::StrCat(id, " ", id, " 1"); - EXPECT_THAT( - InNewUserNamespaceWithMapFD([&](int fd) { - DenySelfSetgroups(); - TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size()); - }), - IsPosixErrorOkAndHolds(0)); -} - -TEST_P(ProcSelfUidGidMapTest, TrailingNewlineAndNULIgnored) { - // This is identical to IdentityMapOwnID, except that a trailing newline, NUL, - // and an invalid (incomplete) map entry are appended to the valid entry. The - // newline should be accepted, and everything after the NUL should be ignored. - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - uint32_t id = CurrentID(); - std::string line = absl::StrCat(id, " ", id, " 1\n\0 4 3"); - EXPECT_THAT( - InNewUserNamespaceWithMapFD([&](int fd) { - DenySelfSetgroups(); - // The write should return the full size of the write, even though - // characters after the NUL were ignored. - TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size()); - }), - IsPosixErrorOkAndHolds(0)); -} - -TEST_P(ProcSelfUidGidMapTest, NonIdentityMapOwnID) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - uint32_t id = CurrentID(); - uint32_t id2 = another_id(id); - std::string line = absl::StrCat(id2, " ", id, " 1"); - EXPECT_THAT( - InNewUserNamespaceWithMapFD([&](int fd) { - DenySelfSetgroups(); - TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size()); - }), - IsPosixErrorOkAndHolds(0)); -} - -TEST_P(ProcSelfUidGidMapTest, MapOtherID) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - // Whether or not we have CAP_SET*ID is irrelevant: the process running in the - // new (child) user namespace won't have any capabilities in the current - // (parent) user namespace, which is needed. - uint32_t id = CurrentID(); - uint32_t id2 = another_id(id); - std::string line = absl::StrCat(id, " ", id2, " 1"); - EXPECT_THAT(InNewUserNamespaceWithMapFD([&](int fd) { - DenySelfSetgroups(); - TEST_PCHECK(write(fd, line.c_str(), line.size()) < 0); - TEST_CHECK(errno == EPERM); - }), - IsPosixErrorOkAndHolds(0)); -} - -INSTANTIATE_TEST_SUITE_P(All, ProcSelfUidGidMapTest, - ::testing::ValuesIn(UidGidMapTestParams()), - DescribeTestParam); - -TEST_P(ProcPidUidGidMapTest, MapOtherIDPrivileged) { - // Like ProcSelfUidGidMapTest_MapOtherID, but since we have CAP_SET*ID in the - // parent user namespace (this one), we can map IDs that aren't ours. - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveSetIDCapability())); - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(AllIDsMapped())); - - pid_t child_pid; - Cleanup cleanup_child; - std::tie(child_pid, cleanup_child) = - ASSERT_NO_ERRNO_AND_VALUE(CreateProcessInNewUserNamespace()); - - uint32_t id = CurrentID(); - uint32_t id2 = another_id(id); - std::string line = absl::StrCat(id, " ", id2, " 1"); - DenyPidSetgroups(child_pid); - auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenMapFile(child_pid)); - EXPECT_THAT(write(fd.get(), line.c_str(), line.size()), - SyscallSucceedsWithValue(line.size())); -} - -TEST_P(ProcPidUidGidMapTest, MapAnyIDsPrivileged) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace())); - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveSetIDCapability())); - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(AllIDsMapped())); - - pid_t child_pid; - Cleanup cleanup_child; - std::tie(child_pid, cleanup_child) = - ASSERT_NO_ERRNO_AND_VALUE(CreateProcessInNewUserNamespace()); - - // Test all of: - // - // - Mapping ranges of length > 1 - // - // - Mapping multiple ranges - // - // - Non-identity mappings - char entries[] = "2 0 2\n4 6 2"; - DenyPidSetgroups(child_pid); - auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenMapFile(child_pid)); - EXPECT_THAT(write(fd.get(), entries, sizeof(entries)), - SyscallSucceedsWithValue(sizeof(entries))); -} - -INSTANTIATE_TEST_SUITE_P(All, ProcPidUidGidMapTest, - ::testing::ValuesIn(UidGidMapTestParams()), - DescribeTestParam); - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/pselect.cc b/test/syscalls/linux/pselect.cc deleted file mode 100644 index 4e43c4d7f..000000000 --- a/test/syscalls/linux/pselect.cc +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <sys/select.h> - -#include "gtest/gtest.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/base_poll_test.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -struct MaskWithSize { - sigset_t* mask; - size_t mask_size; -}; - -// Linux and glibc have a different idea of the sizeof sigset_t. When calling -// the syscall directly, use what the kernel expects. -unsigned kSigsetSize = SIGRTMAX / 8; - -// Linux pselect(2) differs from the glibc wrapper function in that Linux -// updates the timeout with the amount of time remaining. In order to test this -// behavior we need to use the syscall directly. -int syscallPselect6(int nfds, fd_set* readfds, fd_set* writefds, - fd_set* exceptfds, struct timespec* timeout, - const MaskWithSize* mask_with_size) { - return syscall(SYS_pselect6, nfds, readfds, writefds, exceptfds, timeout, - mask_with_size); -} - -class PselectTest : public BasePollTest { - protected: - void SetUp() override { BasePollTest::SetUp(); } - void TearDown() override { BasePollTest::TearDown(); } -}; - -// See that when there are no FD sets, pselect behaves like sleep. -TEST_F(PselectTest, NullFds) { - struct timespec timeout = absl::ToTimespec(absl::Milliseconds(10)); - ASSERT_THAT(syscallPselect6(0, nullptr, nullptr, nullptr, &timeout, nullptr), - SyscallSucceeds()); - EXPECT_EQ(timeout.tv_sec, 0); - EXPECT_EQ(timeout.tv_nsec, 0); - - timeout = absl::ToTimespec(absl::Milliseconds(10)); - ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, nullptr), - SyscallSucceeds()); - EXPECT_EQ(timeout.tv_sec, 0); - EXPECT_EQ(timeout.tv_nsec, 0); -} - -TEST_F(PselectTest, ClosedFds) { - fd_set read_set; - FD_ZERO(&read_set); - int fd; - ASSERT_THAT(fd = dup(1), SyscallSucceeds()); - ASSERT_THAT(close(fd), SyscallSucceeds()); - FD_SET(fd, &read_set); - struct timespec timeout = absl::ToTimespec(absl::Milliseconds(10)); - EXPECT_THAT( - syscallPselect6(fd + 1, &read_set, nullptr, nullptr, &timeout, nullptr), - SyscallFailsWithErrno(EBADF)); -} - -TEST_F(PselectTest, ZeroTimeout) { - struct timespec timeout = {}; - ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, nullptr), - SyscallSucceeds()); - EXPECT_EQ(timeout.tv_sec, 0); - EXPECT_EQ(timeout.tv_nsec, 0); -} - -// If random S/R interrupts the pselect, SIGALRM may be delivered before pselect -// restarts, causing the pselect to hang forever. -TEST_F(PselectTest, NoTimeout_NoRandomSave) { - // When there's no timeout, pselect may never return so set a timer. - SetTimer(absl::Milliseconds(100)); - // See that we get interrupted by the timer. - ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, nullptr, nullptr), - SyscallFailsWithErrno(EINTR)); - EXPECT_TRUE(TimerFired()); -} - -TEST_F(PselectTest, InvalidTimeoutNegative) { - struct timespec timeout = absl::ToTimespec(absl::Seconds(-1)); - ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, nullptr), - SyscallFailsWithErrno(EINVAL)); - EXPECT_EQ(timeout.tv_sec, -1); - EXPECT_EQ(timeout.tv_nsec, 0); -} - -TEST_F(PselectTest, InvalidTimeoutNotNormalized) { - struct timespec timeout = {0, 1000000001}; - ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, nullptr), - SyscallFailsWithErrno(EINVAL)); - EXPECT_EQ(timeout.tv_sec, 0); - EXPECT_EQ(timeout.tv_nsec, 1000000001); -} - -TEST_F(PselectTest, EmptySigMaskInvalidMaskSize) { - struct timespec timeout = {}; - MaskWithSize invalid = {nullptr, 7}; - EXPECT_THAT(syscallPselect6(0, nullptr, nullptr, nullptr, &timeout, &invalid), - SyscallSucceeds()); -} - -TEST_F(PselectTest, EmptySigMaskValidMaskSize) { - struct timespec timeout = {}; - MaskWithSize invalid = {nullptr, 8}; - EXPECT_THAT(syscallPselect6(0, nullptr, nullptr, nullptr, &timeout, &invalid), - SyscallSucceeds()); -} - -TEST_F(PselectTest, InvalidMaskSize) { - struct timespec timeout = {}; - sigset_t sigmask; - ASSERT_THAT(sigemptyset(&sigmask), SyscallSucceeds()); - MaskWithSize invalid = {&sigmask, 7}; - EXPECT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, &invalid), - SyscallFailsWithErrno(EINVAL)); -} - -// Verify that signals blocked by the pselect mask (that would otherwise be -// allowed) do not interrupt pselect. -TEST_F(PselectTest, SignalMaskBlocksSignal) { - absl::Duration duration(absl::Seconds(30)); - struct timespec timeout = absl::ToTimespec(duration); - absl::Duration timer_duration(absl::Seconds(10)); - - // Call with a mask that blocks SIGALRM. See that pselect is not interrupted - // (i.e. returns 0) and that upon completion, the timer has fired. - sigset_t mask; - ASSERT_THAT(sigprocmask(0, nullptr, &mask), SyscallSucceeds()); - ASSERT_THAT(sigaddset(&mask, SIGALRM), SyscallSucceeds()); - MaskWithSize mask_with_size = {&mask, kSigsetSize}; - SetTimer(timer_duration); - MaybeSave(); - ASSERT_FALSE(TimerFired()); - ASSERT_THAT( - syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, &mask_with_size), - SyscallSucceeds()); - EXPECT_TRUE(TimerFired()); - EXPECT_EQ(absl::DurationFromTimespec(timeout), absl::Duration()); -} - -// Verify that signals allowed by the pselect mask (that would otherwise be -// blocked) interrupt pselect. -TEST_F(PselectTest, SignalMaskAllowsSignal) { - absl::Duration duration = absl::Seconds(30); - struct timespec timeout = absl::ToTimespec(duration); - absl::Duration timer_duration = absl::Seconds(10); - - sigset_t mask; - ASSERT_THAT(sigprocmask(0, nullptr, &mask), SyscallSucceeds()); - - // Block SIGALRM. - auto cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGALRM)); - - // Call with a mask that unblocks SIGALRM. See that pselect is interrupted. - MaskWithSize mask_with_size = {&mask, kSigsetSize}; - SetTimer(timer_duration); - MaybeSave(); - ASSERT_FALSE(TimerFired()); - ASSERT_THAT( - syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, &mask_with_size), - SyscallFailsWithErrno(EINTR)); - EXPECT_TRUE(TimerFired()); - EXPECT_GT(absl::DurationFromTimespec(timeout), absl::Duration()); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc deleted file mode 100644 index bfe3e2603..000000000 --- a/test/syscalls/linux/ptrace.cc +++ /dev/null @@ -1,1212 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <elf.h> -#include <signal.h> -#include <stddef.h> -#include <sys/ptrace.h> -#include <sys/time.h> -#include <sys/types.h> -#include <sys/user.h> -#include <sys/wait.h> -#include <unistd.h> - -#include <iostream> -#include <utility> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/logging.h" -#include "test/util/multiprocess_util.h" -#include "test/util/platform_util.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/time_util.h" - -ABSL_FLAG(bool, ptrace_test_execve_child, false, - "If true, run the " - "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_" - "TraceExit child workload."); - -namespace gvisor { -namespace testing { - -namespace { - -// PTRACE_GETSIGMASK and PTRACE_SETSIGMASK are not defined until glibc 2.23 -// (fb53a27c5741 "Add new header definitions from Linux 4.4 (plus older ptrace -// definitions)"). -constexpr auto kPtraceGetSigMask = static_cast<__ptrace_request>(0x420a); -constexpr auto kPtraceSetSigMask = static_cast<__ptrace_request>(0x420b); - -// PTRACE_SYSEMU is not defined until glibc 2.27 (c48831d0eebf "linux/x86: sync -// sys/ptrace.h with Linux 4.14 [BZ #22433]"). -constexpr auto kPtraceSysemu = static_cast<__ptrace_request>(31); - -// PTRACE_EVENT_STOP is not defined until glibc 2.26 (3f67d1a7021e "Add Linux -// PTRACE_EVENT_STOP"). -constexpr int kPtraceEventStop = 128; - -// Sends sig to the current process with tgkill(2). -// -// glibc's raise(2) may change the signal mask before sending the signal. These -// extra syscalls make tests of syscall, signal interception, etc. difficult to -// write. -void RaiseSignal(int sig) { - pid_t pid = getpid(); - TEST_PCHECK(pid > 0); - pid_t tid = gettid(); - TEST_PCHECK(tid > 0); - TEST_PCHECK(tgkill(pid, tid, sig) == 0); -} - -// Returns the Yama ptrace scope. -PosixErrorOr<int> YamaPtraceScope() { - constexpr char kYamaPtraceScopePath[] = "/proc/sys/kernel/yama/ptrace_scope"; - - ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(kYamaPtraceScopePath)); - if (!exists) { - // File doesn't exist means no Yama, so the scope is disabled -> 0. - return 0; - } - - std::string contents; - RETURN_IF_ERRNO(GetContents(kYamaPtraceScopePath, &contents)); - - int scope; - if (!absl::SimpleAtoi(contents, &scope)) { - return PosixError(EINVAL, absl::StrCat(contents, ": not a valid number")); - } - - return scope; -} - -TEST(PtraceTest, AttachSelf) { - EXPECT_THAT(ptrace(PTRACE_ATTACH, gettid(), 0, 0), - SyscallFailsWithErrno(EPERM)); -} - -TEST(PtraceTest, AttachSameThreadGroup) { - pid_t const tid = gettid(); - ScopedThread([&] { - EXPECT_THAT(ptrace(PTRACE_ATTACH, tid, 0, 0), SyscallFailsWithErrno(EPERM)); - }); -} - -TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) { - // Yama prevents attaching to a parent. Skip the test if the scope is anything - // except disabled. - SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) > 0); - - constexpr long kBeforePokeDataValue = 10; - constexpr long kAfterPokeDataValue = 20; - - volatile long word = kBeforePokeDataValue; - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - - // Attach to the parent. - pid_t const parent_pid = getppid(); - TEST_PCHECK(ptrace(PTRACE_ATTACH, parent_pid, 0, 0) == 0); - MaybeSave(); - - // Block until the parent enters signal-delivery-stop as a result of the - // SIGSTOP sent by PTRACE_ATTACH. - int status; - TEST_PCHECK(waitpid(parent_pid, &status, 0) == parent_pid); - MaybeSave(); - TEST_CHECK(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP); - - // Replace the value of word in the parent process with kAfterPokeDataValue. - long const parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, &word, 0); - MaybeSave(); - TEST_CHECK(parent_word == kBeforePokeDataValue); - TEST_PCHECK( - ptrace(PTRACE_POKEDATA, parent_pid, &word, kAfterPokeDataValue) == 0); - MaybeSave(); - - // Detach from the parent and suppress the SIGSTOP. If the SIGSTOP is not - // suppressed, the parent will hang in group-stop, causing the test to time - // out. - TEST_PCHECK(ptrace(PTRACE_DETACH, parent_pid, 0, 0) == 0); - MaybeSave(); - _exit(0); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to complete. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; - - // Check that the child's PTRACE_POKEDATA was effective. - EXPECT_EQ(kAfterPokeDataValue, word); -} - -TEST(PtraceTest, GetSigMask) { - // glibc and the Linux kernel define a sigset_t with different sizes. To avoid - // creating a kernel_sigset_t and recreating all the modification functions - // (sigemptyset, etc), we just hardcode the kernel sigset size. - constexpr int kSizeofKernelSigset = 8; - constexpr int kBlockSignal = SIGUSR1; - sigset_t blocked; - sigemptyset(&blocked); - sigaddset(&blocked, kBlockSignal); - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - - // Install a signal handler for kBlockSignal to avoid termination and block - // it. - TEST_PCHECK(signal( - kBlockSignal, +[](int signo) {}) != SIG_ERR); - MaybeSave(); - TEST_PCHECK(sigprocmask(SIG_SETMASK, &blocked, nullptr) == 0); - MaybeSave(); - - // Enable tracing. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - - // This should be blocked. - RaiseSignal(kBlockSignal); - - // This should be suppressed by parent, who will change signal mask in the - // meantime, which means kBlockSignal should be delivered once this resumes. - RaiseSignal(SIGSTOP); - - _exit(0); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Get current signal mask. - sigset_t set; - EXPECT_THAT(ptrace(kPtraceGetSigMask, child_pid, kSizeofKernelSigset, &set), - SyscallSucceeds()); - EXPECT_THAT(blocked, EqualsSigset(set)); - - // Try to get current signal mask with bad size argument. - EXPECT_THAT(ptrace(kPtraceGetSigMask, child_pid, 0, nullptr), - SyscallFailsWithErrno(EINVAL)); - - // Try to set bad signal mask. - sigset_t* bad_addr = reinterpret_cast<sigset_t*>(-1); - EXPECT_THAT( - ptrace(kPtraceSetSigMask, child_pid, kSizeofKernelSigset, bad_addr), - SyscallFailsWithErrno(EFAULT)); - - // Set signal mask to empty set. - sigset_t set1; - sigemptyset(&set1); - EXPECT_THAT(ptrace(kPtraceSetSigMask, child_pid, kSizeofKernelSigset, &set1), - SyscallSucceeds()); - - // Suppress SIGSTOP and resume the child. It should re-enter - // signal-delivery-stop for kBlockSignal. - ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == kBlockSignal) - << " status " << status; - - ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - // Let's see that process exited normally. - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -TEST(PtraceTest, GetSiginfo_SetSiginfo_SignalInjection) { - constexpr int kOriginalSigno = SIGUSR1; - constexpr int kInjectedSigno = SIGUSR2; - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - - // Override all signal handlers. - struct sigaction sa = {}; - sa.sa_handler = +[](int signo) { _exit(signo); }; - TEST_PCHECK(sigfillset(&sa.sa_mask) == 0); - for (int signo = 1; signo < 32; signo++) { - if (signo == SIGKILL || signo == SIGSTOP) { - continue; - } - TEST_PCHECK(sigaction(signo, &sa, nullptr) == 0); - } - for (int signo = SIGRTMIN; signo <= SIGRTMAX; signo++) { - TEST_PCHECK(sigaction(signo, &sa, nullptr) == 0); - } - - // Unblock all signals. - TEST_PCHECK(sigprocmask(SIG_UNBLOCK, &sa.sa_mask, nullptr) == 0); - MaybeSave(); - - // Send ourselves kOriginalSignal while ptraced and exit with the signal we - // actually receive via the signal handler, if any, or 0 if we don't receive - // a signal. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - RaiseSignal(kOriginalSigno); - _exit(0); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself kOriginalSigno and enter - // signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == kOriginalSigno) - << " status " << status; - - siginfo_t siginfo = {}; - ASSERT_THAT(ptrace(PTRACE_GETSIGINFO, child_pid, 0, &siginfo), - SyscallSucceeds()); - EXPECT_EQ(kOriginalSigno, siginfo.si_signo); - EXPECT_EQ(SI_TKILL, siginfo.si_code); - - // Replace the signal with kInjectedSigno, and check that the child exits - // with kInjectedSigno, indicating that signal injection was successful. - siginfo.si_signo = kInjectedSigno; - ASSERT_THAT(ptrace(PTRACE_SETSIGINFO, child_pid, 0, &siginfo), - SyscallSucceeds()); - ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, kInjectedSigno), - SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == kInjectedSigno) - << " status " << status; -} - -TEST(PtraceTest, SIGKILLDoesNotCauseSignalDeliveryStop) { - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - RaiseSignal(SIGKILL); - TEST_CHECK_MSG(false, "Survived SIGKILL?"); - _exit(1); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Expect the child to die to SIGKILL without entering signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; -} - -TEST(PtraceTest, PtraceKill) { - constexpr int kOriginalSigno = SIGUSR1; - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - - // PTRACE_KILL only works if tracee has entered signal-delivery-stop. - RaiseSignal(kOriginalSigno); - TEST_CHECK_MSG(false, "Failed to kill the process?"); - _exit(0); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself kOriginalSigno and enter - // signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == kOriginalSigno) - << " status " << status; - - ASSERT_THAT(ptrace(PTRACE_KILL, child_pid, 0, 0), SyscallSucceeds()); - - // Expect the child to die with SIGKILL. - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; -} - -TEST(PtraceTest, GetRegSet) { - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - - // Enable tracing. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - - // Use kill explicitly because we check the syscall argument register below. - kill(getpid(), SIGSTOP); - - _exit(0); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Get the general registers. - struct user_regs_struct regs; - struct iovec iov; - iov.iov_base = ®s; - iov.iov_len = sizeof(regs); - EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), - SyscallSucceeds()); - - // Read exactly the full register set. - EXPECT_EQ(iov.iov_len, sizeof(regs)); - -#ifdef __x86_64__ - // Child called kill(2), with SIGSTOP as arg 2. - EXPECT_EQ(regs.rsi, SIGSTOP); -#endif - - // Suppress SIGSTOP and resume the child. - ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - // Let's see that process exited normally. - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -TEST(PtraceTest, AttachingConvertsGroupStopToPtraceStop) { - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - while (true) { - pause(); - } - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // SIGSTOP the child and wait for it to stop. - ASSERT_THAT(kill(child_pid, SIGSTOP), SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(child_pid, &status, WUNTRACED), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Attach to the child and expect it to re-enter a traced group-stop despite - // already being stopped. - ASSERT_THAT(ptrace(PTRACE_ATTACH, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Verify that the child is ptrace-stopped by checking that it can receive - // ptrace commands requiring a ptrace-stop. - EXPECT_THAT(ptrace(PTRACE_SETOPTIONS, child_pid, 0, 0), SyscallSucceeds()); - - // Group-stop is distinguished from signal-delivery-stop by PTRACE_GETSIGINFO - // failing with EINVAL. - siginfo_t siginfo = {}; - EXPECT_THAT(ptrace(PTRACE_GETSIGINFO, child_pid, 0, &siginfo), - SyscallFailsWithErrno(EINVAL)); - - // Detach from the child and expect it to stay stopped without a notification. - ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, WUNTRACED | WNOHANG), - SyscallSucceedsWithValue(0)); - - // Sending it SIGCONT should cause it to leave its stop. - ASSERT_THAT(kill(child_pid, SIGCONT), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, WCONTINUED), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFCONTINUED(status)) << " status " << status; - - // Clean up the child. - ASSERT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; -} - -// Fixture for tests parameterized by whether or not to use PTRACE_O_TRACEEXEC. -class PtraceExecveTest : public ::testing::TestWithParam<bool> { - protected: - bool TraceExec() const { return GetParam(); } -}; - -TEST_P(PtraceExecveTest, Execve_GetRegs_PeekUser_SIGKILL_TraceClone_TraceExit) { - ExecveArray const owned_child_argv = {"/proc/self/exe", - "--ptrace_test_execve_child"}; - char* const* const child_argv = owned_child_argv.get(); - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. The test relies on calling execve() in a non-leader - // thread; pthread_create() isn't async-signal-safe, so the safest way to - // do this is to execve() first, then enable tracing and run the expected - // child process behavior in the new subprocess. - execve(child_argv[0], child_argv, /* envp = */ nullptr); - TEST_PCHECK_MSG(false, "Survived execve to test child"); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Enable PTRACE_O_TRACECLONE so we can get the ID of the child's non-leader - // thread, PTRACE_O_TRACEEXIT so we can observe the leader's death, and - // PTRACE_O_TRACEEXEC if required by the test. (The leader doesn't call - // execve, but options should be inherited across clone.) - long opts = PTRACE_O_TRACECLONE | PTRACE_O_TRACEEXIT; - if (TraceExec()) { - opts |= PTRACE_O_TRACEEXEC; - } - ASSERT_THAT(ptrace(PTRACE_SETOPTIONS, child_pid, 0, opts), SyscallSucceeds()); - - // Suppress the SIGSTOP and wait for the child's leader thread to report - // PTRACE_EVENT_CLONE. Get the new thread's ID from the event. - ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_CLONE << 8), status >> 8); - unsigned long eventmsg; - ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, child_pid, 0, &eventmsg), - SyscallSucceeds()); - pid_t const nonleader_tid = eventmsg; - pid_t const leader_tid = child_pid; - - // The new thread should be ptraced and in signal-delivery-stop by SIGSTOP due - // to PTRACE_O_TRACECLONE. - // - // Before bf959931ddb88c4e4366e96dd22e68fa0db9527c "wait/ptrace: assume __WALL - // if the child is traced" (4.7) , waiting on it requires __WCLONE since, as a - // non-leader, its termination signal is 0. After, a standard wait is - // sufficient. - ASSERT_THAT(waitpid(nonleader_tid, &status, __WCLONE), - SyscallSucceedsWithValue(nonleader_tid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Resume both child threads. - for (pid_t const tid : {leader_tid, nonleader_tid}) { - ASSERT_THAT(ptrace(PTRACE_CONT, tid, 0, 0), SyscallSucceeds()); - } - - // The non-leader child thread should call execve, causing the leader thread - // to enter PTRACE_EVENT_EXIT with an apparent exit code of 0. At this point, - // the leader has not yet exited, so the non-leader should be blocked in - // execve. - ASSERT_THAT(waitpid(leader_tid, &status, 0), - SyscallSucceedsWithValue(leader_tid)); - EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_EXIT << 8), status >> 8); - ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, leader_tid, 0, &eventmsg), - SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(eventmsg) && WEXITSTATUS(eventmsg) == 0) - << " eventmsg " << eventmsg; - EXPECT_THAT(waitpid(nonleader_tid, &status, __WCLONE | WNOHANG), - SyscallSucceedsWithValue(0)); - - // Allow the leader to continue exiting. This should allow the non-leader to - // complete its execve, causing the original leader to be reaped without - // further notice and the non-leader to steal its ID. - ASSERT_THAT(ptrace(PTRACE_CONT, leader_tid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(leader_tid, &status, 0), - SyscallSucceedsWithValue(leader_tid)); - if (TraceExec()) { - // If PTRACE_O_TRACEEXEC was enabled, the execing thread should be in - // PTRACE_EVENT_EXEC-stop, with the event message set to its old thread ID. - EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_EXEC << 8), status >> 8); - ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, leader_tid, 0, &eventmsg), - SyscallSucceeds()); - EXPECT_EQ(nonleader_tid, eventmsg); - } else { - // Otherwise, the execing thread should have received SIGTRAP and should now - // be in signal-delivery-stop. - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) - << " status " << status; - } - -#ifdef __x86_64__ - { - // CS should be 0x33, indicating an 64-bit binary. - constexpr uint64_t kAMD64UserCS = 0x33; - EXPECT_THAT(ptrace(PTRACE_PEEKUSER, leader_tid, - offsetof(struct user_regs_struct, cs), 0), - SyscallSucceedsWithValue(kAMD64UserCS)); - struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, leader_tid, 0, ®s), - SyscallSucceeds()); - EXPECT_EQ(kAMD64UserCS, regs.cs); - } -#endif // defined(__x86_64__) - - // PTRACE_O_TRACEEXIT should have been inherited across execve. Send SIGKILL, - // which should end the PTRACE_EVENT_EXEC-stop or signal-delivery-stop and - // leave the child in PTRACE_EVENT_EXIT-stop. - ASSERT_THAT(kill(leader_tid, SIGKILL), SyscallSucceeds()); - ASSERT_THAT(waitpid(leader_tid, &status, 0), - SyscallSucceedsWithValue(leader_tid)); - EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_EXIT << 8), status >> 8); - ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, leader_tid, 0, &eventmsg), - SyscallSucceeds()); - EXPECT_TRUE(WIFSIGNALED(eventmsg) && WTERMSIG(eventmsg) == SIGKILL) - << " eventmsg " << eventmsg; - - // End the PTRACE_EVENT_EXIT stop, allowing the child to exit. - ASSERT_THAT(ptrace(PTRACE_CONT, leader_tid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(leader_tid, &status, 0), - SyscallSucceedsWithValue(leader_tid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; -} - -[[noreturn]] void RunExecveChild() { - // Enable tracing, then raise SIGSTOP and expect our parent to suppress it. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - RaiseSignal(SIGSTOP); - MaybeSave(); - - // Call execve() in a non-leader thread. As long as execve() succeeds, what - // exactly we execve() shouldn't really matter, since the tracer should kill - // us after execve() completes. - ScopedThread t([&] { - ExecveArray const owned_child_argv = {"/proc/self/exe", - "--this_flag_shouldnt_exist"}; - char* const* const child_argv = owned_child_argv.get(); - execve(child_argv[0], child_argv, /* envp = */ nullptr); - TEST_PCHECK_MSG(false, "Survived execve? (thread)"); - }); - t.Join(); - TEST_CHECK_MSG(false, "Survived execve? (main)"); - _exit(1); -} - -INSTANTIATE_TEST_SUITE_P(TraceExec, PtraceExecveTest, ::testing::Bool()); - -// This test has expectations on when syscall-enter/exit-stops occur that are -// violated if saving occurs, since saving interrupts all syscalls, causing -// premature syscall-exit. -TEST(PtraceTest, - ExitWhenParentIsNotTracer_Syscall_TraceVfork_TraceVforkDone_NoRandomSave) { - constexpr int kExitTraceeExitCode = 99; - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - - // Block SIGCHLD so it doesn't interrupt wait4. - sigset_t mask; - TEST_PCHECK(sigemptyset(&mask) == 0); - TEST_PCHECK(sigaddset(&mask, SIGCHLD) == 0); - TEST_PCHECK(sigprocmask(SIG_SETMASK, &mask, nullptr) == 0); - MaybeSave(); - - // Enable tracing, then raise SIGSTOP and expect our parent to suppress it. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - RaiseSignal(SIGSTOP); - MaybeSave(); - - // Spawn a vfork child that exits immediately, and reap it. Don't save - // after vfork since the parent expects to see wait4 as the next syscall. - pid_t const pid = vfork(); - if (pid == 0) { - _exit(kExitTraceeExitCode); - } - TEST_PCHECK_MSG(pid > 0, "vfork failed"); - - int status; - TEST_PCHECK(wait4(pid, &status, 0, nullptr) > 0); - MaybeSave(); - TEST_CHECK(WIFEXITED(status) && WEXITSTATUS(status) == kExitTraceeExitCode); - _exit(0); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop. - int status; - ASSERT_THAT(child_pid, SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Enable PTRACE_O_TRACEVFORK so we can get the ID of the grandchild, - // PTRACE_O_TRACEVFORKDONE so we can observe PTRACE_EVENT_VFORK_DONE, and - // PTRACE_O_TRACESYSGOOD so syscall-enter/exit-stops are unambiguously - // indicated by a stop signal of SIGTRAP|0x80 rather than just SIGTRAP. - ASSERT_THAT(ptrace(PTRACE_SETOPTIONS, child_pid, 0, - PTRACE_O_TRACEVFORK | PTRACE_O_TRACEVFORKDONE | - PTRACE_O_TRACESYSGOOD), - SyscallSucceeds()); - - // Suppress the SIGSTOP and wait for the child to report PTRACE_EVENT_VFORK. - // Get the new process' ID from the event. - ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_VFORK << 8), status >> 8); - unsigned long eventmsg; - ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, child_pid, 0, &eventmsg), - SyscallSucceeds()); - pid_t const grandchild_pid = eventmsg; - - // The grandchild should be traced by us and in signal-delivery-stop by - // SIGSTOP due to PTRACE_O_TRACEVFORK. This allows us to wait on it even - // though we're not its parent. - ASSERT_THAT(waitpid(grandchild_pid, &status, 0), - SyscallSucceedsWithValue(grandchild_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Resume the child with PTRACE_SYSCALL. Since the grandchild is still in - // signal-delivery-stop, the child should remain in vfork() waiting for the - // grandchild to exec or exit. - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds()); - absl::SleepFor(absl::Seconds(1)); - ASSERT_THAT(waitpid(child_pid, &status, WNOHANG), - SyscallSucceedsWithValue(0)); - - // Suppress the grandchild's SIGSTOP and wait for the grandchild to exit. Pass - // WNOWAIT to waitid() so that we don't acknowledge the grandchild's exit yet. - ASSERT_THAT(ptrace(PTRACE_CONT, grandchild_pid, 0, 0), SyscallSucceeds()); - siginfo_t siginfo = {}; - ASSERT_THAT(waitid(P_PID, grandchild_pid, &siginfo, WEXITED | WNOWAIT), - SyscallSucceeds()); - EXPECT_EQ(SIGCHLD, siginfo.si_signo); - EXPECT_EQ(CLD_EXITED, siginfo.si_code); - EXPECT_EQ(kExitTraceeExitCode, siginfo.si_status); - EXPECT_EQ(grandchild_pid, siginfo.si_pid); - EXPECT_EQ(getuid(), siginfo.si_uid); - - // The child should now be in PTRACE_EVENT_VFORK_DONE stop. The event - // message should still be the grandchild's PID. - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_VFORK_DONE << 8), status >> 8); - ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, child_pid, 0, &eventmsg), - SyscallSucceeds()); - EXPECT_EQ(grandchild_pid, eventmsg); - - // Resume the child with PTRACE_SYSCALL again and expect it to enter - // syscall-exit-stop for vfork() or clone(), either of which should return the - // grandchild's PID from the syscall. Aside from PTRACE_O_TRACESYSGOOD, - // syscall-stops are distinguished from signal-delivery-stop by - // PTRACE_GETSIGINFO returning a siginfo for which si_code == SIGTRAP or - // SIGTRAP|0x80. - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) - << " status " << status; - ASSERT_THAT(ptrace(PTRACE_GETSIGINFO, child_pid, 0, &siginfo), - SyscallSucceeds()); - EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80)) - << "si_code = " << siginfo.si_code; -#ifdef __x86_64__ - { - struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); - EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone) - << "orig_rax = " << regs.orig_rax; - EXPECT_EQ(grandchild_pid, regs.rax); - } -#endif // defined(__x86_64__) - - // After this point, the child will be making wait4 syscalls that will be - // interrupted by saving, so saving is not permitted. Note that this is - // explicitly released below once the grandchild exits. - DisableSave ds; - - // Resume the child with PTRACE_SYSCALL again and expect it to enter - // syscall-enter-stop for wait4(). - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) - << " status " << status; - ASSERT_THAT(ptrace(PTRACE_GETSIGINFO, child_pid, 0, &siginfo), - SyscallSucceeds()); - EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80)) - << "si_code = " << siginfo.si_code; -#ifdef __x86_64__ - { - EXPECT_THAT(ptrace(PTRACE_PEEKUSER, child_pid, - offsetof(struct user_regs_struct, orig_rax), 0), - SyscallSucceedsWithValue(SYS_wait4)); - } -#endif // defined(__x86_64__) - - // Resume the child with PTRACE_SYSCALL again. Since the grandchild is - // waiting for the tracer (us) to acknowledge its exit first, wait4 should - // block. - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds()); - absl::SleepFor(absl::Seconds(1)); - ASSERT_THAT(waitpid(child_pid, &status, WNOHANG), - SyscallSucceedsWithValue(0)); - - // Acknowledge the grandchild's exit. - ASSERT_THAT(waitpid(grandchild_pid, &status, 0), - SyscallSucceedsWithValue(grandchild_pid)); - ds.reset(); - - // Now the child should enter syscall-exit-stop for wait4, returning with the - // grandchild's PID. - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) - << " status " << status; -#ifdef __x86_64__ - { - struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); - EXPECT_EQ(SYS_wait4, regs.orig_rax); - EXPECT_EQ(grandchild_pid, regs.rax); - } -#endif // defined(__x86_64__) - - // Detach from the child and wait for it to exit. - ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -// These tests requires knowledge of architecture-specific syscall convention. -#ifdef __x86_64__ -TEST(PtraceTest, Int3) { - SKIP_IF(PlatformSupportInt3() == PlatformSupport::NotSupported); - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - - // Enable tracing. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - - // Interrupt 3 - trap to debugger - asm("int3"); - - _exit(56); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) - << " status " << status; - - ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds()); - - // The child should validate the injected return value and then exit normally. - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 56) - << " status " << status; -} - -TEST(PtraceTest, Sysemu_PokeUser) { - constexpr int kSysemuHelperFirstExitCode = 126; - constexpr uint64_t kSysemuInjectedExitGroupReturn = 42; - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - - // Enable tracing, then raise SIGSTOP and expect our parent to suppress it. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - RaiseSignal(SIGSTOP); - - // Try to exit_group, expecting the tracer to skip the syscall and set its - // own return value. - int const rv = syscall(SYS_exit_group, kSysemuHelperFirstExitCode); - TEST_PCHECK_MSG(rv == kSysemuInjectedExitGroupReturn, - "exit_group returned incorrect value"); - - _exit(0); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Suppress the SIGSTOP and wait for the child to enter syscall-enter-stop - // for its first exit_group syscall. - ASSERT_THAT(ptrace(kPtraceSysemu, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) - << " status " << status; - - struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); - EXPECT_EQ(SYS_exit_group, regs.orig_rax); - EXPECT_EQ(-ENOSYS, regs.rax); - EXPECT_EQ(kSysemuHelperFirstExitCode, regs.rdi); - - // Replace the exit_group return value, then resume the child, which should - // automatically skip the syscall. - ASSERT_THAT( - ptrace(PTRACE_POKEUSER, child_pid, offsetof(struct user_regs_struct, rax), - kSysemuInjectedExitGroupReturn), - SyscallSucceeds()); - ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); - - // The child should validate the injected return value and then exit normally. - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -// This test also cares about syscall-exit-stop. -TEST(PtraceTest, ERESTART_NoRandomSave) { - constexpr int kSigno = SIGUSR1; - - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - - // Ignore, but unblock, kSigno. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - TEST_PCHECK(sigfillset(&sa.sa_mask) == 0); - TEST_PCHECK(sigaction(kSigno, &sa, nullptr) == 0); - MaybeSave(); - TEST_PCHECK(sigprocmask(SIG_UNBLOCK, &sa.sa_mask, nullptr) == 0); - MaybeSave(); - - // Enable tracing, then raise SIGSTOP and expect our parent to suppress it. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - RaiseSignal(SIGSTOP); - - // Invoke the pause syscall, which normally should not return until we - // receive a signal that "either terminates the process or causes the - // invocation of a signal-catching function". - pause(); - - _exit(0); - } - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // After this point, the child's pause syscall will be interrupted by saving, - // so saving is not permitted. Note that this is explicitly released below - // once the child is stopped. - DisableSave ds; - - // Suppress the SIGSTOP and wait for the child to enter syscall-enter-stop for - // its pause syscall. - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) - << " status " << status; - - struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); - EXPECT_EQ(SYS_pause, regs.orig_rax); - EXPECT_EQ(-ENOSYS, regs.rax); - - // Resume the child with PTRACE_SYSCALL and expect it to block in the pause - // syscall. - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds()); - absl::SleepFor(absl::Seconds(1)); - ASSERT_THAT(waitpid(child_pid, &status, WNOHANG), - SyscallSucceedsWithValue(0)); - - // Send the child kSigno, causing it to return ERESTARTNOHAND and enter - // syscall-exit-stop from the pause syscall. - constexpr int ERESTARTNOHAND = 514; - ASSERT_THAT(kill(child_pid, kSigno), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) - << " status " << status; - ds.reset(); - - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); - EXPECT_EQ(SYS_pause, regs.orig_rax); - EXPECT_EQ(-ERESTARTNOHAND, regs.rax); - - // Replace the return value from pause with 0, causing pause to not be - // restarted despite kSigno being ignored. - ASSERT_THAT(ptrace(PTRACE_POKEUSER, child_pid, - offsetof(struct user_regs_struct, rax), 0), - SyscallSucceeds()); - - // Detach from the child and wait for it to exit. - ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} -#endif // defined(__x86_64__) - -TEST(PtraceTest, Seize_Interrupt_Listen) { - volatile long child_should_spin = 1; - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - while (child_should_spin) { - SleepSafe(absl::Seconds(1)); - } - _exit(1); - } - - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Attach to the child with PTRACE_SEIZE; doing so should not stop the child. - ASSERT_THAT(ptrace(PTRACE_SEIZE, child_pid, 0, 0), SyscallSucceeds()); - int status; - EXPECT_THAT(waitpid(child_pid, &status, WNOHANG), - SyscallSucceedsWithValue(0)); - - // Stop the child with PTRACE_INTERRUPT. - ASSERT_THAT(ptrace(PTRACE_INTERRUPT, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_EQ(SIGTRAP | (kPtraceEventStop << 8), status >> 8); - - // Unset child_should_spin to verify that the child never leaves the spin - // loop. - ASSERT_THAT(ptrace(PTRACE_POKEDATA, child_pid, &child_should_spin, 0), - SyscallSucceeds()); - - // Send SIGSTOP to the child, then resume it, allowing it to proceed to - // signal-delivery-stop. - ASSERT_THAT(kill(child_pid, SIGSTOP), SyscallSucceeds()); - ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // Release the child from signal-delivery-stop without suppressing the - // SIGSTOP, causing it to enter group-stop. - ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, SIGSTOP), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_EQ(SIGSTOP | (kPtraceEventStop << 8), status >> 8); - - // "The state of the tracee after PTRACE_LISTEN is somewhat of a gray area: it - // is not in any ptrace-stop (ptrace commands won't work on it, and it will - // deliver waitpid(2) notifications), but it also may be considered 'stopped' - // because it is not executing instructions (is not scheduled), and if it was - // in group-stop before PTRACE_LISTEN, it will not respond to signals until - // SIGCONT is received." - ptrace(2). - ASSERT_THAT(ptrace(PTRACE_LISTEN, child_pid, 0, 0), SyscallSucceeds()); - EXPECT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), - SyscallFailsWithErrno(ESRCH)); - EXPECT_THAT(waitpid(child_pid, &status, WNOHANG), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(kill(child_pid, SIGTERM), SyscallSucceeds()); - absl::SleepFor(absl::Seconds(1)); - EXPECT_THAT(waitpid(child_pid, &status, WNOHANG), - SyscallSucceedsWithValue(0)); - - // Send SIGCONT to the child, causing it to leave group-stop and re-trap due - // to PTRACE_LISTEN. - EXPECT_THAT(kill(child_pid, SIGCONT), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_EQ(SIGTRAP | (kPtraceEventStop << 8), status >> 8); - - // Detach the child and expect it to exit due to the SIGTERM we sent while - // it was stopped by PTRACE_LISTEN. - ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGTERM) - << " status " << status; -} - -TEST(PtraceTest, Interrupt_Listen_RequireSeize) { - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - raise(SIGSTOP); - _exit(0); - } - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop. - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) - << " status " << status; - - // PTRACE_INTERRUPT and PTRACE_LISTEN should fail since the child wasn't - // attached with PTRACE_SEIZE, leaving the child in signal-delivery-stop. - EXPECT_THAT(ptrace(PTRACE_INTERRUPT, child_pid, 0, 0), - SyscallFailsWithErrno(EIO)); - EXPECT_THAT(ptrace(PTRACE_LISTEN, child_pid, 0, 0), - SyscallFailsWithErrno(EIO)); - - // Suppress SIGSTOP and detach from the child, expecting it to exit normally. - ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -TEST(PtraceTest, SeizeSetOptions) { - pid_t const child_pid = fork(); - if (child_pid == 0) { - // In child process. - while (true) { - SleepSafe(absl::Seconds(1)); - } - } - - // In parent process. - ASSERT_THAT(child_pid, SyscallSucceeds()); - - // Attach to the child with PTRACE_SEIZE while setting PTRACE_O_TRACESYSGOOD. - ASSERT_THAT(ptrace(PTRACE_SEIZE, child_pid, 0, PTRACE_O_TRACESYSGOOD), - SyscallSucceeds()); - - // Stop the child with PTRACE_INTERRUPT. - ASSERT_THAT(ptrace(PTRACE_INTERRUPT, child_pid, 0, 0), SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_EQ(SIGTRAP | (kPtraceEventStop << 8), status >> 8); - - // Resume the child with PTRACE_SYSCALL and wait for it to enter - // syscall-enter-stop. The stop signal status from the syscall stop should be - // SIGTRAP|0x80, reflecting PTRACE_O_TRACESYSGOOD. - ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) - << " status " << status; - - // Clean up the child. - ASSERT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - if (WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) { - // "SIGKILL kills even within system calls (syscall-exit-stop is not - // generated prior to death by SIGKILL). The net effect is that SIGKILL - // always kills the process (all its threads), even if some threads of the - // process are ptraced." - ptrace(2). This is technically true, but... - // - // When we send SIGKILL to the child, kernel/signal.c:complete_signal() => - // signal_wake_up(resume=1) kicks the tracee out of the syscall-enter-stop. - // The pending SIGKILL causes the syscall to be skipped, but the child - // thread still reports syscall-exit before checking for pending signals; in - // current kernels, this is - // arch/x86/entry/common.c:syscall_return_slowpath() => - // syscall_slow_exit_work() => - // include/linux/tracehook.h:tracehook_report_syscall_exit() => - // ptrace_report_syscall() => kernel/signal.c:ptrace_notify() => - // ptrace_do_notify() => ptrace_stop(). - // - // ptrace_stop() sets the task's state to TASK_TRACED and the task's - // exit_code to SIGTRAP|0x80 (passed by ptrace_report_syscall()), then calls - // freezable_schedule(). freezable_schedule() eventually reaches - // __schedule(), which detects signal_pending_state() due to the pending - // SIGKILL, sets the task's state back to TASK_RUNNING, and returns without - // descheduling. Thus, the task never enters syscall-exit-stop. However, if - // our wait4() => kernel/exit.c:wait_task_stopped() racily observes the - // TASK_TRACED state and the non-zero exit code set by ptrace_stop() before - // __schedule() sets the state back to TASK_RUNNING, it will return the - // task's exit_code as status W_STOPCODE(SIGTRAP|0x80). So we get a spurious - // syscall-exit-stop notification, and need to wait4() again for task exit. - // - // gVisor is not susceptible to this race because - // kernel.Task.waitCollectTraceeStopLocked() checks specifically for an - // active ptraceStop, which is not initiated if SIGKILL is pending. - std::cout << "Observed syscall-exit after SIGKILL"; - ASSERT_THAT(waitpid(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - } - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - gvisor::testing::TestInit(&argc, &argv); - - if (absl::GetFlag(FLAGS_ptrace_test_execve_child)) { - gvisor::testing::RunExecveChild(); - } - - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc deleted file mode 100644 index dafe64d20..000000000 --- a/test/syscalls/linux/pty.cc +++ /dev/null @@ -1,1616 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <linux/capability.h> -#include <linux/major.h> -#include <poll.h> -#include <sched.h> -#include <signal.h> -#include <sys/ioctl.h> -#include <sys/mman.h> -#include <sys/stat.h> -#include <sys/sysmacros.h> -#include <sys/types.h> -#include <sys/wait.h> -#include <termios.h> -#include <unistd.h> - -#include <iostream> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/notification.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/pty_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -using ::testing::AnyOf; -using ::testing::Contains; -using ::testing::Eq; -using ::testing::Not; - -// Tests Unix98 pseudoterminals. -// -// These tests assume that /dev/ptmx exists and is associated with a devpts -// filesystem mounted at /dev/pts/. While a Linux distribution could -// theoretically place those anywhere, glibc expects those locations, so they -// are effectively fixed. - -// Minor device number for an unopened ptmx file. -constexpr int kPtmxMinor = 2; - -// The timeout when polling for data from a pty. When data is written to one end -// of a pty, Linux asynchronously makes it available to the other end, so we -// have to wait. -constexpr absl::Duration kTimeout = absl::Seconds(20); - -// The maximum line size in bytes returned per read from a pty file. -constexpr int kMaxLineSize = 4096; - -constexpr char kMasterPath[] = "/dev/ptmx"; - -// glibc defines its own, different, version of struct termios. We care about -// what the kernel does, not glibc. -#define KERNEL_NCCS 19 -struct kernel_termios { - tcflag_t c_iflag; - tcflag_t c_oflag; - tcflag_t c_cflag; - tcflag_t c_lflag; - cc_t c_line; - cc_t c_cc[KERNEL_NCCS]; -}; - -bool operator==(struct kernel_termios const& a, - struct kernel_termios const& b) { - return memcmp(&a, &b, sizeof(a)) == 0; -} - -// Returns the termios-style control character for the passed character. -// -// e.g., for Ctrl-C, i.e., ^C, call ControlCharacter('C'). -// -// Standard control characters are ASCII bytes 0 through 31. -constexpr char ControlCharacter(char c) { - // A is 1, B is 2, etc. - return c - 'A' + 1; -} - -// Returns the printable character the given control character represents. -constexpr char FromControlCharacter(char c) { return c + 'A' - 1; } - -// Returns true if c is a control character. -// -// Standard control characters are ASCII bytes 0 through 31. -constexpr bool IsControlCharacter(char c) { return c <= 31; } - -struct Field { - const char* name; - uint64_t mask; - uint64_t value; -}; - -// ParseFields returns a string representation of value, using the names in -// fields. -std::string ParseFields(const Field* fields, size_t len, uint64_t value) { - bool first = true; - std::string s; - for (size_t i = 0; i < len; i++) { - const Field f = fields[i]; - if ((value & f.mask) == f.value) { - if (!first) { - s += "|"; - } - s += f.name; - first = false; - value &= ~f.mask; - } - } - - if (value) { - if (!first) { - s += "|"; - } - absl::StrAppend(&s, value); - } - - return s; -} - -const Field kIflagFields[] = { - {"IGNBRK", IGNBRK, IGNBRK}, {"BRKINT", BRKINT, BRKINT}, - {"IGNPAR", IGNPAR, IGNPAR}, {"PARMRK", PARMRK, PARMRK}, - {"INPCK", INPCK, INPCK}, {"ISTRIP", ISTRIP, ISTRIP}, - {"INLCR", INLCR, INLCR}, {"IGNCR", IGNCR, IGNCR}, - {"ICRNL", ICRNL, ICRNL}, {"IUCLC", IUCLC, IUCLC}, - {"IXON", IXON, IXON}, {"IXANY", IXANY, IXANY}, - {"IXOFF", IXOFF, IXOFF}, {"IMAXBEL", IMAXBEL, IMAXBEL}, - {"IUTF8", IUTF8, IUTF8}, -}; - -const Field kOflagFields[] = { - {"OPOST", OPOST, OPOST}, {"OLCUC", OLCUC, OLCUC}, - {"ONLCR", ONLCR, ONLCR}, {"OCRNL", OCRNL, OCRNL}, - {"ONOCR", ONOCR, ONOCR}, {"ONLRET", ONLRET, ONLRET}, - {"OFILL", OFILL, OFILL}, {"OFDEL", OFDEL, OFDEL}, - {"NL0", NLDLY, NL0}, {"NL1", NLDLY, NL1}, - {"CR0", CRDLY, CR0}, {"CR1", CRDLY, CR1}, - {"CR2", CRDLY, CR2}, {"CR3", CRDLY, CR3}, - {"TAB0", TABDLY, TAB0}, {"TAB1", TABDLY, TAB1}, - {"TAB2", TABDLY, TAB2}, {"TAB3", TABDLY, TAB3}, - {"BS0", BSDLY, BS0}, {"BS1", BSDLY, BS1}, - {"FF0", FFDLY, FF0}, {"FF1", FFDLY, FF1}, - {"VT0", VTDLY, VT0}, {"VT1", VTDLY, VT1}, - {"XTABS", XTABS, XTABS}, -}; - -#ifndef IBSHIFT -// Shift from CBAUD to CIBAUD. -#define IBSHIFT 16 -#endif - -const Field kCflagFields[] = { - {"B0", CBAUD, B0}, - {"B50", CBAUD, B50}, - {"B75", CBAUD, B75}, - {"B110", CBAUD, B110}, - {"B134", CBAUD, B134}, - {"B150", CBAUD, B150}, - {"B200", CBAUD, B200}, - {"B300", CBAUD, B300}, - {"B600", CBAUD, B600}, - {"B1200", CBAUD, B1200}, - {"B1800", CBAUD, B1800}, - {"B2400", CBAUD, B2400}, - {"B4800", CBAUD, B4800}, - {"B9600", CBAUD, B9600}, - {"B19200", CBAUD, B19200}, - {"B38400", CBAUD, B38400}, - {"CS5", CSIZE, CS5}, - {"CS6", CSIZE, CS6}, - {"CS7", CSIZE, CS7}, - {"CS8", CSIZE, CS8}, - {"CSTOPB", CSTOPB, CSTOPB}, - {"CREAD", CREAD, CREAD}, - {"PARENB", PARENB, PARENB}, - {"PARODD", PARODD, PARODD}, - {"HUPCL", HUPCL, HUPCL}, - {"CLOCAL", CLOCAL, CLOCAL}, - {"B57600", CBAUD, B57600}, - {"B115200", CBAUD, B115200}, - {"B230400", CBAUD, B230400}, - {"B460800", CBAUD, B460800}, - {"B500000", CBAUD, B500000}, - {"B576000", CBAUD, B576000}, - {"B921600", CBAUD, B921600}, - {"B1000000", CBAUD, B1000000}, - {"B1152000", CBAUD, B1152000}, - {"B1500000", CBAUD, B1500000}, - {"B2000000", CBAUD, B2000000}, - {"B2500000", CBAUD, B2500000}, - {"B3000000", CBAUD, B3000000}, - {"B3500000", CBAUD, B3500000}, - {"B4000000", CBAUD, B4000000}, - {"CMSPAR", CMSPAR, CMSPAR}, - {"CRTSCTS", CRTSCTS, CRTSCTS}, - {"IB0", CIBAUD, B0 << IBSHIFT}, - {"IB50", CIBAUD, B50 << IBSHIFT}, - {"IB75", CIBAUD, B75 << IBSHIFT}, - {"IB110", CIBAUD, B110 << IBSHIFT}, - {"IB134", CIBAUD, B134 << IBSHIFT}, - {"IB150", CIBAUD, B150 << IBSHIFT}, - {"IB200", CIBAUD, B200 << IBSHIFT}, - {"IB300", CIBAUD, B300 << IBSHIFT}, - {"IB600", CIBAUD, B600 << IBSHIFT}, - {"IB1200", CIBAUD, B1200 << IBSHIFT}, - {"IB1800", CIBAUD, B1800 << IBSHIFT}, - {"IB2400", CIBAUD, B2400 << IBSHIFT}, - {"IB4800", CIBAUD, B4800 << IBSHIFT}, - {"IB9600", CIBAUD, B9600 << IBSHIFT}, - {"IB19200", CIBAUD, B19200 << IBSHIFT}, - {"IB38400", CIBAUD, B38400 << IBSHIFT}, - {"IB57600", CIBAUD, B57600 << IBSHIFT}, - {"IB115200", CIBAUD, B115200 << IBSHIFT}, - {"IB230400", CIBAUD, B230400 << IBSHIFT}, - {"IB460800", CIBAUD, B460800 << IBSHIFT}, - {"IB500000", CIBAUD, B500000 << IBSHIFT}, - {"IB576000", CIBAUD, B576000 << IBSHIFT}, - {"IB921600", CIBAUD, B921600 << IBSHIFT}, - {"IB1000000", CIBAUD, B1000000 << IBSHIFT}, - {"IB1152000", CIBAUD, B1152000 << IBSHIFT}, - {"IB1500000", CIBAUD, B1500000 << IBSHIFT}, - {"IB2000000", CIBAUD, B2000000 << IBSHIFT}, - {"IB2500000", CIBAUD, B2500000 << IBSHIFT}, - {"IB3000000", CIBAUD, B3000000 << IBSHIFT}, - {"IB3500000", CIBAUD, B3500000 << IBSHIFT}, - {"IB4000000", CIBAUD, B4000000 << IBSHIFT}, -}; - -const Field kLflagFields[] = { - {"ISIG", ISIG, ISIG}, {"ICANON", ICANON, ICANON}, - {"XCASE", XCASE, XCASE}, {"ECHO", ECHO, ECHO}, - {"ECHOE", ECHOE, ECHOE}, {"ECHOK", ECHOK, ECHOK}, - {"ECHONL", ECHONL, ECHONL}, {"NOFLSH", NOFLSH, NOFLSH}, - {"TOSTOP", TOSTOP, TOSTOP}, {"ECHOCTL", ECHOCTL, ECHOCTL}, - {"ECHOPRT", ECHOPRT, ECHOPRT}, {"ECHOKE", ECHOKE, ECHOKE}, - {"FLUSHO", FLUSHO, FLUSHO}, {"PENDIN", PENDIN, PENDIN}, - {"IEXTEN", IEXTEN, IEXTEN}, {"EXTPROC", EXTPROC, EXTPROC}, -}; - -std::string FormatCC(char c) { - if (isgraph(c)) { - return std::string(1, c); - } else if (c == ' ') { - return " "; - } else if (c == '\t') { - return "\\t"; - } else if (c == '\r') { - return "\\r"; - } else if (c == '\n') { - return "\\n"; - } else if (c == '\0') { - return "\\0"; - } else if (IsControlCharacter(c)) { - return absl::StrCat("^", std::string(1, FromControlCharacter(c))); - } - return absl::StrCat("\\x", absl::Hex(c)); -} - -std::ostream& operator<<(std::ostream& os, struct kernel_termios const& a) { - os << "{ c_iflag = " - << ParseFields(kIflagFields, ABSL_ARRAYSIZE(kIflagFields), a.c_iflag); - os << ", c_oflag = " - << ParseFields(kOflagFields, ABSL_ARRAYSIZE(kOflagFields), a.c_oflag); - os << ", c_cflag = " - << ParseFields(kCflagFields, ABSL_ARRAYSIZE(kCflagFields), a.c_cflag); - os << ", c_lflag = " - << ParseFields(kLflagFields, ABSL_ARRAYSIZE(kLflagFields), a.c_lflag); - os << ", c_line = " << a.c_line; - os << ", c_cc = { [VINTR] = '" << FormatCC(a.c_cc[VINTR]); - os << "', [VQUIT] = '" << FormatCC(a.c_cc[VQUIT]); - os << "', [VERASE] = '" << FormatCC(a.c_cc[VERASE]); - os << "', [VKILL] = '" << FormatCC(a.c_cc[VKILL]); - os << "', [VEOF] = '" << FormatCC(a.c_cc[VEOF]); - os << "', [VTIME] = '" << static_cast<int>(a.c_cc[VTIME]); - os << "', [VMIN] = " << static_cast<int>(a.c_cc[VMIN]); - os << ", [VSWTC] = '" << FormatCC(a.c_cc[VSWTC]); - os << "', [VSTART] = '" << FormatCC(a.c_cc[VSTART]); - os << "', [VSTOP] = '" << FormatCC(a.c_cc[VSTOP]); - os << "', [VSUSP] = '" << FormatCC(a.c_cc[VSUSP]); - os << "', [VEOL] = '" << FormatCC(a.c_cc[VEOL]); - os << "', [VREPRINT] = '" << FormatCC(a.c_cc[VREPRINT]); - os << "', [VDISCARD] = '" << FormatCC(a.c_cc[VDISCARD]); - os << "', [VWERASE] = '" << FormatCC(a.c_cc[VWERASE]); - os << "', [VLNEXT] = '" << FormatCC(a.c_cc[VLNEXT]); - os << "', [VEOL2] = '" << FormatCC(a.c_cc[VEOL2]); - os << "'}"; - return os; -} - -// Return the default termios settings for a new terminal. -struct kernel_termios DefaultTermios() { - struct kernel_termios t = {}; - t.c_iflag = IXON | ICRNL; - t.c_oflag = OPOST | ONLCR; - t.c_cflag = B38400 | CSIZE | CS8 | CREAD; - t.c_lflag = ISIG | ICANON | ECHO | ECHOE | ECHOK | ECHOCTL | ECHOKE | IEXTEN; - t.c_line = 0; - t.c_cc[VINTR] = ControlCharacter('C'); - t.c_cc[VQUIT] = ControlCharacter('\\'); - t.c_cc[VERASE] = '\x7f'; - t.c_cc[VKILL] = ControlCharacter('U'); - t.c_cc[VEOF] = ControlCharacter('D'); - t.c_cc[VTIME] = '\0'; - t.c_cc[VMIN] = 1; - t.c_cc[VSWTC] = '\0'; - t.c_cc[VSTART] = ControlCharacter('Q'); - t.c_cc[VSTOP] = ControlCharacter('S'); - t.c_cc[VSUSP] = ControlCharacter('Z'); - t.c_cc[VEOL] = '\0'; - t.c_cc[VREPRINT] = ControlCharacter('R'); - t.c_cc[VDISCARD] = ControlCharacter('O'); - t.c_cc[VWERASE] = ControlCharacter('W'); - t.c_cc[VLNEXT] = ControlCharacter('V'); - t.c_cc[VEOL2] = '\0'; - return t; -} - -// PollAndReadFd tries to read count bytes from buf within timeout. -// -// Returns a partial read if some bytes were read. -// -// fd must be non-blocking. -PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, - absl::Duration timeout) { - absl::Time end = absl::Now() + timeout; - - size_t completed = 0; - absl::Duration remaining; - while ((remaining = end - absl::Now()) > absl::ZeroDuration()) { - struct pollfd pfd = {fd, POLLIN, 0}; - int ret = RetryEINTR(poll)(&pfd, 1, absl::ToInt64Milliseconds(remaining)); - if (ret < 0) { - return PosixError(errno, "poll failed"); - } else if (ret == 0) { - // Timed out. - continue; - } else if (ret != 1) { - return PosixError(EINVAL, absl::StrCat("Bad poll ret ", ret)); - } - - ssize_t n = - ReadFd(fd, static_cast<char*>(buf) + completed, count - completed); - if (n < 0) { - return PosixError(errno, "read failed"); - } - completed += n; - if (completed >= count) { - return completed; - } - } - - if (completed) { - return completed; - } - return PosixError(ETIMEDOUT, "Poll timed out"); -} - -TEST(PtyTrunc, Truncate) { - // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to - // (f)truncate should. - FileDescriptor master = - ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC)); - int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master)); - std::string spath = absl::StrCat("/dev/pts/", n); - FileDescriptor slave = - ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); - - EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST(BasicPtyTest, StatUnopenedMaster) { - struct stat s; - ASSERT_THAT(stat(kMasterPath, &s), SyscallSucceeds()); - - EXPECT_EQ(s.st_rdev, makedev(TTYAUX_MAJOR, kPtmxMinor)); - EXPECT_EQ(s.st_size, 0); - EXPECT_EQ(s.st_blocks, 0); - - // ptmx attached to a specific devpts mount uses block size 1024. See - // fs/devpts/inode.c:devpts_fill_super. - // - // The global ptmx device uses the block size of the filesystem it is created - // on (which is usually 4096 for disk filesystems). - EXPECT_THAT(s.st_blksize, AnyOf(Eq(1024), Eq(4096))); -} - -// Waits for count bytes to be readable from fd. Unlike poll, which can return -// before all data is moved into a pty's read buffer, this function waits for -// all count bytes to become readable. -PosixErrorOr<int> WaitUntilReceived(int fd, int count) { - int buffered = -1; - absl::Duration remaining; - absl::Time end = absl::Now() + kTimeout; - while ((remaining = end - absl::Now()) > absl::ZeroDuration()) { - if (ioctl(fd, FIONREAD, &buffered) < 0) { - return PosixError(errno, "failed FIONREAD ioctl"); - } - if (buffered >= count) { - return buffered; - } - absl::SleepFor(absl::Milliseconds(500)); - } - return PosixError( - ETIMEDOUT, - absl::StrFormat( - "FIONREAD timed out, receiving only %d of %d expected bytes", - buffered, count)); -} - -// Verifies that there is nothing left to read from fd. -void ExpectFinished(const FileDescriptor& fd) { - // Nothing more to read. - char c; - EXPECT_THAT(ReadFd(fd.get(), &c, 1), SyscallFailsWithErrno(EAGAIN)); -} - -// Verifies that we can read expected bytes from fd into buf. -void ExpectReadable(const FileDescriptor& fd, int expected, char* buf) { - size_t n = ASSERT_NO_ERRNO_AND_VALUE( - PollAndReadFd(fd.get(), buf, expected, kTimeout)); - EXPECT_EQ(expected, n); -} - -TEST(BasicPtyTest, OpenMasterSlave) { - FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); -} - -// The slave entry in /dev/pts/ disappears when the master is closed, even if -// the slave is still open. -TEST(BasicPtyTest, SlaveEntryGoneAfterMasterClose) { - FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); - - // Get pty index. - int index = -1; - ASSERT_THAT(ioctl(master.get(), TIOCGPTN, &index), SyscallSucceeds()); - - std::string path = absl::StrCat("/dev/pts/", index); - - struct stat st; - EXPECT_THAT(stat(path.c_str(), &st), SyscallSucceeds()); - - master.reset(); - - EXPECT_THAT(stat(path.c_str(), &st), SyscallFailsWithErrno(ENOENT)); -} - -TEST(BasicPtyTest, Getdents) { - FileDescriptor master1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - int index1 = -1; - ASSERT_THAT(ioctl(master1.get(), TIOCGPTN, &index1), SyscallSucceeds()); - FileDescriptor slave1 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master1)); - - FileDescriptor master2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - int index2 = -1; - ASSERT_THAT(ioctl(master2.get(), TIOCGPTN, &index2), SyscallSucceeds()); - FileDescriptor slave2 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master2)); - - // The directory contains ptmx, index1, and index2. (Plus any additional PTYs - // unrelated to this test.) - - std::vector<std::string> contents = - ASSERT_NO_ERRNO_AND_VALUE(ListDir("/dev/pts/", true)); - EXPECT_THAT(contents, Contains(absl::StrCat(index1))); - EXPECT_THAT(contents, Contains(absl::StrCat(index2))); - - master2.reset(); - - // The directory contains ptmx and index1, but not index2 since the master is - // closed. (Plus any additional PTYs unrelated to this test.) - - contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/dev/pts/", true)); - EXPECT_THAT(contents, Contains(absl::StrCat(index1))); - EXPECT_THAT(contents, Not(Contains(absl::StrCat(index2)))); - - // N.B. devpts supports legacy "single-instance" mode and new "multi-instance" - // mode. In legacy mode, devpts does not contain a "ptmx" device (the distro - // must use mknod to create it somewhere, presumably /dev/ptmx). - // Multi-instance mode does include a "ptmx" device tied to that mount. - // - // We don't check for the presence or absence of "ptmx", as distros vary in - // their usage of the two modes. -} - -class PtyTest : public ::testing::Test { - protected: - void SetUp() override { - master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); - } - - void DisableCanonical() { - struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); - t.c_lflag &= ~ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - } - - void EnableCanonical() { - struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); - t.c_lflag |= ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - } - - // Master and slave ends of the PTY. Non-blocking. - FileDescriptor master_; - FileDescriptor slave_; -}; - -// Master to slave sanity test. -TEST_F(PtyTest, WriteMasterToSlave) { - // N.B. by default, the slave reads nothing until the master writes a newline. - constexpr char kBuf[] = "hello\n"; - - EXPECT_THAT(WriteFd(master_.get(), kBuf, sizeof(kBuf) - 1), - SyscallSucceedsWithValue(sizeof(kBuf) - 1)); - - // Linux moves data from the master to the slave via async work scheduled via - // tty_flip_buffer_push. Since it is asynchronous, the data may not be - // available for reading immediately. Instead we must poll and assert that it - // becomes available "soon". - - char buf[sizeof(kBuf)] = {}; - ExpectReadable(slave_, sizeof(buf) - 1, buf); - - EXPECT_EQ(memcmp(buf, kBuf, sizeof(kBuf)), 0); -} - -// Slave to master sanity test. -TEST_F(PtyTest, WriteSlaveToMaster) { - // N.B. by default, the master reads nothing until the slave writes a newline, - // and the master gets a carriage return. - constexpr char kInput[] = "hello\n"; - constexpr char kExpected[] = "hello\r\n"; - - EXPECT_THAT(WriteFd(slave_.get(), kInput, sizeof(kInput) - 1), - SyscallSucceedsWithValue(sizeof(kInput) - 1)); - - // Linux moves data from the master to the slave via async work scheduled via - // tty_flip_buffer_push. Since it is asynchronous, the data may not be - // available for reading immediately. Instead we must poll and assert that it - // becomes available "soon". - - char buf[sizeof(kExpected)] = {}; - ExpectReadable(master_, sizeof(buf) - 1, buf); - - EXPECT_EQ(memcmp(buf, kExpected, sizeof(kExpected)), 0); -} - -TEST_F(PtyTest, WriteInvalidUTF8) { - char c = 0xff; - ASSERT_THAT(syscall(__NR_write, master_.get(), &c, sizeof(c)), - SyscallSucceedsWithValue(sizeof(c))); -} - -// Both the master and slave report the standard default termios settings. -// -// Note that TCGETS on the master actually redirects to the slave (see comment -// on MasterTermiosUnchangable). -TEST_F(PtyTest, DefaultTermios) { - struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); - EXPECT_EQ(t, DefaultTermios()); - - EXPECT_THAT(ioctl(master_.get(), TCGETS, &t), SyscallSucceeds()); - EXPECT_EQ(t, DefaultTermios()); -} - -// Changing termios from the master actually affects the slave. -// -// TCSETS on the master actually redirects to the slave (see comment on -// MasterTermiosUnchangable). -TEST_F(PtyTest, TermiosAffectsSlave) { - struct kernel_termios master_termios = {}; - EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds()); - master_termios.c_lflag ^= ICANON; - EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); - - struct kernel_termios slave_termios = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &slave_termios), SyscallSucceeds()); - EXPECT_EQ(master_termios, slave_termios); -} - -// The master end of the pty has termios: -// -// struct kernel_termios t = { -// .c_iflag = 0; -// .c_oflag = 0; -// .c_cflag = B38400 | CS8 | CREAD; -// .c_lflag = 0; -// .c_cc = /* same as DefaultTermios */ -// } -// -// (From drivers/tty/pty.c:unix98_pty_init) -// -// All termios control ioctls on the master actually redirect to the slave -// (drivers/tty/tty_ioctl.c:tty_mode_ioctl), making it impossible to change the -// master termios. -// -// Verify this by setting ICRNL (which rewrites input \r to \n) and verify that -// it has no effect on the master. -TEST_F(PtyTest, MasterTermiosUnchangable) { - char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); - - ExpectReadable(master_, 1, &c); - EXPECT_EQ(c, '\r'); // ICRNL had no effect! - - ExpectFinished(master_); -} - -// ICRNL rewrites input \r to \n. -TEST_F(PtyTest, TermiosICRNL) { - struct kernel_termios t = DefaultTermios(); - t.c_iflag |= ICRNL; - t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - - char c = '\r'; - ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - - ExpectReadable(slave_, 1, &c); - EXPECT_EQ(c, '\n'); - - ExpectFinished(slave_); -} - -// ONLCR rewrites output \n to \r\n. -TEST_F(PtyTest, TermiosONLCR) { - struct kernel_termios t = DefaultTermios(); - t.c_oflag |= ONLCR; - t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - - char c = '\n'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); - - // Extra byte for NUL for EXPECT_STREQ. - char buf[3] = {}; - ExpectReadable(master_, 2, buf); - EXPECT_STREQ(buf, "\r\n"); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, TermiosIGNCR) { - struct kernel_termios t = DefaultTermios(); - t.c_iflag |= IGNCR; - t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - - char c = '\r'; - ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - - // Nothing to read. - ASSERT_THAT(PollAndReadFd(slave_.get(), &c, 1, kTimeout), - PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); -} - -// Test that we can successfully poll for readable data from the slave. -TEST_F(PtyTest, TermiosPollSlave) { - struct kernel_termios t = DefaultTermios(); - t.c_iflag |= IGNCR; - t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - - absl::Notification notify; - int sfd = slave_.get(); - ScopedThread th([sfd, ¬ify]() { - notify.Notify(); - - // Poll on the reader fd with POLLIN event. - struct pollfd poll_fd = {sfd, POLLIN, 0}; - EXPECT_THAT( - RetryEINTR(poll)(&poll_fd, 1, absl::ToInt64Milliseconds(kTimeout)), - SyscallSucceedsWithValue(1)); - - // Should trigger POLLIN event. - EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN); - }); - - notify.WaitForNotification(); - // Sleep ensures that poll begins waiting before we write to the FD. - absl::SleepFor(absl::Seconds(1)); - - char s[] = "foo\n"; - ASSERT_THAT(WriteFd(master_.get(), s, strlen(s) + 1), SyscallSucceeds()); -} - -// Test that we can successfully poll for readable data from the master. -TEST_F(PtyTest, TermiosPollMaster) { - struct kernel_termios t = DefaultTermios(); - t.c_iflag |= IGNCR; - t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(master_.get(), TCSETS, &t), SyscallSucceeds()); - - absl::Notification notify; - int mfd = master_.get(); - ScopedThread th([mfd, ¬ify]() { - notify.Notify(); - - // Poll on the reader fd with POLLIN event. - struct pollfd poll_fd = {mfd, POLLIN, 0}; - EXPECT_THAT( - RetryEINTR(poll)(&poll_fd, 1, absl::ToInt64Milliseconds(kTimeout)), - SyscallSucceedsWithValue(1)); - - // Should trigger POLLIN event. - EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN); - }); - - notify.WaitForNotification(); - // Sleep ensures that poll begins waiting before we write to the FD. - absl::SleepFor(absl::Seconds(1)); - - char s[] = "foo\n"; - ASSERT_THAT(WriteFd(slave_.get(), s, strlen(s) + 1), SyscallSucceeds()); -} - -TEST_F(PtyTest, TermiosINLCR) { - struct kernel_termios t = DefaultTermios(); - t.c_iflag |= INLCR; - t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - - char c = '\n'; - ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - - ExpectReadable(slave_, 1, &c); - EXPECT_EQ(c, '\r'); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, TermiosONOCR) { - struct kernel_termios t = DefaultTermios(); - t.c_oflag |= ONOCR; - t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - - // The terminal is at column 0, so there should be no CR to read. - char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); - - // Nothing to read. - ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), - PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); - - // This time the column is greater than 0, so we should be able to read the CR - // out of the other end. - constexpr char kInput[] = "foo\r"; - constexpr int kInputSize = sizeof(kInput) - 1; - ASSERT_THAT(WriteFd(slave_.get(), kInput, kInputSize), - SyscallSucceedsWithValue(kInputSize)); - - char buf[kInputSize] = {}; - ExpectReadable(master_, kInputSize, buf); - - EXPECT_EQ(memcmp(buf, kInput, kInputSize), 0); - - ExpectFinished(master_); - - // Terminal should be at column 0 again, so no CR can be read. - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); - - // Nothing to read. - ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), - PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); -} - -TEST_F(PtyTest, TermiosOCRNL) { - struct kernel_termios t = DefaultTermios(); - t.c_oflag |= OCRNL; - t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - - // The terminal is at column 0, so there should be no CR to read. - char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); - - ExpectReadable(master_, 1, &c); - EXPECT_EQ(c, '\n'); - - ExpectFinished(master_); -} - -// Tests that VEOL is disabled when we start, and that we can set it to enable -// it. -TEST_F(PtyTest, VEOLTermination) { - // Write a few bytes ending with '\0', and confirm that we can't read. - constexpr char kInput[] = "hello"; - ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), - SyscallSucceedsWithValue(sizeof(kInput))); - char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), - PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); - - // Set the EOL character to '=' and write it. - constexpr char delim = '='; - struct kernel_termios t = DefaultTermios(); - t.c_cc[VEOL] = delim; - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); - - // Now we can read, as sending EOL caused the line to become available. - ExpectReadable(slave_, sizeof(kInput), buf); - EXPECT_EQ(memcmp(buf, kInput, sizeof(kInput)), 0); - - ExpectReadable(slave_, 1, buf); - EXPECT_EQ(buf[0], '='); - - ExpectFinished(slave_); -} - -// Tests that we can write more than the 4096 character limit, then a -// terminating character, then read out just the first 4095 bytes plus the -// terminator. -TEST_F(PtyTest, CanonBigWrite) { - constexpr int kWriteLen = kMaxLineSize + 4; - char input[kWriteLen]; - memset(input, 'M', kWriteLen - 1); - input[kWriteLen - 1] = '\n'; - ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), - SyscallSucceedsWithValue(kWriteLen)); - - // We can read the line. - char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); - - ExpectFinished(slave_); -} - -// Tests that data written in canonical mode can be read immediately once -// switched to noncanonical mode. -TEST_F(PtyTest, SwitchCanonToNoncanon) { - // Write a few bytes without a terminating character, switch to noncanonical - // mode, and read them. - constexpr char kInput[] = "hello"; - ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), - SyscallSucceedsWithValue(sizeof(kInput))); - - // Nothing available yet. - char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), - PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); - - DisableCanonical(); - - ExpectReadable(slave_, sizeof(kInput), buf); - EXPECT_STREQ(buf, kInput); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, SwitchCanonToNonCanonNewline) { - // Write a few bytes with a terminating character. - constexpr char kInput[] = "hello\n"; - ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), - SyscallSucceedsWithValue(sizeof(kInput))); - - DisableCanonical(); - - // We can read the line. - char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput), buf); - EXPECT_STREQ(buf, kInput); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { - DisableCanonical(); - - // Write more than the maximum line size, then write a delimiter. - constexpr int kWriteLen = 4100; - char input[kWriteLen]; - memset(input, 'M', kWriteLen); - ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), - SyscallSucceedsWithValue(kWriteLen)); - // Wait for the input queue to fill. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); - constexpr char delim = '\n'; - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); - - EnableCanonical(); - - // We can read the line. - char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); - - // We can also read the remaining characters. - ExpectReadable(slave_, 6, buf); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { - DisableCanonical(); - - // Write a few bytes without a terminating character. - // mode, and read them. - constexpr char kInput[] = "hello"; - ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput) - 1), - SyscallSucceedsWithValue(sizeof(kInput) - 1)); - - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput) - 1)); - EnableCanonical(); - - // We can read the line. - char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput) - 1, buf); - EXPECT_STREQ(buf, kInput); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { - DisableCanonical(); - - // Write a few bytes without a terminating character. - // mode, and read them. - constexpr int kWriteLen = 4100; - char input[kWriteLen]; - memset(input, 'M', kWriteLen); - ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), - SyscallSucceedsWithValue(kWriteLen)); - - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); - EnableCanonical(); - - // We can read the line. - char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); - - ExpectFinished(slave_); -} - -// Tests that we can write over the 4095 noncanonical limit, then read out -// everything. -TEST_F(PtyTest, NoncanonBigWrite) { - DisableCanonical(); - - // Write well over the 4095 internal buffer limit. - constexpr char kInput = 'M'; - constexpr int kInputSize = kMaxLineSize * 2; - for (int i = 0; i < kInputSize; i++) { - // This makes too many syscalls for save/restore. - const DisableSave ds; - ASSERT_THAT(WriteFd(master_.get(), &kInput, sizeof(kInput)), - SyscallSucceedsWithValue(sizeof(kInput))); - } - - // We should be able to read out everything. Sleep a bit so that Linux has a - // chance to move data from the master to the slave. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); - for (int i = 0; i < kInputSize; i++) { - // This makes too many syscalls for save/restore. - const DisableSave ds; - char c; - ExpectReadable(slave_, 1, &c); - ASSERT_EQ(c, kInput); - } - - ExpectFinished(slave_); -} - -// ICANON doesn't make input available until a line delimiter is typed. -// -// Test newline. -TEST_F(PtyTest, TermiosICANONNewline) { - char input[3] = {'a', 'b', 'c'}; - ASSERT_THAT(WriteFd(master_.get(), input, sizeof(input)), - SyscallSucceedsWithValue(sizeof(input))); - - // Extra bytes for newline (written later) and NUL for EXPECT_STREQ. - char buf[5] = {}; - - // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), - PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); - - char delim = '\n'; - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); - - // Now it is available. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(input) + 1)); - ExpectReadable(slave_, sizeof(input) + 1, buf); - EXPECT_STREQ(buf, "abc\n"); - - ExpectFinished(slave_); -} - -// ICANON doesn't make input available until a line delimiter is typed. -// -// Test EOF (^D). -TEST_F(PtyTest, TermiosICANONEOF) { - char input[3] = {'a', 'b', 'c'}; - ASSERT_THAT(WriteFd(master_.get(), input, sizeof(input)), - SyscallSucceedsWithValue(sizeof(input))); - - // Extra byte for NUL for EXPECT_STREQ. - char buf[4] = {}; - - // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), - PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); - char delim = ControlCharacter('D'); - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); - - // Now it is available. Note that ^D is not included. - ExpectReadable(slave_, sizeof(input), buf); - EXPECT_STREQ(buf, "abc"); - - ExpectFinished(slave_); -} - -// ICANON limits us to 4096 bytes including a terminating character. Anything -// after and 4095th character is discarded (although still processed for -// signals and echoing). -TEST_F(PtyTest, CanonDiscard) { - constexpr char kInput = 'M'; - constexpr int kInputSize = 4100; - constexpr int kIter = 3; - - // A few times write more than the 4096 character maximum, then a newline. - constexpr char delim = '\n'; - for (int i = 0; i < kIter; i++) { - // This makes too many syscalls for save/restore. - const DisableSave ds; - for (int i = 0; i < kInputSize; i++) { - ASSERT_THAT(WriteFd(master_.get(), &kInput, sizeof(kInput)), - SyscallSucceedsWithValue(sizeof(kInput))); - } - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); - } - - // There should be multiple truncated lines available to read. - for (int i = 0; i < kIter; i++) { - char buf[kInputSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); - EXPECT_EQ(buf[kMaxLineSize - 1], delim); - EXPECT_EQ(buf[kMaxLineSize - 2], kInput); - } - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, CanonMultiline) { - constexpr char kInput1[] = "GO\n"; - constexpr char kInput2[] = "BLUE\n"; - - // Write both lines. - ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), - SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_THAT(WriteFd(master_.get(), kInput2, sizeof(kInput2) - 1), - SyscallSucceedsWithValue(sizeof(kInput2) - 1)); - - // Get the first line. - char line1[8] = {}; - ExpectReadable(slave_, sizeof(kInput1) - 1, line1); - EXPECT_STREQ(line1, kInput1); - - // Get the second line. - char line2[8] = {}; - ExpectReadable(slave_, sizeof(kInput2) - 1, line2); - EXPECT_STREQ(line2, kInput2); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { - DisableCanonical(); - - constexpr char kInput1[] = "GO\n"; - constexpr char kInput2[] = "BLUE\n"; - constexpr char kExpected[] = "GO\nBLUE\n"; - - // Write both lines. - ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), - SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_THAT(WriteFd(master_.get(), kInput2, sizeof(kInput2) - 1), - SyscallSucceedsWithValue(sizeof(kInput2) - 1)); - - ASSERT_NO_ERRNO( - WaitUntilReceived(slave_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); - EnableCanonical(); - - // Get all together as one line. - char line[9] = {}; - ExpectReadable(slave_, 8, line); - EXPECT_STREQ(line, kExpected); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, SwitchTwiceMultiline) { - std::string kInputs[] = {"GO\n", "BLUE\n", "!"}; - std::string kExpected = "GO\nBLUE\n!"; - - // Write each line. - for (std::string input : kInputs) { - ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()), - SyscallSucceedsWithValue(input.size())); - } - - DisableCanonical(); - // All written characters have to make it into the input queue before - // canonical mode is re-enabled. If the final '!' character hasn't been - // enqueued before canonical mode is re-enabled, it won't be readable. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kExpected.size())); - EnableCanonical(); - - // Get all together as one line. - char line[10] = {}; - ExpectReadable(slave_, 9, line); - EXPECT_STREQ(line, kExpected.c_str()); - - ExpectFinished(slave_); -} - -TEST_F(PtyTest, QueueSize) { - // Write the line. - constexpr char kInput1[] = "GO\n"; - ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), - SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); - - // Ensure that writing more (beyond what is readable) does not impact the - // readable size. - char input[kMaxLineSize]; - memset(input, 'M', kMaxLineSize); - ASSERT_THAT(WriteFd(master_.get(), input, kMaxLineSize), - SyscallSucceedsWithValue(kMaxLineSize)); - int inputBufSize = ASSERT_NO_ERRNO_AND_VALUE( - WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); - EXPECT_EQ(inputBufSize, sizeof(kInput1) - 1); -} - -TEST_F(PtyTest, PartialBadBuffer) { - // Allocate 2 pages. - void* addr = mmap(nullptr, 2 * kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - ASSERT_NE(addr, MAP_FAILED); - char* buf = reinterpret_cast<char*>(addr); - - // Guard the 2nd page for our read to run into. - ASSERT_THAT( - mprotect(reinterpret_cast<void*>(buf + kPageSize), kPageSize, PROT_NONE), - SyscallSucceeds()); - - // Leave only one free byte in the buffer. - char* bad_buffer = buf + kPageSize - 1; - - // Write to the master. - constexpr char kBuf[] = "hello\n"; - constexpr size_t size = sizeof(kBuf) - 1; - EXPECT_THAT(WriteFd(master_.get(), kBuf, size), - SyscallSucceedsWithValue(size)); - - // Read from the slave into bad_buffer. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), size)); - EXPECT_THAT(ReadFd(slave_.get(), bad_buffer, size), - SyscallFailsWithErrno(EFAULT)); - - EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()) << addr; -} - -TEST_F(PtyTest, SimpleEcho) { - constexpr char kInput[] = "Mr. Eko"; - EXPECT_THAT(WriteFd(master_.get(), kInput, strlen(kInput)), - SyscallSucceedsWithValue(strlen(kInput))); - - char buf[100] = {}; - ExpectReadable(master_, strlen(kInput), buf); - - EXPECT_STREQ(buf, kInput); - ExpectFinished(master_); -} - -TEST_F(PtyTest, GetWindowSize) { - struct winsize ws; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); - EXPECT_EQ(ws.ws_row, 0); - EXPECT_EQ(ws.ws_col, 0); -} - -TEST_F(PtyTest, SetSlaveWindowSize) { - constexpr uint16_t kRows = 343; - constexpr uint16_t kCols = 2401; - struct winsize ws = {.ws_row = kRows, .ws_col = kCols}; - ASSERT_THAT(ioctl(slave_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); - - struct winsize retrieved_ws = {}; - ASSERT_THAT(ioctl(master_.get(), TIOCGWINSZ, &retrieved_ws), - SyscallSucceeds()); - EXPECT_EQ(retrieved_ws.ws_row, kRows); - EXPECT_EQ(retrieved_ws.ws_col, kCols); -} - -TEST_F(PtyTest, SetMasterWindowSize) { - constexpr uint16_t kRows = 343; - constexpr uint16_t kCols = 2401; - struct winsize ws = {.ws_row = kRows, .ws_col = kCols}; - ASSERT_THAT(ioctl(master_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); - - struct winsize retrieved_ws = {}; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &retrieved_ws), - SyscallSucceeds()); - EXPECT_EQ(retrieved_ws.ws_row, kRows); - EXPECT_EQ(retrieved_ws.ws_col, kCols); -} - -class JobControlTest : public ::testing::Test { - protected: - void SetUp() override { - master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); - - // Make this a session leader, which also drops the controlling terminal. - // In the gVisor test environment, this test will be run as the session - // leader already (as the sentry init process). - if (!IsRunningOnGvisor()) { - ASSERT_THAT(setsid(), SyscallSucceeds()); - } - } - - // Master and slave ends of the PTY. Non-blocking. - FileDescriptor master_; - FileDescriptor slave_; -}; - -TEST_F(JobControlTest, SetTTYMaster) { - ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds()); -} - -TEST_F(JobControlTest, SetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); -} - -TEST_F(JobControlTest, SetTTYNonLeader) { - // Fork a process that won't be the session leader. - pid_t child = fork(); - if (!child) { - // We shouldn't be able to set the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); -} - -TEST_F(JobControlTest, SetTTYBadArg) { - // Despite the man page saying arg should be 0 here, Linux doesn't actually - // check. - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds()); -} - -TEST_F(JobControlTest, SetTTYDifferentSession) { - SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Fork, join a new session, and try to steal the parent's controlling - // terminal, which should fail. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(setsid() >= 0); - // We shouldn't be able to steal the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); -} - -TEST_F(JobControlTest, ReleaseTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Make sure we're ignoring SIGHUP, which will be sent to this process once we - // disconnect they TTY. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - sigemptyset(&sa.sa_mask); - struct sigaction old_sa; - EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds()); - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); - EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); -} - -TEST_F(JobControlTest, ReleaseUnsetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); -} - -TEST_F(JobControlTest, ReleaseWrongTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); -} - -TEST_F(JobControlTest, ReleaseTTYNonLeader) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); -} - -TEST_F(JobControlTest, ReleaseTTYDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - pid_t child = fork(); - if (!child) { - // Join a new session, then try to disconnect. - TEST_PCHECK(setsid() >= 0); - TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); -} - -// Used by the child process spawned in ReleaseTTYSignals to track received -// signals. -static int received; - -void sig_handler(int signum) { received |= signum; } - -// When the session leader releases its controlling terminal, the foreground -// process group gets SIGHUP, then SIGCONT. This test: -// - Spawns 2 threads -// - Has thread 1 return 0 if it gets both SIGHUP and SIGCONT -// - Has thread 2 leave the foreground process group, and return non-zero if it -// receives any signals. -// - Has the parent thread release its controlling terminal -// - Checks that thread 1 got both signals -// - Checks that thread 2 didn't get any signals. -TEST_F(JobControlTest, ReleaseTTYSignals) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - received = 0; - struct sigaction sa = {}; - sa.sa_handler = sig_handler; - sa.sa_flags = 0; - sigemptyset(&sa.sa_mask); - sigaddset(&sa.sa_mask, SIGHUP); - sigaddset(&sa.sa_mask, SIGCONT); - sigprocmask(SIG_BLOCK, &sa.sa_mask, NULL); - - pid_t same_pgrp_child = fork(); - if (!same_pgrp_child) { - // The child will wait for SIGHUP and SIGCONT, then return 0. It begins with - // SIGHUP and SIGCONT blocked. We install signal handlers for those signals, - // then use sigsuspend to wait for those specific signals. - TEST_PCHECK(!sigaction(SIGHUP, &sa, NULL)); - TEST_PCHECK(!sigaction(SIGCONT, &sa, NULL)); - sigset_t mask; - sigfillset(&mask); - sigdelset(&mask, SIGHUP); - sigdelset(&mask, SIGCONT); - while (received != (SIGHUP | SIGCONT)) { - sigsuspend(&mask); - } - _exit(0); - } - - // We don't want to block these anymore. - sigprocmask(SIG_UNBLOCK, &sa.sa_mask, NULL); - - // This child will return non-zero if either SIGHUP or SIGCONT are received. - pid_t diff_pgrp_child = fork(); - if (!diff_pgrp_child) { - TEST_PCHECK(!setpgid(0, 0)); - TEST_PCHECK(pause()); - _exit(1); - } - - EXPECT_THAT(setpgid(diff_pgrp_child, diff_pgrp_child), SyscallSucceeds()); - - // Make sure we're ignoring SIGHUP, which will be sent to this process once we - // disconnect they TTY. - struct sigaction sighup_sa = {}; - sighup_sa.sa_handler = SIG_IGN; - sighup_sa.sa_flags = 0; - sigemptyset(&sighup_sa.sa_mask); - struct sigaction old_sa; - EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds()); - - // Release the controlling terminal, sending SIGHUP and SIGCONT to all other - // processes in this process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); - - EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); - - // The child in the same process group will get signaled. - int wstatus; - EXPECT_THAT(waitpid(same_pgrp_child, &wstatus, 0), - SyscallSucceedsWithValue(same_pgrp_child)); - EXPECT_EQ(wstatus, 0); - - // The other child will not get signaled. - EXPECT_THAT(waitpid(diff_pgrp_child, &wstatus, WNOHANG), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(kill(diff_pgrp_child, SIGKILL), SyscallSucceeds()); -} - -TEST_F(JobControlTest, GetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - pid_t foreground_pgid; - pid_t pid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), - SyscallSucceeds()); - ASSERT_THAT(pid = getpid(), SyscallSucceeds()); - - ASSERT_EQ(foreground_pgid, pid); -} - -TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { - // At this point there's no controlling terminal, so TIOCGPGRP should fail. - pid_t foreground_pgid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), - SyscallFailsWithErrno(ENOTTY)); -} - -// This test: -// - sets itself as the foreground process group -// - creates a child process in a new process group -// - sets that child as the foreground process group -// - kills its child and sets itself as the foreground process group. -TEST_F(JobControlTest, SetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - sigemptyset(&sa.sa_mask); - sigaction(SIGTTOU, &sa, NULL); - - // Set ourself as the foreground process group. - ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds()); - - // Create a new process that just waits to be signaled. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!pause()); - // We should never reach this. - _exit(1); - } - - // Make the child its own process group, then make it the controlling process - // group of the terminal. - ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds()); - - // Sanity check - we're still the controlling session. - ASSERT_EQ(getsid(0), getsid(child)); - - // Signal the child, wait for it to exit, then retake the terminal. - ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds()); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_TRUE(WIFSIGNALED(wstatus)); - ASSERT_EQ(WTERMSIG(wstatus), SIGTERM); - - // Set ourself as the foreground process. - pid_t pgid; - ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds()); -} - -TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { - pid_t pid = getpid(); - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), - SyscallFailsWithErrno(ENOTTY)); -} - -TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - pid_t pid = -1; - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Create a new process, put it in a new process group, make that group the - // foreground process group, then have the process wait. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!setpgid(0, 0)); - _exit(0); - } - - // Wait for the child to exit. - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - // The child's process group doesn't exist anymore - this should fail. - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(ESRCH)); -} - -TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - int sync_setsid[2]; - int sync_exit[2]; - ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds()); - ASSERT_THAT(pipe(sync_exit), SyscallSucceeds()); - - // Create a new process and put it in a new session. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(setsid() >= 0); - // Tell the parent we're in a new session. - char c = 'c'; - TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); - TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); - _exit(0); - } - - // Wait for the child to tell us it's in a new session. - char c = 'c'; - ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1)); - - // Child is in a new session, so we can't make it the foregroup process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(EPERM)); - - EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1)); - - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFEXITED(wstatus)); - EXPECT_EQ(WEXITSTATUS(wstatus), 0); -} - -// Verify that we don't hang when creating a new session from an orphaned -// process group (b/139968068). Calling setsid() creates an orphaned process -// group, as process groups that contain the session's leading process are -// orphans. -// -// We create 2 sessions in this test. The init process in gVisor is considered -// not to be an orphan (see sessions.go), so we have to create a session from -// which to create a session. The latter session is being created from an -// orphaned process group. -TEST_F(JobControlTest, OrphanRegression) { - pid_t session_2_leader = fork(); - if (!session_2_leader) { - TEST_PCHECK(setsid() >= 0); - - pid_t session_3_leader = fork(); - if (!session_3_leader) { - TEST_PCHECK(setsid() >= 0); - - _exit(0); - } - - int wstatus; - TEST_PCHECK(waitpid(session_3_leader, &wstatus, 0) == session_3_leader); - TEST_PCHECK(wstatus == 0); - - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(session_2_leader, &wstatus, 0), - SyscallSucceedsWithValue(session_2_leader)); - ASSERT_EQ(wstatus, 0); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc deleted file mode 100644 index 14a4af980..000000000 --- a/test/syscalls/linux/pty_root.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/ioctl.h> -#include <termios.h> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/pty_util.h" - -namespace gvisor { -namespace testing { - -// These tests should be run as root. -namespace { - -TEST(JobControlRootTest, StealTTY) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - // Make this a session leader, which also drops the controlling terminal. - // In the gVisor test environment, this test will be run as the session - // leader already (as the sentry init process). - if (!IsRunningOnGvisor()) { - ASSERT_THAT(setsid(), SyscallSucceeds()); - } - - FileDescriptor master = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); - - // Make slave the controlling terminal. - ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Fork, join a new session, and try to steal the parent's controlling - // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg - // of 1. - pid_t child = fork(); - if (!child) { - ASSERT_THAT(setsid(), SyscallSucceeds()); - // We shouldn't be able to steal the terminal with the wrong arg value. - TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); - // We should be able to steal it here. - TEST_PCHECK(!ioctl(slave.get(), TIOCSCTTY, 1)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc deleted file mode 100644 index b48fe540d..000000000 --- a/test/syscalls/linux/pwrite64.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is not incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. -class Pwrite64 : public ::testing::Test { - void SetUp() override { - name_ = NewTempAbsPath(); - int fd; - ASSERT_THAT(fd = open(name_.c_str(), O_CREAT, 0644), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - } - - void TearDown() override { unlink(name_.c_str()); } - - public: - std::string name_; -}; - -TEST_F(Pwrite64, AppendOnly) { - int fd; - ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds()); - constexpr int64_t kBufSize = 1024; - std::vector<char> buf(kBufSize); - std::fill(buf.begin(), buf.end(), 'a'); - EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0), - SyscallSucceedsWithValue(buf.size())); - EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST_F(Pwrite64, InvalidArgs) { - int fd; - ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds()); - constexpr int64_t kBufSize = 1024; - std::vector<char> buf(kBufSize); - std::fill(buf.begin(), buf.end(), 'a'); - EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), -1), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc deleted file mode 100644 index 3fe5a600f..000000000 --- a/test/syscalls/linux/pwritev2.cc +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <sys/uio.h> - -#include <string> -#include <vector> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -#ifndef SYS_pwritev2 -#if defined(__x86_64__) -#define SYS_pwritev2 328 -#elif defined(__aarch64__) -#define SYS_pwritev2 287 -#else -#error "Unknown architecture" -#endif -#endif // SYS_pwrite2 - -#ifndef RWF_HIPRI -#define RWF_HIPRI 0x1 -#endif // RWF_HIPRI - -#ifndef RWF_DSYNC -#define RWF_DSYNC 0x2 -#endif // RWF_DSYNC - -#ifndef RWF_SYNC -#define RWF_SYNC 0x4 -#endif // RWF_SYNC - -constexpr int kBufSize = 1024; - -void SetContent(std::vector<char>& content) { - for (uint i = 0; i < content.size(); i++) { - content[i] = static_cast<char>((i % 10) + '0'); - } -} - -ssize_t pwritev2(unsigned long fd, const struct iovec* iov, - unsigned long iovcnt, off_t offset, unsigned long flags) { - // syscall on pwritev2 does some weird things (see man syscall and search - // pwritev2), so we insert a 0 to word align the flags argument on native. - return syscall(SYS_pwritev2, fd, iov, iovcnt, offset, 0, flags); -} - -// This test is the base case where we call pwritev (no offset, no flags). -TEST(Writev2Test, TestBaseCall) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - std::vector<char> content(kBufSize); - SetContent(content); - struct iovec iov[2]; - iov[0].iov_base = content.data(); - iov[0].iov_len = content.size() / 2; - iov[1].iov_base = static_cast<char*>(iov[0].iov_base) + (content.size() / 2); - iov[1].iov_len = content.size() / 2; - - ASSERT_THAT(pwritev2(fd.get(), iov, /*iovcnt=*/2, - /*offset=*/0, /*flags=*/0), - SyscallSucceedsWithValue(kBufSize)); - - std::vector<char> buf(kBufSize); - EXPECT_THAT(read(fd.get(), buf.data(), kBufSize), - SyscallSucceedsWithValue(kBufSize)); - - EXPECT_EQ(content, buf); -} - -// This test is where we call pwritev2 with a positive offset and no flags. -TEST(Pwritev2Test, TestValidPositiveOffset) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - std::string prefix(kBufSize, '0'); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), prefix, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - std::vector<char> content(kBufSize); - SetContent(content); - struct iovec iov; - iov.iov_base = content.data(); - iov.iov_len = content.size(); - - ASSERT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/prefix.size(), /*flags=*/0), - SyscallSucceedsWithValue(content.size())); - - std::vector<char> buf(prefix.size() + content.size()); - EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - std::vector<char> want(prefix.begin(), prefix.end()); - want.insert(want.end(), content.begin(), content.end()); - EXPECT_EQ(want, buf); -} - -// This test is the base case where we call writev by using -1 as the offset. -// The write should use the file offset, so the test increments the file offset -// prior to call pwritev2. -TEST(Pwritev2Test, TestNegativeOneOffset) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const std::string prefix = "00"; - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), prefix.data(), TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - ASSERT_THAT(lseek(fd.get(), prefix.size(), SEEK_SET), - SyscallSucceedsWithValue(prefix.size())); - - std::vector<char> content(kBufSize); - SetContent(content); - struct iovec iov; - iov.iov_base = content.data(); - iov.iov_len = content.size(); - - ASSERT_THAT(pwritev2(fd.get(), &iov, /*iovcnt*/ 1, - /*offset=*/static_cast<off_t>(-1), /*flags=*/0), - SyscallSucceedsWithValue(content.size())); - - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(prefix.size() + content.size())); - - std::vector<char> buf(prefix.size() + content.size()); - EXPECT_THAT(pread(fd.get(), buf.data(), buf.size(), /*offset=*/0), - SyscallSucceedsWithValue(buf.size())); - - std::vector<char> want(prefix.begin(), prefix.end()); - want.insert(want.end(), content.begin(), content.end()); - EXPECT_EQ(want, buf); -} - -// pwritev2 requires if the RWF_HIPRI flag is passed, the fd must be opened with -// O_DIRECT. This test implements a correct call with the RWF_HIPRI flag. -TEST(Pwritev2Test, TestCallWithRWF_HIPRI) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - std::vector<char> content(kBufSize); - SetContent(content); - struct iovec iov; - iov.iov_base = content.data(); - iov.iov_len = content.size(); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/RWF_HIPRI), - SyscallSucceedsWithValue(kBufSize)); - - std::vector<char> buf(content.size()); - EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(buf, content); -} - -// This test checks that pwritev2 can be called with valid flags -TEST(Pwritev2Test, TestCallWithValidFlags) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - std::vector<char> content(kBufSize, '0'); - struct iovec iov; - iov.iov_base = content.data(); - iov.iov_len = content.size(); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/RWF_DSYNC), - SyscallSucceedsWithValue(kBufSize)); - - std::vector<char> buf(content.size()); - EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(buf, content); - - SetContent(content); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/0x4), - SyscallSucceedsWithValue(kBufSize)); - - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(content.size())); - - EXPECT_THAT(pread(fd.get(), buf.data(), buf.size(), /*offset=*/0), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(buf, content); -} - -// This test calls pwritev2 with a bad file descriptor. -TEST(Writev2Test, TestBadFile) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - ASSERT_THAT(pwritev2(/*fd=*/-1, /*iov=*/nullptr, /*iovcnt=*/0, - /*offset=*/0, /*flags=*/0), - SyscallFailsWithErrno(EBADF)); -} - -// This test calls pwrite2 with an invalid offset. -TEST(Pwritev2Test, TestInvalidOffset) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - char buf[16]; - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/static_cast<off_t>(-8), /*flags=*/0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(Pwritev2Test, TestUnseekableFileValid) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - int pipe_fds[2]; - - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - - std::vector<char> content(32, '0'); - SetContent(content); - struct iovec iov; - iov.iov_base = content.data(); - iov.iov_len = content.size(); - - EXPECT_THAT(pwritev2(pipe_fds[1], &iov, /*iovcnt=*/1, - /*offset=*/static_cast<off_t>(-1), /*flags=*/0), - SyscallSucceedsWithValue(content.size())); - - std::vector<char> buf(content.size()); - EXPECT_THAT(read(pipe_fds[0], buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(content, buf); - - EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds()); -} - -// Calling pwritev2 with a non-negative offset calls pwritev. Calling pwritev -// with an unseekable file is not allowed. A pipe is used for an unseekable -// file. -TEST(Pwritev2Test, TestUnseekableFileInValid) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - int pipe_fds[2]; - char buf[16]; - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - - EXPECT_THAT(pwritev2(pipe_fds[1], &iov, /*iovcnt=*/1, - /*offset=*/2, /*flags=*/0), - SyscallFailsWithErrno(ESPIPE)); - - EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds()); -} - -TEST(Pwritev2Test, TestReadOnlyFile) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - char buf[16]; - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/0), - SyscallFailsWithErrno(EBADF)); -} - -// This test calls pwritev2 with an invalid flag. -TEST(Pwritev2Test, TestInvalidFlag) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR | O_DIRECT)); - - char buf[16]; - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/0xF0), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc deleted file mode 100644 index 0a27506aa..000000000 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <linux/capability.h> -#include <netinet/in.h> -#include <netinet/ip.h> -#include <netinet/ip_icmp.h> -#include <netinet/udp.h> -#include <poll.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> -#include <cstring> - -#include "gtest/gtest.h" -#include "absl/base/internal/endian.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Tests for IPPROTO_RAW raw sockets, which implies IP_HDRINCL. -class RawHDRINCL : public ::testing::Test { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // Returns a valid looback IP header with no payload. - struct iphdr LoopbackHeader(); - - // Fills in buf with an IP header, UDP header, and payload. Returns false if - // buf_size isn't large enough to hold everything. - bool FillPacket(char* buf, size_t buf_size, int port, const char* payload, - uint16_t payload_size); - - // The socket used for both reading and writing. - int socket_; - - // The loopback address. - struct sockaddr_in addr_; -}; - -void RawHDRINCL::SetUp() { - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_RAW), - SyscallFailsWithErrno(EPERM)); - GTEST_SKIP(); - } - - ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), - SyscallSucceeds()); - - addr_ = {}; - - addr_.sin_port = IPPROTO_IP; - addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr_.sin_family = AF_INET; -} - -void RawHDRINCL::TearDown() { - // TearDown will be run even if we skip the test. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - EXPECT_THAT(close(socket_), SyscallSucceeds()); - } -} - -struct iphdr RawHDRINCL::LoopbackHeader() { - struct iphdr hdr = {}; - hdr.ihl = 5; - hdr.version = 4; - hdr.tos = 0; - hdr.tot_len = absl::gbswap_16(sizeof(hdr)); - hdr.id = 0; - hdr.frag_off = 0; - hdr.ttl = 7; - hdr.protocol = 1; - hdr.daddr = htonl(INADDR_LOOPBACK); - // hdr.check is set by the network stack. - // hdr.tot_len is set by the network stack. - // hdr.saddr is set by the network stack. - return hdr; -} - -bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port, - const char* payload, uint16_t payload_size) { - if (buf_size < sizeof(struct iphdr) + sizeof(struct udphdr) + payload_size) { - return false; - } - - struct iphdr ip = LoopbackHeader(); - ip.protocol = IPPROTO_UDP; - - struct udphdr udp = {}; - udp.source = absl::gbswap_16(port); - udp.dest = absl::gbswap_16(port); - udp.len = absl::gbswap_16(sizeof(udp) + payload_size); - udp.check = 0; - - memcpy(buf, reinterpret_cast<char*>(&ip), sizeof(ip)); - memcpy(buf + sizeof(ip), reinterpret_cast<char*>(&udp), sizeof(udp)); - memcpy(buf + sizeof(ip) + sizeof(udp), payload, payload_size); - - return true; -} - -// We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup -// creates the first one, so we only have to create one more here. -TEST_F(RawHDRINCL, MultipleCreation) { - int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds()); - - ASSERT_THAT(close(s2), SyscallSucceeds()); -} - -// Test that shutting down an unconnected socket fails. -TEST_F(RawHDRINCL, FailShutdownWithoutConnect) { - ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); -} - -// Test that listen() fails. -TEST_F(RawHDRINCL, FailListen) { - ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP)); -} - -// Test that accept() fails. -TEST_F(RawHDRINCL, FailAccept) { - struct sockaddr saddr; - socklen_t addrlen; - ASSERT_THAT(accept(socket_, &saddr, &addrlen), - SyscallFailsWithErrno(ENOTSUP)); -} - -// Test that the socket is writable immediately. -TEST_F(RawHDRINCL, PollWritableImmediately) { - struct pollfd pfd = {}; - pfd.fd = socket_; - pfd.events = POLLOUT; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 0), SyscallSucceedsWithValue(1)); -} - -// Test that the socket isn't readable. -TEST_F(RawHDRINCL, NotReadable) { - // Try to receive data with MSG_DONTWAIT, which returns immediately if there's - // nothing to be read. - char buf[117]; - ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EINVAL)); -} - -// Test that we can connect() to a valid IP (loopback). -TEST_F(RawHDRINCL, ConnectToLoopback) { - ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)), - SyscallSucceeds()); -} - -TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { - struct iphdr hdr = LoopbackHeader(); - ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), - SyscallSucceedsWithValue(sizeof(hdr))); -} - -// HDRINCL implies write-only. Verify that we can't read a packet sent to -// loopback. -TEST_F(RawHDRINCL, NotReadableAfterWrite) { - ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)), - SyscallSucceeds()); - - // Construct a packet with an IP header, UDP header, and payload. - constexpr char kPayload[] = "odst"; - char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; - ASSERT_TRUE(FillPacket(packet, sizeof(packet), 40000 /* port */, kPayload, - sizeof(kPayload))); - - socklen_t addrlen = sizeof(addr_); - ASSERT_NO_FATAL_FAILURE( - sendto(socket_, reinterpret_cast<void*>(&packet), sizeof(packet), 0, - reinterpret_cast<struct sockaddr*>(&addr_), addrlen)); - - struct pollfd pfd = {}; - pfd.fd = socket_; - pfd.events = POLLIN; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); -} - -TEST_F(RawHDRINCL, WriteTooSmall) { - ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)), - SyscallSucceeds()); - - // This is smaller than the size of an IP header. - constexpr char kBuf[] = "JP5"; - ASSERT_THAT(send(socket_, kBuf, sizeof(kBuf), 0), - SyscallFailsWithErrno(EINVAL)); -} - -// Bind to localhost. -TEST_F(RawHDRINCL, BindToLocalhost) { - ASSERT_THAT( - bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); -} - -// Bind to a different address. -TEST_F(RawHDRINCL, BindToInvalid) { - struct sockaddr_in bind_addr = {}; - bind_addr.sin_family = AF_INET; - bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. - ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -// Send and receive a packet. -TEST_F(RawHDRINCL, SendAndReceive) { - int port = 40000; - if (!IsRunningOnGvisor()) { - port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( - PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); - } - - // IPPROTO_RAW sockets are write-only. We'll have to open another socket to - // read what we write. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); - - // Construct a packet with an IP header, UDP header, and payload. - constexpr char kPayload[] = "toto"; - char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; - ASSERT_TRUE( - FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); - - socklen_t addrlen = sizeof(addr_); - ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, - reinterpret_cast<struct sockaddr*>(&addr_), - addrlen)); - - // Receive the payload. - char recv_buf[sizeof(packet)]; - struct sockaddr_in src; - socklen_t src_size = sizeof(src); - ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, - reinterpret_cast<struct sockaddr*>(&src), &src_size), - SyscallSucceedsWithValue(sizeof(packet))); - EXPECT_EQ( - memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr), - sizeof(kPayload)), - 0); - // The network stack should have set the source address. - EXPECT_EQ(src.sin_family, AF_INET); - EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should be 0, as the packet is less than 68 bytes. - struct iphdr iphdr = {}; - memcpy(&iphdr, recv_buf, sizeof(iphdr)); - EXPECT_EQ(iphdr.id, 0); -} - -// Send and receive a packet with nonzero IP ID. -TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { - int port = 40000; - if (!IsRunningOnGvisor()) { - port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( - PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); - } - - // IPPROTO_RAW sockets are write-only. We'll have to open another socket to - // read what we write. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); - - // Construct a packet with an IP header, UDP header, and payload. Make the - // payload large enough to force an IP ID to be assigned. - constexpr char kPayload[128] = {}; - char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; - ASSERT_TRUE( - FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); - - socklen_t addrlen = sizeof(addr_); - ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, - reinterpret_cast<struct sockaddr*>(&addr_), - addrlen)); - - // Receive the payload. - char recv_buf[sizeof(packet)]; - struct sockaddr_in src; - socklen_t src_size = sizeof(src); - ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, - reinterpret_cast<struct sockaddr*>(&src), &src_size), - SyscallSucceedsWithValue(sizeof(packet))); - EXPECT_EQ( - memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr), - sizeof(kPayload)), - 0); - // The network stack should have set the source address. - EXPECT_EQ(src.sin_family, AF_INET); - EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should not be 0, as the packet was more than 68 bytes. - struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf); - EXPECT_NE(iphdr->id, 0); -} - -// Send and receive a packet where the sendto address is not the same as the -// provided destination. -TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { - int port = 40000; - if (!IsRunningOnGvisor()) { - port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( - PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); - } - - // IPPROTO_RAW sockets are write-only. We'll have to open another socket to - // read what we write. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); - - // Construct a packet with an IP header, UDP header, and payload. - constexpr char kPayload[] = "toto"; - char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; - ASSERT_TRUE( - FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); - // Overwrite the IP destination address with an IP we can't get to. - struct iphdr iphdr = {}; - memcpy(&iphdr, packet, sizeof(iphdr)); - iphdr.daddr = 42; - memcpy(packet, &iphdr, sizeof(iphdr)); - - socklen_t addrlen = sizeof(addr_); - ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, - reinterpret_cast<struct sockaddr*>(&addr_), - addrlen)); - - // Receive the payload, since sendto should replace the bad destination with - // localhost. - char recv_buf[sizeof(packet)]; - struct sockaddr_in src; - socklen_t src_size = sizeof(src); - ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, - reinterpret_cast<struct sockaddr*>(&src), &src_size), - SyscallSucceedsWithValue(sizeof(packet))); - EXPECT_EQ( - memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr), - sizeof(kPayload)), - 0); - // The network stack should have set the source address. - EXPECT_EQ(src.sin_family, AF_INET); - EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should be 0, as the packet is less than 68 bytes. - struct iphdr recv_iphdr = {}; - memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr)); - EXPECT_EQ(recv_iphdr.id, 0); - // The destination address should be localhost, not the bad IP we set - // initially. - EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc deleted file mode 100644 index 3de898df7..000000000 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ /dev/null @@ -1,514 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <linux/capability.h> -#include <netinet/in.h> -#include <netinet/ip.h> -#include <netinet/ip_icmp.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> -#include <cstdint> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// The size of an empty ICMP packet and IP header together. -constexpr size_t kEmptyICMPSize = 28; - -// ICMP raw sockets get their own special tests because Linux automatically -// responds to ICMP echo requests, and thus a single echo request sent via -// loopback leads to 2 received ICMP packets. - -class RawSocketICMPTest : public ::testing::Test { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // Checks that both an ICMP echo request and reply are received. Calls should - // be wrapped in ASSERT_NO_FATAL_FAILURE. - void ExpectICMPSuccess(const struct icmphdr& icmp); - - // Sends icmp via s_. - void SendEmptyICMP(const struct icmphdr& icmp); - - // Sends icmp via s_ to the given address. - void SendEmptyICMPTo(int sock, const struct sockaddr_in& addr, - const struct icmphdr& icmp); - - // Reads from s_ into recv_buf. - void ReceiveICMP(char* recv_buf, size_t recv_buf_len, size_t expected_size, - struct sockaddr_in* src); - - // Reads from sock into recv_buf. - void ReceiveICMPFrom(char* recv_buf, size_t recv_buf_len, - size_t expected_size, struct sockaddr_in* src, int sock); - - // The socket used for both reading and writing. - int s_; - - // The loopback address. - struct sockaddr_in addr_; -}; - -void RawSocketICMPTest::SetUp() { - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), - SyscallFailsWithErrno(EPERM)); - GTEST_SKIP(); - } - - ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); - - addr_ = {}; - - // "On raw sockets sin_port is set to the IP protocol." - ip(7). - addr_.sin_port = IPPROTO_IP; - addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr_.sin_family = AF_INET; -} - -void RawSocketICMPTest::TearDown() { - // TearDown will be run even if we skip the test. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - EXPECT_THAT(close(s_), SyscallSucceeds()); - } -} - -// We'll only read an echo in this case, as the kernel won't respond to the -// malformed ICMP checksum. -TEST_F(RawSocketICMPTest, SendAndReceiveBadChecksum) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, - // and ID. None of that should matter for raw sockets - the kernel should - // still give us the packet. - struct icmphdr icmp; - icmp.type = ICMP_ECHO; - icmp.code = 0; - icmp.checksum = 0; - icmp.un.echo.sequence = 2012; - icmp.un.echo.id = 2014; - ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); - - // Veryify that we get the echo, then that there's nothing else to read. - char recv_buf[kEmptyICMPSize]; - struct sockaddr_in src; - ASSERT_NO_FATAL_FAILURE( - ReceiveICMP(recv_buf, sizeof(recv_buf), sizeof(struct icmphdr), &src)); - EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0); - // The packet should be identical to what we sent. - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)), 0); - - // And there should be nothing left to read. - EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Send and receive an ICMP packet. -TEST_F(RawSocketICMPTest, SendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Prepare and send an ICMP packet. Use arbitrary junk for sequence and ID. - // None of that should matter for raw sockets - the kernel should still give - // us the packet. - struct icmphdr icmp; - icmp.type = ICMP_ECHO; - icmp.code = 0; - icmp.checksum = 0; - icmp.un.echo.sequence = 2012; - icmp.un.echo.id = 2014; - icmp.checksum = ICMPChecksum(icmp, NULL, 0); - ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); - - ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); -} - -// We should be able to create multiple raw sockets for the same protocol and -// receive the same packet on both. -TEST_F(RawSocketICMPTest, MultipleSocketReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - FileDescriptor s2 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP)); - - // Prepare and send an ICMP packet. Use arbitrary junk for sequence and ID. - // None of that should matter for raw sockets - the kernel should still give - // us the packet. - struct icmphdr icmp; - icmp.type = ICMP_ECHO; - icmp.code = 0; - icmp.checksum = 0; - icmp.un.echo.sequence = 2016; - icmp.un.echo.id = 2018; - icmp.checksum = ICMPChecksum(icmp, NULL, 0); - ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); - - // Both sockets will receive the echo request and reply in indeterminate - // order, so we'll need to read 2 packets from each. - - // Receive on socket 1. - constexpr int kBufSize = kEmptyICMPSize; - char recv_buf1[2][kBufSize]; - struct sockaddr_in src; - for (int i = 0; i < 2; i++) { - ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf1[i], - ABSL_ARRAYSIZE(recv_buf1[i]), - sizeof(struct icmphdr), &src)); - EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0); - } - - // Receive on socket 2. - char recv_buf2[2][kBufSize]; - for (int i = 0; i < 2; i++) { - ASSERT_NO_FATAL_FAILURE( - ReceiveICMPFrom(recv_buf2[i], ABSL_ARRAYSIZE(recv_buf2[i]), - sizeof(struct icmphdr), &src, s2.get())); - EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0); - } - - // Ensure both sockets receive identical packets. - int types[] = {ICMP_ECHO, ICMP_ECHOREPLY}; - for (int type : types) { - auto match_type = [=](char buf[kBufSize]) { - struct icmphdr* icmp = - reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr)); - return icmp->type == type; - }; - auto icmp1_it = - std::find_if(std::begin(recv_buf1), std::end(recv_buf1), match_type); - auto icmp2_it = - std::find_if(std::begin(recv_buf2), std::end(recv_buf2), match_type); - ASSERT_NE(icmp1_it, std::end(recv_buf1)); - ASSERT_NE(icmp2_it, std::end(recv_buf2)); - EXPECT_EQ(memcmp(*icmp1_it + sizeof(struct iphdr), - *icmp2_it + sizeof(struct iphdr), sizeof(icmp)), - 0); - } -} - -// A raw ICMP socket and ping socket should both receive the ICMP packets -// intended for the ping socket. -TEST_F(RawSocketICMPTest, RawAndPingSockets) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - FileDescriptor ping_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); - - // Ping sockets take care of the ICMP ID and checksum. - struct icmphdr icmp; - icmp.type = ICMP_ECHO; - icmp.code = 0; - icmp.un.echo.sequence = *static_cast<unsigned short*>(&icmp.un.echo.sequence); - ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, sizeof(icmp), 0, - reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)), - SyscallSucceedsWithValue(sizeof(icmp))); - - // Receive on socket 1, which receives the echo request and reply in - // indeterminate order. - constexpr int kBufSize = kEmptyICMPSize; - char recv_buf1[2][kBufSize]; - struct sockaddr_in src; - for (int i = 0; i < 2; i++) { - ASSERT_NO_FATAL_FAILURE( - ReceiveICMP(recv_buf1[i], kBufSize, sizeof(struct icmphdr), &src)); - EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0); - } - - // Receive on socket 2. Ping sockets only get the echo reply, not the initial - // echo. - char ping_recv_buf[kBufSize]; - ASSERT_THAT(RetryEINTR(recv)(ping_sock.get(), ping_recv_buf, kBufSize, 0), - SyscallSucceedsWithValue(sizeof(struct icmphdr))); - - // Ensure both sockets receive identical echo reply packets. - auto match_type_raw = [=](char buf[kBufSize]) { - struct icmphdr* icmp = - reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr)); - return icmp->type == ICMP_ECHOREPLY; - }; - auto raw_reply_it = - std::find_if(std::begin(recv_buf1), std::end(recv_buf1), match_type_raw); - ASSERT_NE(raw_reply_it, std::end(recv_buf1)); - EXPECT_EQ( - memcmp(*raw_reply_it + sizeof(struct iphdr), ping_recv_buf, sizeof(icmp)), - 0); -} - -// A raw ICMP socket should be able to send a malformed short ICMP Echo Request, -// while ping socket should not. -// Neither should be able to receieve a short malformed packet. -TEST_F(RawSocketICMPTest, ShortEchoRawAndPingSockets) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - FileDescriptor ping_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); - - struct icmphdr icmp; - icmp.type = ICMP_ECHO; - icmp.code = 0; - icmp.un.echo.sequence = 0; - icmp.un.echo.id = 6789; - icmp.checksum = 0; - icmp.checksum = ICMPChecksum(icmp, NULL, 0); - - // Omit 2 bytes from ICMP packet. - constexpr int kShortICMPSize = sizeof(icmp) - 2; - - // Sending a malformed short ICMP message to a ping socket should fail. - ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, kShortICMPSize, 0, - reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)), - SyscallFailsWithErrno(EINVAL)); - - // Sending a malformed short ICMP message to a raw socket should not fail. - ASSERT_THAT(RetryEINTR(sendto)(s_, &icmp, kShortICMPSize, 0, - reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)), - SyscallSucceedsWithValue(kShortICMPSize)); - - // Neither Ping nor Raw socket should have anything to read. - char recv_buf[kEmptyICMPSize]; - EXPECT_THAT(RetryEINTR(recv)(ping_sock.get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); - EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// A raw ICMP socket should be able to send a malformed short ICMP Echo Reply, -// while ping socket should not. -// Neither should be able to receieve a short malformed packet. -TEST_F(RawSocketICMPTest, ShortEchoReplyRawAndPingSockets) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - FileDescriptor ping_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); - - struct icmphdr icmp; - icmp.type = ICMP_ECHOREPLY; - icmp.code = 0; - icmp.un.echo.sequence = 0; - icmp.un.echo.id = 6789; - icmp.checksum = 0; - icmp.checksum = ICMPChecksum(icmp, NULL, 0); - - // Omit 2 bytes from ICMP packet. - constexpr int kShortICMPSize = sizeof(icmp) - 2; - - // Sending a malformed short ICMP message to a ping socket should fail. - ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, kShortICMPSize, 0, - reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)), - SyscallFailsWithErrno(EINVAL)); - - // Sending a malformed short ICMP message to a raw socket should not fail. - ASSERT_THAT(RetryEINTR(sendto)(s_, &icmp, kShortICMPSize, 0, - reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)), - SyscallSucceedsWithValue(kShortICMPSize)); - - // Neither Ping nor Raw socket should have anything to read. - char recv_buf[kEmptyICMPSize]; - EXPECT_THAT(RetryEINTR(recv)(ping_sock.get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); - EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Test that connect() sends packets to the right place. -TEST_F(RawSocketICMPTest, SendAndReceiveViaConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Prepare and send an ICMP packet. Use arbitrary junk for sequence and ID. - // None of that should matter for raw sockets - the kernel should still give - // us the packet. - struct icmphdr icmp; - icmp.type = ICMP_ECHO; - icmp.code = 0; - icmp.checksum = 0; - icmp.un.echo.sequence = 2003; - icmp.un.echo.id = 2004; - icmp.checksum = ICMPChecksum(icmp, NULL, 0); - ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0), - SyscallSucceedsWithValue(sizeof(icmp))); - - ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); -} - -// Bind to localhost, then send and receive packets. -TEST_F(RawSocketICMPTest, BindSendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, - // and ID. None of that should matter for raw sockets - the kernel should - // still give us the packet. - struct icmphdr icmp; - icmp.type = ICMP_ECHO; - icmp.code = 0; - icmp.checksum = 0; - icmp.un.echo.sequence = 2004; - icmp.un.echo.id = 2007; - icmp.checksum = ICMPChecksum(icmp, NULL, 0); - ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); - - ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); -} - -// Bind and connect to localhost and send/receive packets. -TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, - // and ID. None of that should matter for raw sockets - the kernel should - // still give us the packet. - struct icmphdr icmp; - icmp.type = ICMP_ECHO; - icmp.code = 0; - icmp.checksum = 0; - icmp.un.echo.sequence = 2010; - icmp.un.echo.id = 7; - icmp.checksum = ICMPChecksum(icmp, NULL, 0); - ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); - - ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); -} - -void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { - // We're going to receive both the echo request and reply, but the order is - // indeterminate. - char recv_buf[kEmptyICMPSize]; - struct sockaddr_in src; - bool received_request = false; - bool received_reply = false; - - for (int i = 0; i < 2; i++) { - // Receive the packet. - ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf), - sizeof(struct icmphdr), &src)); - EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0); - struct icmphdr* recvd_icmp = - reinterpret_cast<struct icmphdr*>(recv_buf + sizeof(struct iphdr)); - switch (recvd_icmp->type) { - case ICMP_ECHO: - EXPECT_FALSE(received_request); - received_request = true; - // The packet should be identical to what we sent. - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)), - 0); - break; - - case ICMP_ECHOREPLY: - EXPECT_FALSE(received_reply); - received_reply = true; - // Most fields should be the same. - EXPECT_EQ(recvd_icmp->code, icmp.code); - EXPECT_EQ(recvd_icmp->un.echo.sequence, icmp.un.echo.sequence); - EXPECT_EQ(recvd_icmp->un.echo.id, icmp.un.echo.id); - // A couple are different. - EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY); - // The checksum computed over the reply should still be valid. - EXPECT_EQ(ICMPChecksum(*recvd_icmp, NULL, 0), 0); - break; - } - } - - ASSERT_TRUE(received_request); - ASSERT_TRUE(received_reply); -} - -void RawSocketICMPTest::SendEmptyICMP(const struct icmphdr& icmp) { - ASSERT_NO_FATAL_FAILURE(SendEmptyICMPTo(s_, addr_, icmp)); -} - -void RawSocketICMPTest::SendEmptyICMPTo(int sock, - const struct sockaddr_in& addr, - const struct icmphdr& icmp) { - // It's safe to use const_cast here because sendmsg won't modify the iovec or - // address. - struct iovec iov = {}; - iov.iov_base = static_cast<void*>(const_cast<struct icmphdr*>(&icmp)); - iov.iov_len = sizeof(icmp); - struct msghdr msg = {}; - msg.msg_name = static_cast<void*>(const_cast<struct sockaddr_in*>(&addr)); - msg.msg_namelen = sizeof(addr); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = NULL; - msg.msg_controllen = 0; - msg.msg_flags = 0; - ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(sizeof(icmp))); -} - -void RawSocketICMPTest::ReceiveICMP(char* recv_buf, size_t recv_buf_len, - size_t expected_size, - struct sockaddr_in* src) { - ASSERT_NO_FATAL_FAILURE( - ReceiveICMPFrom(recv_buf, recv_buf_len, expected_size, src, s_)); -} - -void RawSocketICMPTest::ReceiveICMPFrom(char* recv_buf, size_t recv_buf_len, - size_t expected_size, - struct sockaddr_in* src, int sock) { - struct iovec iov = {}; - iov.iov_base = recv_buf; - iov.iov_len = recv_buf_len; - struct msghdr msg = {}; - msg.msg_name = src; - msg.msg_namelen = sizeof(*src); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = NULL; - msg.msg_controllen = 0; - msg.msg_flags = 0; - // We should receive the ICMP packet plus 20 bytes of IP header. - ASSERT_THAT(recvmsg(sock, &msg, 0), - SyscallSucceedsWithValue(expected_size + sizeof(struct iphdr))); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc deleted file mode 100644 index cde2f07c9..000000000 --- a/test/syscalls/linux/raw_socket_ipv4.cc +++ /dev/null @@ -1,392 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <linux/capability.h> -#include <netinet/in.h> -#include <netinet/ip.h> -#include <netinet/ip_icmp.h> -#include <poll.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -// Note: in order to run these tests, /proc/sys/net/ipv4/ping_group_range will -// need to be configured to let the superuser create ping sockets (see icmp(7)). - -namespace gvisor { -namespace testing { - -namespace { - -// Fixture for tests parameterized by protocol. -class RawSocketTest : public ::testing::TestWithParam<int> { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // Sends buf via s_. - void SendBuf(const char* buf, int buf_len); - - // Sends buf to the provided address via the provided socket. - void SendBufTo(int sock, const struct sockaddr_in& addr, const char* buf, - int buf_len); - - // Reads from s_ into recv_buf. - void ReceiveBuf(char* recv_buf, size_t recv_buf_len); - - int Protocol() { return GetParam(); } - - // The socket used for both reading and writing. - int s_; - - // The loopback address. - struct sockaddr_in addr_; -}; - -void RawSocketTest::SetUp() { - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()), - SyscallFailsWithErrno(EPERM)); - GTEST_SKIP(); - } - - ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); - - addr_ = {}; - - // We don't set ports because raw sockets don't have a notion of ports. - addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr_.sin_family = AF_INET; -} - -void RawSocketTest::TearDown() { - // TearDown will be run even if we skip the test. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - EXPECT_THAT(close(s_), SyscallSucceeds()); - } -} - -// We should be able to create multiple raw sockets for the same protocol. -// BasicRawSocket::Setup creates the first one, so we only have to create one -// more here. -TEST_P(RawSocketTest, MultipleCreation) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); - - ASSERT_THAT(close(s2), SyscallSucceeds()); -} - -// Test that shutting down an unconnected socket fails. -TEST_P(RawSocketTest, FailShutdownWithoutConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); -} - -// Shutdown is a no-op for raw sockets (and datagram sockets in general). -TEST_P(RawSocketTest, ShutdownWriteNoop) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "noop"; - ASSERT_THAT(RetryEINTR(write)(s_, kBuf, sizeof(kBuf)), - SyscallSucceedsWithValue(sizeof(kBuf))); -} - -// Shutdown is a no-op for raw sockets (and datagram sockets in general). -TEST_P(RawSocketTest, ShutdownReadNoop) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "gdg"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - constexpr size_t kReadSize = sizeof(kBuf) + sizeof(struct iphdr); - char c[kReadSize]; - ASSERT_THAT(read(s_, &c, sizeof(c)), SyscallSucceedsWithValue(kReadSize)); -} - -// Test that listen() fails. -TEST_P(RawSocketTest, FailListen) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT(listen(s_, 1), SyscallFailsWithErrno(ENOTSUP)); -} - -// Test that accept() fails. -TEST_P(RawSocketTest, FailAccept) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - struct sockaddr saddr; - socklen_t addrlen; - ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP)); -} - -// Test that getpeername() returns nothing before connect(). -TEST_P(RawSocketTest, FailGetPeerNameBeforeConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - struct sockaddr saddr; - socklen_t addrlen = sizeof(saddr); - ASSERT_THAT(getpeername(s_, &saddr, &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -// Test that getpeername() returns something after connect(). -TEST_P(RawSocketTest, GetPeerName) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - struct sockaddr saddr; - socklen_t addrlen = sizeof(saddr); - ASSERT_THAT(getpeername(s_, &saddr, &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - ASSERT_GT(addrlen, 0); -} - -// Test that the socket is writable immediately. -TEST_P(RawSocketTest, PollWritableImmediately) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - struct pollfd pfd = {}; - pfd.fd = s_; - pfd.events = POLLOUT; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1)); -} - -// Test that the socket isn't readable before receiving anything. -TEST_P(RawSocketTest, PollNotReadableInitially) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Try to receive data with MSG_DONTWAIT, which returns immediately if there's - // nothing to be read. - char buf[117]; - ASSERT_THAT(RetryEINTR(recv)(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Test that the socket becomes readable once something is written to it. -TEST_P(RawSocketTest, PollTriggeredOnWrite) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Write something so that there's data to be read. - // Arbitrary. - constexpr char kBuf[] = "JP5"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - struct pollfd pfd = {}; - pfd.fd = s_; - pfd.events = POLLIN; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1)); -} - -// Test that we can connect() to a valid IP (loopback). -TEST_P(RawSocketTest, ConnectToLoopback) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); -} - -// Test that calling send() without connect() fails. -TEST_P(RawSocketTest, SendWithoutConnectFails) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Arbitrary. - constexpr char kBuf[] = "Endgame was good"; - ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0), - SyscallFailsWithErrno(EDESTADDRREQ)); -} - -// Bind to localhost. -TEST_P(RawSocketTest, BindToLocalhost) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); -} - -// Bind to a different address. -TEST_P(RawSocketTest, BindToInvalid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - struct sockaddr_in bind_addr = {}; - bind_addr.sin_family = AF_INET; - bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. - ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -// Send and receive an packet. -TEST_P(RawSocketTest, SendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Arbitrary. - constexpr char kBuf[] = "TB12"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); -} - -// We should be able to create multiple raw sockets for the same protocol and -// receive the same packet on both. -TEST_P(RawSocketTest, MultipleSocketReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "TB10"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - // Receive it on socket 1. - char recv_buf1[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1, sizeof(recv_buf1))); - - // Receive it on socket 2. - char recv_buf2[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s2, recv_buf2, sizeof(recv_buf2))); - - EXPECT_EQ(memcmp(recv_buf1 + sizeof(struct iphdr), - recv_buf2 + sizeof(struct iphdr), sizeof(kBuf)), - 0); - - ASSERT_THAT(close(s2), SyscallSucceeds()); -} - -// Test that connect sends packets to the right place. -TEST_P(RawSocketTest, SendAndReceiveViaConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "JH4"; - ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0), - SyscallSucceedsWithValue(sizeof(kBuf))); - - // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); -} - -// Bind to localhost, then send and receive packets. -TEST_P(RawSocketTest, BindSendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "DR16"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); -} - -// Bind and connect to localhost and send/receive packets. -TEST_P(RawSocketTest, BindConnectSendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "DG88"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); -} - -void RawSocketTest::SendBuf(const char* buf, int buf_len) { - ASSERT_NO_FATAL_FAILURE(SendBufTo(s_, addr_, buf, buf_len)); -} - -void RawSocketTest::SendBufTo(int sock, const struct sockaddr_in& addr, - const char* buf, int buf_len) { - // It's safe to use const_cast here because sendmsg won't modify the iovec or - // address. - struct iovec iov = {}; - iov.iov_base = static_cast<void*>(const_cast<char*>(buf)); - iov.iov_len = static_cast<size_t>(buf_len); - struct msghdr msg = {}; - msg.msg_name = static_cast<void*>(const_cast<struct sockaddr_in*>(&addr)); - msg.msg_namelen = sizeof(addr); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = NULL; - msg.msg_controllen = 0; - msg.msg_flags = 0; - ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(buf_len)); -} - -void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) { - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, recv_buf_len)); -} - -INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest, - ::testing::Values(IPPROTO_TCP, IPPROTO_UDP)); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc deleted file mode 100644 index 2633ba31b..000000000 --- a/test/syscalls/linux/read.cc +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <unistd.h> - -#include <vector> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class ReadTest : public ::testing::Test { - void SetUp() override { - name_ = NewTempAbsPath(); - int fd; - ASSERT_THAT(fd = open(name_.c_str(), O_CREAT, 0644), SyscallSucceeds()); - ASSERT_THAT(close(fd), SyscallSucceeds()); - } - - void TearDown() override { unlink(name_.c_str()); } - - public: - std::string name_; -}; - -TEST_F(ReadTest, ZeroBuffer) { - int fd; - ASSERT_THAT(fd = open(name_.c_str(), O_RDWR), SyscallSucceeds()); - - char msg[] = "hello world"; - EXPECT_THAT(PwriteFd(fd, msg, strlen(msg), 0), - SyscallSucceedsWithValue(strlen(msg))); - - char buf[10]; - EXPECT_THAT(ReadFd(fd, buf, 0), SyscallSucceedsWithValue(0)); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST_F(ReadTest, EmptyFileReturnsZeroAtEOF) { - int fd; - ASSERT_THAT(fd = open(name_.c_str(), O_RDWR), SyscallSucceeds()); - - char eof_buf[10]; - EXPECT_THAT(ReadFd(fd, eof_buf, 10), SyscallSucceedsWithValue(0)); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST_F(ReadTest, EofAfterRead) { - int fd; - ASSERT_THAT(fd = open(name_.c_str(), O_RDWR), SyscallSucceeds()); - - // Write some bytes to be read. - constexpr char kMessage[] = "hello world"; - EXPECT_THAT(PwriteFd(fd, kMessage, sizeof(kMessage), 0), - SyscallSucceedsWithValue(sizeof(kMessage))); - - // Read all of the bytes at once. - char buf[sizeof(kMessage)]; - EXPECT_THAT(ReadFd(fd, buf, sizeof(kMessage)), - SyscallSucceedsWithValue(sizeof(kMessage))); - - // Read again with a non-zero buffer and expect EOF. - char eof_buf[10]; - EXPECT_THAT(ReadFd(fd, eof_buf, 10), SyscallSucceedsWithValue(0)); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST_F(ReadTest, DevNullReturnsEof) { - int fd; - ASSERT_THAT(fd = open("/dev/null", O_RDONLY), SyscallSucceeds()); - std::vector<char> buf(1); - EXPECT_THAT(ReadFd(fd, buf.data(), 1), SyscallSucceedsWithValue(0)); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -const int kReadSize = 128 * 1024; - -// Do not allow random save as it could lead to partial reads. -TEST_F(ReadTest, CanReadFullyFromDevZero_NoRandomSave) { - int fd; - ASSERT_THAT(fd = open("/dev/zero", O_RDONLY), SyscallSucceeds()); - - std::vector<char> buf(kReadSize, 1); - EXPECT_THAT(ReadFd(fd, buf.data(), kReadSize), - SyscallSucceedsWithValue(kReadSize)); - EXPECT_THAT(close(fd), SyscallSucceeds()); - EXPECT_EQ(std::vector<char>(kReadSize, 0), buf); -} - -TEST_F(ReadTest, ReadDirectoryFails) { - const FileDescriptor file = - ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY)); - std::vector<char> buf(1); - EXPECT_THAT(ReadFd(file.get(), buf.data(), 1), SyscallFailsWithErrno(EISDIR)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/readahead.cc b/test/syscalls/linux/readahead.cc deleted file mode 100644 index 09703b5c1..000000000 --- a/test/syscalls/linux/readahead.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(ReadaheadTest, InvalidFD) { - EXPECT_THAT(readahead(-1, 1, 1), SyscallFailsWithErrno(EBADF)); -} - -TEST(ReadaheadTest, InvalidOffset) { - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - EXPECT_THAT(readahead(fd.get(), -1, 1), SyscallFailsWithErrno(EINVAL)); -} - -TEST(ReadaheadTest, ValidOffset) { - constexpr char kData[] = "123"; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - - // N.B. The implementation of readahead is filesystem-specific, and a file - // backed by ram may return EINVAL because there is nothing to be read. - EXPECT_THAT(readahead(fd.get(), 1, 1), AnyOf(SyscallSucceedsWithValue(0), - SyscallFailsWithErrno(EINVAL))); -} - -TEST(ReadaheadTest, PastEnd) { - constexpr char kData[] = "123"; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - // See above. - EXPECT_THAT(readahead(fd.get(), 2, 2), AnyOf(SyscallSucceedsWithValue(0), - SyscallFailsWithErrno(EINVAL))); -} - -TEST(ReadaheadTest, CrossesEnd) { - constexpr char kData[] = "123"; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - // See above. - EXPECT_THAT(readahead(fd.get(), 4, 2), AnyOf(SyscallSucceedsWithValue(0), - SyscallFailsWithErrno(EINVAL))); -} - -TEST(ReadaheadTest, WriteOnly) { - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_WRONLY)); - EXPECT_THAT(readahead(fd.get(), 0, 1), SyscallFailsWithErrno(EBADF)); -} - -TEST(ReadaheadTest, InvalidSize) { - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - EXPECT_THAT(readahead(fd.get(), 0, -1), SyscallFailsWithErrno(EINVAL)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc deleted file mode 100644 index baaf9f757..000000000 --- a/test/syscalls/linux/readv.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <limits.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" -#include "test/syscalls/linux/readv_common.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class ReadvTest : public FileTest { - void SetUp() override { - FileTest::SetUp(); - - ASSERT_THAT(write(test_file_fd_.get(), kReadvTestData, kReadvTestDataSize), - SyscallSucceedsWithValue(kReadvTestDataSize)); - ASSERT_THAT(lseek(test_file_fd_.get(), 0, SEEK_SET), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(write(test_pipe_[1], kReadvTestData, kReadvTestDataSize), - SyscallSucceedsWithValue(kReadvTestDataSize)); - } -}; - -TEST_F(ReadvTest, ReadOneBufferPerByte_File) { - ReadOneBufferPerByte(test_file_fd_.get()); -} - -TEST_F(ReadvTest, ReadOneBufferPerByte_Pipe) { - ReadOneBufferPerByte(test_pipe_[0]); -} - -TEST_F(ReadvTest, ReadOneHalfAtATime_File) { - ReadOneHalfAtATime(test_file_fd_.get()); -} - -TEST_F(ReadvTest, ReadOneHalfAtATime_Pipe) { - ReadOneHalfAtATime(test_pipe_[0]); -} - -TEST_F(ReadvTest, ReadAllOneBuffer_File) { - ReadAllOneBuffer(test_file_fd_.get()); -} - -TEST_F(ReadvTest, ReadAllOneBuffer_Pipe) { ReadAllOneBuffer(test_pipe_[0]); } - -TEST_F(ReadvTest, ReadAllOneLargeBuffer_File) { - ReadAllOneLargeBuffer(test_file_fd_.get()); -} - -TEST_F(ReadvTest, ReadAllOneLargeBuffer_Pipe) { - ReadAllOneLargeBuffer(test_pipe_[0]); -} - -TEST_F(ReadvTest, ReadBuffersOverlapping_File) { - ReadBuffersOverlapping(test_file_fd_.get()); -} - -TEST_F(ReadvTest, ReadBuffersOverlapping_Pipe) { - ReadBuffersOverlapping(test_pipe_[0]); -} - -TEST_F(ReadvTest, ReadBuffersDiscontinuous_File) { - ReadBuffersDiscontinuous(test_file_fd_.get()); -} - -TEST_F(ReadvTest, ReadBuffersDiscontinuous_Pipe) { - ReadBuffersDiscontinuous(test_pipe_[0]); -} - -TEST_F(ReadvTest, ReadIovecsCompletelyFilled_File) { - ReadIovecsCompletelyFilled(test_file_fd_.get()); -} - -TEST_F(ReadvTest, ReadIovecsCompletelyFilled_Pipe) { - ReadIovecsCompletelyFilled(test_pipe_[0]); -} - -TEST_F(ReadvTest, BadFileDescriptor) { - char buffer[1024]; - struct iovec iov[1]; - iov[0].iov_base = buffer; - iov[0].iov_len = 1024; - - ASSERT_THAT(readv(-1, iov, 1024), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(ReadvTest, BadIovecsPointer_File) { - ASSERT_THAT(readv(test_file_fd_.get(), nullptr, 1), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(ReadvTest, BadIovecsPointer_Pipe) { - ASSERT_THAT(readv(test_pipe_[0], nullptr, 1), SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(ReadvTest, BadIovecBase_File) { - struct iovec iov[1]; - iov[0].iov_base = nullptr; - iov[0].iov_len = 1024; - ASSERT_THAT(readv(test_file_fd_.get(), iov, 1), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(ReadvTest, BadIovecBase_Pipe) { - struct iovec iov[1]; - iov[0].iov_base = nullptr; - iov[0].iov_len = 1024; - ASSERT_THAT(readv(test_pipe_[0], iov, 1), SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(ReadvTest, ZeroIovecs_File) { - struct iovec iov[1]; - iov[0].iov_base = 0; - iov[0].iov_len = 0; - ASSERT_THAT(readv(test_file_fd_.get(), iov, 1), SyscallSucceeds()); -} - -TEST_F(ReadvTest, ZeroIovecs_Pipe) { - struct iovec iov[1]; - iov[0].iov_base = 0; - iov[0].iov_len = 0; - ASSERT_THAT(readv(test_pipe_[0], iov, 1), SyscallSucceeds()); -} - -TEST_F(ReadvTest, NotReadable_File) { - char buffer[1024]; - struct iovec iov[1]; - iov[0].iov_base = buffer; - iov[0].iov_len = 1024; - - std::string wronly_file = NewTempAbsPath(); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(wronly_file, O_CREAT | O_WRONLY, S_IRUSR | S_IWUSR)); - ASSERT_THAT(readv(fd.get(), iov, 1), SyscallFailsWithErrno(EBADF)); - fd.reset(); // Close before unlinking. - ASSERT_THAT(unlink(wronly_file.c_str()), SyscallSucceeds()); -} - -TEST_F(ReadvTest, NotReadable_Pipe) { - char buffer[1024]; - struct iovec iov[1]; - iov[0].iov_base = buffer; - iov[0].iov_len = 1024; - ASSERT_THAT(readv(test_pipe_[1], iov, 1), SyscallFailsWithErrno(EBADF)); -} - -TEST_F(ReadvTest, DirNotReadable) { - char buffer[1024]; - struct iovec iov[1]; - iov[0].iov_base = buffer; - iov[0].iov_len = 1024; - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY)); - ASSERT_THAT(readv(fd.get(), iov, 1), SyscallFailsWithErrno(EISDIR)); -} - -TEST_F(ReadvTest, OffsetIncremented) { - char* buffer = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - struct iovec iov[1]; - iov[0].iov_base = buffer; - iov[0].iov_len = kReadvTestDataSize; - - ASSERT_THAT(readv(test_file_fd_.get(), iov, 1), - SyscallSucceedsWithValue(kReadvTestDataSize)); - ASSERT_THAT(lseek(test_file_fd_.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(kReadvTestDataSize)); - - free(buffer); -} - -TEST_F(ReadvTest, EndOfFile) { - char* buffer = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - struct iovec iov[1]; - iov[0].iov_base = buffer; - iov[0].iov_len = kReadvTestDataSize; - ASSERT_THAT(readv(test_file_fd_.get(), iov, 1), - SyscallSucceedsWithValue(kReadvTestDataSize)); - free(buffer); - - buffer = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - iov[0].iov_base = buffer; - iov[0].iov_len = kReadvTestDataSize; - ASSERT_THAT(readv(test_file_fd_.get(), iov, 1), SyscallSucceedsWithValue(0)); - free(buffer); -} - -TEST_F(ReadvTest, WouldBlock_Pipe) { - struct iovec iov[1]; - iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - iov[0].iov_len = kReadvTestDataSize; - ASSERT_THAT(readv(test_pipe_[0], iov, 1), - SyscallSucceedsWithValue(kReadvTestDataSize)); - free(iov[0].iov_base); - - iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - ASSERT_THAT(readv(test_pipe_[0], iov, 1), SyscallFailsWithErrno(EAGAIN)); - free(iov[0].iov_base); -} - -TEST_F(ReadvTest, ZeroBuffer) { - char buf[10]; - struct iovec iov[1]; - iov[0].iov_base = buf; - iov[0].iov_len = 0; - ASSERT_THAT(readv(test_pipe_[0], iov, 1), SyscallSucceedsWithValue(0)); -} - -TEST_F(ReadvTest, NullIovecInNonemptyArray) { - std::vector<char> buf(kReadvTestDataSize); - struct iovec iov[2]; - iov[0].iov_base = nullptr; - iov[0].iov_len = 0; - iov[1].iov_base = buf.data(); - iov[1].iov_len = kReadvTestDataSize; - ASSERT_THAT(readv(test_file_fd_.get(), iov, 2), - SyscallSucceedsWithValue(kReadvTestDataSize)); -} - -TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) { - std::vector<char> buf(kReadvTestDataSize); - struct iovec iov[2]; - iov[0].iov_base = reinterpret_cast<void*>(~static_cast<uintptr_t>(0)); - iov[0].iov_len = 0; - iov[1].iov_base = buf.data(); - iov[1].iov_len = kReadvTestDataSize; - ASSERT_THAT(readv(test_file_fd_.get(), iov, 2), - SyscallFailsWithErrno(EFAULT)); -} - -// This test depends on the maximum extent of a single readv() syscall, so -// we can't tolerate interruption from saving. -TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) { - // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly - // important in environments where automated profiling tools may start - // ITIMER_PROF automatically. - struct itimerval itv = {}; - auto const cleanup_itimer = - ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_PROF, itv)); - - // From Linux's include/linux/fs.h. - size_t const MAX_RW_COUNT = INT_MAX & ~(kPageSize - 1); - - // Create an iovec array with 3 segments pointing to consecutive parts of a - // buffer. The first covers all but the last three pages, and should be - // written to in its entirety. The second covers the last page before - // MAX_RW_COUNT and the first page after; only the first page should be - // written to. The third covers the last page of the buffer, and should be - // skipped entirely. - size_t const kBufferSize = MAX_RW_COUNT + 2 * kPageSize; - size_t const kFirstOffset = MAX_RW_COUNT - kPageSize; - size_t const kSecondOffset = MAX_RW_COUNT + kPageSize; - // The buffer is too big to fit on the stack. - std::vector<char> buf(kBufferSize); - struct iovec iov[3]; - iov[0].iov_base = buf.data(); - iov[0].iov_len = kFirstOffset; - iov[1].iov_base = buf.data() + kFirstOffset; - iov[1].iov_len = kSecondOffset - kFirstOffset; - iov[2].iov_base = buf.data() + kSecondOffset; - iov[2].iov_len = kBufferSize - kSecondOffset; - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY)); - EXPECT_THAT(readv(fd.get(), iov, 3), SyscallSucceedsWithValue(MAX_RW_COUNT)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc deleted file mode 100644 index 2694dc64f..000000000 --- a/test/syscalls/linux/readv_common.cc +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// MatchesStringLength checks that a tuple argument of (struct iovec *, int) -// corresponding to an iovec array and its length, contains data that matches -// the string length strlen. -MATCHER_P(MatchesStringLength, strlen, "") { - struct iovec* iovs = arg.first; - int niov = arg.second; - int offset = 0; - for (int i = 0; i < niov; i++) { - offset += iovs[i].iov_len; - } - if (offset != static_cast<int>(strlen)) { - *result_listener << offset; - return false; - } - return true; -} - -// MatchesStringValue checks that a tuple argument of (struct iovec *, int) -// corresponding to an iovec array and its length, contains data that matches -// the string value str. -MATCHER_P(MatchesStringValue, str, "") { - struct iovec* iovs = arg.first; - int len = strlen(str); - int niov = arg.second; - int offset = 0; - for (int i = 0; i < niov; i++) { - struct iovec iov = iovs[i]; - if (len < offset) { - *result_listener << "strlen " << len << " < offset " << offset; - return false; - } - if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) { - absl::string_view iovec_string(static_cast<char*>(iov.iov_base), - iov.iov_len); - *result_listener << iovec_string << " @offset " << offset; - return false; - } - offset += iov.iov_len; - } - return true; -} - -extern const char kReadvTestData[] = - "127.0.0.1 localhost" - "" - "# The following lines are desirable for IPv6 capable hosts" - "::1 ip6-localhost ip6-loopback" - "fe00::0 ip6-localnet" - "ff00::0 ip6-mcastprefix" - "ff02::1 ip6-allnodes" - "ff02::2 ip6-allrouters" - "ff02::3 ip6-allhosts" - "192.168.1.100 a" - "93.184.216.34 foo.bar.example.com xcpu"; -extern const size_t kReadvTestDataSize = sizeof(kReadvTestData); - -static void ReadAllOneProvidedBuffer(int fd, std::vector<char>* buffer) { - struct iovec iovs[1]; - iovs[0].iov_base = buffer->data(); - iovs[0].iov_len = kReadvTestDataSize; - - ASSERT_THAT(readv(fd, iovs, 1), SyscallSucceedsWithValue(kReadvTestDataSize)); - - std::pair<struct iovec*, int> iovec_desc(iovs, 1); - EXPECT_THAT(iovec_desc, MatchesStringLength(kReadvTestDataSize)); - EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData)); -} - -void ReadAllOneBuffer(int fd) { - std::vector<char> buffer(kReadvTestDataSize); - ReadAllOneProvidedBuffer(fd, &buffer); -} - -void ReadAllOneLargeBuffer(int fd) { - std::vector<char> buffer(10 * kReadvTestDataSize); - ReadAllOneProvidedBuffer(fd, &buffer); -} - -void ReadOneHalfAtATime(int fd) { - int len0 = kReadvTestDataSize / 2; - int len1 = kReadvTestDataSize - len0; - std::vector<char> buffer0(len0); - std::vector<char> buffer1(len1); - - struct iovec iovs[2]; - iovs[0].iov_base = buffer0.data(); - iovs[0].iov_len = len0; - iovs[1].iov_base = buffer1.data(); - iovs[1].iov_len = len1; - - ASSERT_THAT(readv(fd, iovs, 2), SyscallSucceedsWithValue(kReadvTestDataSize)); - - std::pair<struct iovec*, int> iovec_desc(iovs, 2); - EXPECT_THAT(iovec_desc, MatchesStringLength(kReadvTestDataSize)); - EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData)); -} - -void ReadOneBufferPerByte(int fd) { - std::vector<char> buffer(kReadvTestDataSize); - std::vector<struct iovec> iovs(kReadvTestDataSize); - char* buffer_ptr = buffer.data(); - struct iovec* iovs_ptr = iovs.data(); - - for (int i = 0; i < static_cast<int>(kReadvTestDataSize); i++) { - struct iovec iov = { - .iov_base = &buffer_ptr[i], - .iov_len = 1, - }; - iovs_ptr[i] = iov; - } - - ASSERT_THAT(readv(fd, iovs_ptr, kReadvTestDataSize), - SyscallSucceedsWithValue(kReadvTestDataSize)); - - std::pair<struct iovec*, int> iovec_desc(iovs.data(), kReadvTestDataSize); - EXPECT_THAT(iovec_desc, MatchesStringLength(kReadvTestDataSize)); - EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData)); -} - -void ReadBuffersOverlapping(int fd) { - // overlap the first overlap_bytes. - int overlap_bytes = 8; - std::vector<char> buffer(kReadvTestDataSize); - - // overlapping causes us to get more data. - int expected_size = kReadvTestDataSize + overlap_bytes; - std::vector<char> expected(expected_size); - char* expected_ptr = expected.data(); - memcpy(expected_ptr, &kReadvTestData[overlap_bytes], overlap_bytes); - memcpy(&expected_ptr[overlap_bytes], &kReadvTestData[overlap_bytes], - kReadvTestDataSize - overlap_bytes); - - struct iovec iovs[2]; - iovs[0].iov_base = buffer.data(); - iovs[0].iov_len = overlap_bytes; - iovs[1].iov_base = buffer.data(); - iovs[1].iov_len = kReadvTestDataSize; - - ASSERT_THAT(readv(fd, iovs, 2), SyscallSucceedsWithValue(kReadvTestDataSize)); - - std::pair<struct iovec*, int> iovec_desc(iovs, 2); - EXPECT_THAT(iovec_desc, MatchesStringLength(expected_size)); - EXPECT_THAT(iovec_desc, MatchesStringValue(expected_ptr)); -} - -void ReadBuffersDiscontinuous(int fd) { - // Each iov is 1 byte separated by 1 byte. - std::vector<char> buffer(kReadvTestDataSize * 2); - std::vector<struct iovec> iovs(kReadvTestDataSize); - - char* buffer_ptr = buffer.data(); - struct iovec* iovs_ptr = iovs.data(); - - for (int i = 0; i < static_cast<int>(kReadvTestDataSize); i++) { - struct iovec iov = { - .iov_base = &buffer_ptr[i * 2], - .iov_len = 1, - }; - iovs_ptr[i] = iov; - } - - ASSERT_THAT(readv(fd, iovs_ptr, kReadvTestDataSize), - SyscallSucceedsWithValue(kReadvTestDataSize)); - - std::pair<struct iovec*, int> iovec_desc(iovs.data(), kReadvTestDataSize); - EXPECT_THAT(iovec_desc, MatchesStringLength(kReadvTestDataSize)); - EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData)); -} - -void ReadIovecsCompletelyFilled(int fd) { - int half = kReadvTestDataSize / 2; - std::vector<char> buffer(kReadvTestDataSize); - char* buffer_ptr = buffer.data(); - memset(buffer.data(), '\0', kReadvTestDataSize); - - struct iovec iovs[2]; - iovs[0].iov_base = buffer.data(); - iovs[0].iov_len = half; - iovs[1].iov_base = &buffer_ptr[half]; - iovs[1].iov_len = half; - - ASSERT_THAT(readv(fd, iovs, 2), SyscallSucceedsWithValue(half * 2)); - - std::pair<struct iovec*, int> iovec_desc(iovs, 2); - EXPECT_THAT(iovec_desc, MatchesStringLength(half * 2)); - EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData)); - - char* str = static_cast<char*>(iovs[0].iov_base); - str[iovs[0].iov_len - 1] = '\0'; - ASSERT_EQ(half - 1, strlen(str)); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/readv_common.h b/test/syscalls/linux/readv_common.h deleted file mode 100644 index 2fa40c35f..000000000 --- a/test/syscalls/linux/readv_common.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_READV_COMMON_H_ -#define GVISOR_TEST_SYSCALLS_READV_COMMON_H_ - -#include <stddef.h> - -namespace gvisor { -namespace testing { - -// A NUL-terminated string containing the data used by tests using the following -// test helpers. -extern const char kReadvTestData[]; - -// The size of kReadvTestData, including the terminating NUL. -extern const size_t kReadvTestDataSize; - -// ReadAllOneBuffer asserts that it can read kReadvTestData from an fd using -// exactly one iovec. -void ReadAllOneBuffer(int fd); - -// ReadAllOneLargeBuffer asserts that it can read kReadvTestData from an fd -// using exactly one iovec containing an overly large buffer. -void ReadAllOneLargeBuffer(int fd); - -// ReadOneHalfAtATime asserts that it can read test_data_from an fd using -// exactly two iovecs that are roughly equivalent in size. -void ReadOneHalfAtATime(int fd); - -// ReadOneBufferPerByte asserts that it can read kReadvTestData from an fd -// using one iovec per byte. -void ReadOneBufferPerByte(int fd); - -// ReadBuffersOverlapping asserts that it can read kReadvTestData from an fd -// where two iovecs are overlapping. -void ReadBuffersOverlapping(int fd); - -// ReadBuffersDiscontinuous asserts that it can read kReadvTestData from an fd -// where each iovec is discontinuous from the next by 1 byte. -void ReadBuffersDiscontinuous(int fd); - -// ReadIovecsCompletelyFilled asserts that the previous iovec is completely -// filled before moving onto the next. -void ReadIovecsCompletelyFilled(int fd); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_READV_COMMON_H_ diff --git a/test/syscalls/linux/readv_socket.cc b/test/syscalls/linux/readv_socket.cc deleted file mode 100644 index dd6fb7008..000000000 --- a/test/syscalls/linux/readv_socket.cc +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/readv_common.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class ReadvSocketTest : public ::testing::Test { - public: - void SetUp() override { - test_unix_stream_socket_[0] = -1; - test_unix_stream_socket_[1] = -1; - test_unix_dgram_socket_[0] = -1; - test_unix_dgram_socket_[1] = -1; - test_unix_seqpacket_socket_[0] = -1; - test_unix_seqpacket_socket_[1] = -1; - - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT( - socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - - ASSERT_THAT( - write(test_unix_stream_socket_[1], kReadvTestData, kReadvTestDataSize), - SyscallSucceedsWithValue(kReadvTestDataSize)); - ASSERT_THAT( - write(test_unix_dgram_socket_[1], kReadvTestData, kReadvTestDataSize), - SyscallSucceedsWithValue(kReadvTestDataSize)); - ASSERT_THAT(write(test_unix_seqpacket_socket_[1], kReadvTestData, - kReadvTestDataSize), - SyscallSucceedsWithValue(kReadvTestDataSize)); - } - - void TearDown() override { - close(test_unix_stream_socket_[0]); - close(test_unix_stream_socket_[1]); - - close(test_unix_dgram_socket_[0]); - close(test_unix_dgram_socket_[1]); - - close(test_unix_seqpacket_socket_[0]); - close(test_unix_seqpacket_socket_[1]); - } - - int test_unix_stream_socket_[2]; - int test_unix_dgram_socket_[2]; - int test_unix_seqpacket_socket_[2]; -}; - -TEST_F(ReadvSocketTest, ReadOneBufferPerByte_StreamSocket) { - ReadOneBufferPerByte(test_unix_stream_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadOneBufferPerByte_DgramSocket) { - ReadOneBufferPerByte(test_unix_dgram_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadOneBufferPerByte_SeqPacketSocket) { - ReadOneBufferPerByte(test_unix_seqpacket_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadOneHalfAtATime_StreamSocket) { - ReadOneHalfAtATime(test_unix_stream_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadOneHalfAtATime_DgramSocket) { - ReadOneHalfAtATime(test_unix_dgram_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadAllOneBuffer_StreamSocket) { - ReadAllOneBuffer(test_unix_stream_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadAllOneBuffer_DgramSocket) { - ReadAllOneBuffer(test_unix_dgram_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadAllOneLargeBuffer_StreamSocket) { - ReadAllOneLargeBuffer(test_unix_stream_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadAllOneLargeBuffer_DgramSocket) { - ReadAllOneLargeBuffer(test_unix_dgram_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadBuffersOverlapping_StreamSocket) { - ReadBuffersOverlapping(test_unix_stream_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadBuffersOverlapping_DgramSocket) { - ReadBuffersOverlapping(test_unix_dgram_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadBuffersDiscontinuous_StreamSocket) { - ReadBuffersDiscontinuous(test_unix_stream_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadBuffersDiscontinuous_DgramSocket) { - ReadBuffersDiscontinuous(test_unix_dgram_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadIovecsCompletelyFilled_StreamSocket) { - ReadIovecsCompletelyFilled(test_unix_stream_socket_[0]); -} - -TEST_F(ReadvSocketTest, ReadIovecsCompletelyFilled_DgramSocket) { - ReadIovecsCompletelyFilled(test_unix_dgram_socket_[0]); -} - -TEST_F(ReadvSocketTest, BadIovecsPointer_StreamSocket) { - ASSERT_THAT(readv(test_unix_stream_socket_[0], nullptr, 1), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(ReadvSocketTest, BadIovecsPointer_DgramSocket) { - ASSERT_THAT(readv(test_unix_dgram_socket_[0], nullptr, 1), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(ReadvSocketTest, BadIovecBase_StreamSocket) { - struct iovec iov[1]; - iov[0].iov_base = nullptr; - iov[0].iov_len = 1024; - ASSERT_THAT(readv(test_unix_stream_socket_[0], iov, 1), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(ReadvSocketTest, BadIovecBase_DgramSocket) { - struct iovec iov[1]; - iov[0].iov_base = nullptr; - iov[0].iov_len = 1024; - ASSERT_THAT(readv(test_unix_dgram_socket_[0], iov, 1), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(ReadvSocketTest, ZeroIovecs_StreamSocket) { - struct iovec iov[1]; - iov[0].iov_base = 0; - iov[0].iov_len = 0; - ASSERT_THAT(readv(test_unix_stream_socket_[0], iov, 1), SyscallSucceeds()); -} - -TEST_F(ReadvSocketTest, ZeroIovecs_DgramSocket) { - struct iovec iov[1]; - iov[0].iov_base = 0; - iov[0].iov_len = 0; - ASSERT_THAT(readv(test_unix_dgram_socket_[0], iov, 1), SyscallSucceeds()); -} - -TEST_F(ReadvSocketTest, WouldBlock_StreamSocket) { - struct iovec iov[1]; - iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - iov[0].iov_len = kReadvTestDataSize; - ASSERT_THAT(readv(test_unix_stream_socket_[0], iov, 1), - SyscallSucceedsWithValue(kReadvTestDataSize)); - free(iov[0].iov_base); - - iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - ASSERT_THAT(readv(test_unix_stream_socket_[0], iov, 1), - SyscallFailsWithErrno(EAGAIN)); - free(iov[0].iov_base); -} - -TEST_F(ReadvSocketTest, WouldBlock_DgramSocket) { - struct iovec iov[1]; - iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - iov[0].iov_len = kReadvTestDataSize; - ASSERT_THAT(readv(test_unix_dgram_socket_[0], iov, 1), - SyscallSucceedsWithValue(kReadvTestDataSize)); - free(iov[0].iov_base); - - iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize)); - ASSERT_THAT(readv(test_unix_dgram_socket_[0], iov, 1), - SyscallFailsWithErrno(EAGAIN)); - free(iov[0].iov_base); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc deleted file mode 100644 index 833c0dc4f..000000000 --- a/test/syscalls/linux/rename.cc +++ /dev/null @@ -1,394 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdio.h> - -#include <string> - -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(RenameTest, RootToAnything) { - ASSERT_THAT(rename("/", "/bin"), SyscallFailsWithErrno(EBUSY)); -} - -TEST(RenameTest, AnythingToRoot) { - ASSERT_THAT(rename("/bin", "/"), SyscallFailsWithErrno(EBUSY)); -} - -TEST(RenameTest, SourceIsAncestorOfTarget) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto subdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); - ASSERT_THAT(rename(dir.path().c_str(), subdir.path().c_str()), - SyscallFailsWithErrno(EINVAL)); - - // Try an even deeper directory. - auto deep_subdir = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(subdir.path())); - ASSERT_THAT(rename(dir.path().c_str(), deep_subdir.path().c_str()), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(RenameTest, TargetIsAncestorOfSource) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto subdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); - ASSERT_THAT(rename(subdir.path().c_str(), dir.path().c_str()), - SyscallFailsWithErrno(ENOTEMPTY)); - - // Try an even deeper directory. - auto deep_subdir = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(subdir.path())); - ASSERT_THAT(rename(deep_subdir.path().c_str(), dir.path().c_str()), - SyscallFailsWithErrno(ENOTEMPTY)); -} - -TEST(RenameTest, FileToSelf) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - EXPECT_THAT(rename(f.path().c_str(), f.path().c_str()), SyscallSucceeds()); -} - -TEST(RenameTest, DirectoryToSelf) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(rename(f.path().c_str(), f.path().c_str()), SyscallSucceeds()); -} - -TEST(RenameTest, FileToSameDirectory) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - std::string const newpath = NewTempAbsPath(); - ASSERT_THAT(rename(f.path().c_str(), newpath.c_str()), SyscallSucceeds()); - std::string const oldpath = f.release(); - f.reset(newpath); - EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false)); - EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true)); -} - -TEST(RenameTest, DirectoryToSameDirectory) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - std::string const newpath = NewTempAbsPath(); - ASSERT_THAT(rename(dir.path().c_str(), newpath.c_str()), SyscallSucceeds()); - std::string const oldpath = dir.release(); - dir.reset(newpath); - EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false)); - EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true)); -} - -TEST(RenameTest, FileToParentDirectory) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path())); - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir2.path())); - std::string const newpath = NewTempAbsPathInDir(dir1.path()); - ASSERT_THAT(rename(f.path().c_str(), newpath.c_str()), SyscallSucceeds()); - std::string const oldpath = f.release(); - f.reset(newpath); - EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false)); - EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true)); -} - -TEST(RenameTest, DirectoryToParentDirectory) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path())); - auto dir3 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir2.path())); - EXPECT_THAT(IsDirectory(dir3.path()), IsPosixErrorOkAndHolds(true)); - std::string const newpath = NewTempAbsPathInDir(dir1.path()); - ASSERT_THAT(rename(dir3.path().c_str(), newpath.c_str()), SyscallSucceeds()); - std::string const oldpath = dir3.release(); - dir3.reset(newpath); - EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false)); - EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true)); - EXPECT_THAT(IsDirectory(newpath), IsPosixErrorOkAndHolds(true)); -} - -TEST(RenameTest, FileToChildDirectory) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path())); - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); - std::string const newpath = NewTempAbsPathInDir(dir2.path()); - ASSERT_THAT(rename(f.path().c_str(), newpath.c_str()), SyscallSucceeds()); - std::string const oldpath = f.release(); - f.reset(newpath); - EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false)); - EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true)); -} - -TEST(RenameTest, DirectoryToChildDirectory) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path())); - auto dir3 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path())); - std::string const newpath = NewTempAbsPathInDir(dir2.path()); - ASSERT_THAT(rename(dir3.path().c_str(), newpath.c_str()), SyscallSucceeds()); - std::string const oldpath = dir3.release(); - dir3.reset(newpath); - EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false)); - EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true)); - EXPECT_THAT(IsDirectory(newpath), IsPosixErrorOkAndHolds(true)); -} - -TEST(RenameTest, DirectoryToOwnChildDirectory) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path())); - std::string const newpath = NewTempAbsPathInDir(dir2.path()); - ASSERT_THAT(rename(dir1.path().c_str(), newpath.c_str()), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(RenameTest, FileOverwritesFile) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - dir.path(), "first", TempPath::kDefaultFileMode)); - auto f2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - dir.path(), "second", TempPath::kDefaultFileMode)); - ASSERT_THAT(rename(f1.path().c_str(), f2.path().c_str()), SyscallSucceeds()); - EXPECT_THAT(Exists(f1.path()), IsPosixErrorOkAndHolds(false)); - - f1.release(); - std::string f2_contents; - ASSERT_NO_ERRNO(GetContents(f2.path(), &f2_contents)); - EXPECT_EQ("first", f2_contents); -} - -TEST(RenameTest, DirectoryOverwritesDirectoryLinkCount) { - auto parent1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(2)); - - auto parent2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(Links(parent2.path()), IsPosixErrorOkAndHolds(2)); - - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent1.path())); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent2.path())); - - EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(3)); - EXPECT_THAT(Links(parent2.path()), IsPosixErrorOkAndHolds(3)); - - ASSERT_THAT(rename(dir1.path().c_str(), dir2.path().c_str()), - SyscallSucceeds()); - - EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(2)); - EXPECT_THAT(Links(parent2.path()), IsPosixErrorOkAndHolds(3)); -} - -TEST(RenameTest, FileDoesNotExist) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string source = JoinPath(dir.path(), "source"); - const std::string dest = JoinPath(dir.path(), "dest"); - ASSERT_THAT(rename(source.c_str(), dest.c_str()), - SyscallFailsWithErrno(ENOENT)); -} - -TEST(RenameTest, FileDoesNotOverwriteDirectory) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(rename(f.path().c_str(), dir.path().c_str()), - SyscallFailsWithErrno(EISDIR)); -} - -TEST(RenameTest, DirectoryDoesNotOverwriteFile) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(rename(dir.path().c_str(), f.path().c_str()), - SyscallFailsWithErrno(ENOTDIR)); -} - -TEST(RenameTest, DirectoryOverwritesEmptyDirectory) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(rename(dir1.path().c_str(), dir2.path().c_str()), - SyscallSucceeds()); - EXPECT_THAT(Exists(dir1.path()), IsPosixErrorOkAndHolds(false)); - dir1.release(); - EXPECT_THAT(Exists(JoinPath(dir2.path(), Basename(f.path()))), - IsPosixErrorOkAndHolds(true)); - f.release(); -} - -TEST(RenameTest, FailsWithDots) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto dir1_dot = absl::StrCat(dir1.path(), "/."); - auto dir2_dot = absl::StrCat(dir2.path(), "/."); - auto dir1_dot_dot = absl::StrCat(dir1.path(), "/.."); - auto dir2_dot_dot = absl::StrCat(dir2.path(), "/.."); - - // Try with dot paths in the first argument - EXPECT_THAT(rename(dir1_dot.c_str(), dir2.path().c_str()), - SyscallFailsWithErrno(EBUSY)); - EXPECT_THAT(rename(dir1_dot_dot.c_str(), dir2.path().c_str()), - SyscallFailsWithErrno(EBUSY)); - - // Try with dot paths in the second argument - EXPECT_THAT(rename(dir1.path().c_str(), dir2_dot.c_str()), - SyscallFailsWithErrno(EBUSY)); - EXPECT_THAT(rename(dir1.path().c_str(), dir2_dot_dot.c_str()), - SyscallFailsWithErrno(EBUSY)); -} - -TEST(RenameTest, DirectoryDoesNotOverwriteNonemptyDirectory) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir2.path())); - ASSERT_THAT(rename(dir1.path().c_str(), dir2.path().c_str()), - SyscallFailsWithErrno(ENOTEMPTY)); -} - -TEST(RenameTest, FailsWhenOldParentNotWritable) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - // dir1 is not writable. - ASSERT_THAT(chmod(dir1.path().c_str(), 0555), SyscallSucceeds()); - - std::string const newpath = NewTempAbsPathInDir(dir2.path()); - EXPECT_THAT(rename(f1.path().c_str(), newpath.c_str()), - SyscallFailsWithErrno(EACCES)); -} - -TEST(RenameTest, FailsWhenNewParentNotWritable) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); - // dir2 is not writable. - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555)); - - std::string const newpath = NewTempAbsPathInDir(dir2.path()); - EXPECT_THAT(rename(f1.path().c_str(), newpath.c_str()), - SyscallFailsWithErrno(EACCES)); -} - -// Equivalent to FailsWhenNewParentNotWritable, but with a destination file -// to overwrite. -TEST(RenameTest, OverwriteFailsWhenNewParentNotWritable) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); - - // dir2 is not writable. - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir2.path())); - ASSERT_THAT(chmod(dir2.path().c_str(), 0555), SyscallSucceeds()); - - EXPECT_THAT(rename(f1.path().c_str(), f2.path().c_str()), - SyscallFailsWithErrno(EACCES)); -} - -// If the parent directory of source is not accessible, rename returns EACCES -// because the user cannot determine if source exists. -TEST(RenameTest, FileDoesNotExistWhenNewParentNotExecutable) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - // No execute permission. - auto dir = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0400)); - - const std::string source = JoinPath(dir.path(), "source"); - const std::string dest = JoinPath(dir.path(), "dest"); - ASSERT_THAT(rename(source.c_str(), dest.c_str()), - SyscallFailsWithErrno(EACCES)); -} - -TEST(RenameTest, DirectoryWithOpenFdOverwritesEmptyDirectory) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // Get an fd on dir1 - int fd; - ASSERT_THAT(fd = open(dir1.path().c_str(), O_DIRECTORY), SyscallSucceeds()); - auto close_f = Cleanup([fd] { - // Close the fd on f. - EXPECT_THAT(close(fd), SyscallSucceeds()); - }); - - EXPECT_THAT(rename(dir1.path().c_str(), dir2.path().c_str()), - SyscallSucceeds()); - - const std::string new_f_path = JoinPath(dir2.path(), Basename(f.path())); - - auto remove_f = Cleanup([&] { - // Delete f in its new location. - ASSERT_NO_ERRNO(Delete(new_f_path)); - f.release(); - }); - - EXPECT_THAT(Exists(dir1.path()), IsPosixErrorOkAndHolds(false)); - dir1.release(); - EXPECT_THAT(Exists(new_f_path), IsPosixErrorOkAndHolds(true)); -} - -TEST(RenameTest, FileWithOpenFd) { - TempPath root_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath dir1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path())); - TempPath dir2 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path())); - TempPath dir3 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path())); - - // Create file in dir1. - constexpr char kContents[] = "foo"; - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - dir1.path(), kContents, TempPath::kDefaultFileMode)); - - // Get fd on file. - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR)); - - // Move f to dir2. - const std::string path2 = NewTempAbsPathInDir(dir2.path()); - ASSERT_THAT(rename(f.path().c_str(), path2.c_str()), SyscallSucceeds()); - - // Read f's kContents. - char buf[sizeof(kContents)]; - EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(kContents), 0), - SyscallSucceedsWithValue(sizeof(kContents) - 1)); - EXPECT_EQ(absl::string_view(buf, sizeof(buf) - 1), kContents); - - // Move f to dir3. - const std::string path3 = NewTempAbsPathInDir(dir3.path()); - ASSERT_THAT(rename(path2.c_str(), path3.c_str()), SyscallSucceeds()); - - // Read f's kContents. - EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(kContents), 0), - SyscallSucceedsWithValue(sizeof(kContents) - 1)); - EXPECT_EQ(absl::string_view(buf, sizeof(buf) - 1), kContents); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/rlimits.cc b/test/syscalls/linux/rlimits.cc deleted file mode 100644 index 860f0f688..000000000 --- a/test/syscalls/linux/rlimits.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/resource.h> -#include <sys/time.h> - -#include "test/util/capability_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(RlimitTest, SetRlimitHigher) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))); - - struct rlimit rl = {}; - EXPECT_THAT(getrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds()); - - // Lower the rlimit first, as it may be equal to /proc/sys/fs/nr_open, in - // which case even users with CAP_SYS_RESOURCE can't raise it. - rl.rlim_cur--; - rl.rlim_max--; - ASSERT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds()); - - rl.rlim_max++; - EXPECT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds()); -} - -TEST(RlimitTest, UnprivilegedSetRlimit) { - // Drop privileges if necessary. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_RESOURCE, false)); - } - - struct rlimit rl = {}; - rl.rlim_cur = 1000; - rl.rlim_max = 20000; - EXPECT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds()); - - struct rlimit rl2 = {}; - EXPECT_THAT(getrlimit(RLIMIT_NOFILE, &rl2), SyscallSucceeds()); - EXPECT_EQ(rl.rlim_cur, rl2.rlim_cur); - EXPECT_EQ(rl.rlim_max, rl2.rlim_max); - - rl.rlim_max = 100000; - EXPECT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallFailsWithErrno(EPERM)); -} - -TEST(RlimitTest, SetSoftRlimitAboveHard) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))); - - struct rlimit rl = {}; - EXPECT_THAT(getrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds()); - - rl.rlim_cur = rl.rlim_max + 1; - EXPECT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallFailsWithErrno(EINVAL)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc deleted file mode 100644 index 4bfb1ff56..000000000 --- a/test/syscalls/linux/rseq.cc +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <signal.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <sys/wait.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/rseq/test.h" -#include "test/syscalls/linux/rseq/uapi.h" -#include "test/util/logging.h" -#include "test/util/multiprocess_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Syscall test for rseq (restartable sequences). -// -// We must be very careful about how these tests are written. Each thread may -// only have one struct rseq registration, which may be done automatically at -// thread start (as of 2019-11-13, glibc does *not* support rseq and thus does -// not do so, but other libraries do). -// -// Testing of rseq is thus done primarily in a child process with no -// registration. This means exec'ing a nostdlib binary, as rseq registration can -// only be cleared by execve (or knowing the old rseq address), and glibc (based -// on the current unmerged patches) register rseq before calling main()). - -int RSeq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) { - return syscall(kRseqSyscall, rseq, rseq_len, flags, sig); -} - -// Returns true if this kernel supports the rseq syscall. -PosixErrorOr<bool> RSeqSupported() { - // We have to be careful here, there are three possible cases: - // - // 1. rseq is not supported -> ENOSYS - // 2. rseq is supported and not registered -> success, but we should - // unregister. - // 3. rseq is supported and registered -> EINVAL (most likely). - - // The only validation done on new registrations is that rseq is aligned and - // writable. - rseq rseq = {}; - int ret = RSeq(&rseq, sizeof(rseq), 0, 0); - if (ret == 0) { - // Successfully registered, rseq is supported. Unregister. - ret = RSeq(&rseq, sizeof(rseq), kRseqFlagUnregister, 0); - if (ret != 0) { - return PosixError(errno); - } - return true; - } - - switch (errno) { - case ENOSYS: - // Not supported. - return false; - case EINVAL: - // Supported, but already registered. EINVAL returned because we provided - // a different address. - return true; - default: - // Unknown error. - return PosixError(errno); - } -} - -constexpr char kRseqBinary[] = "test/syscalls/linux/rseq/rseq"; - -void RunChildTest(std::string test_case, int want_status) { - std::string path = RunfilePath(kRseqBinary); - - pid_t child_pid = -1; - int execve_errno = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(path, {path, test_case}, {}, &child_pid, &execve_errno)); - - ASSERT_GT(child_pid, 0); - ASSERT_EQ(execve_errno, 0); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - ASSERT_EQ(status, want_status); -} - -// Test that rseq must be aligned. -TEST(RseqTest, Unaligned) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestUnaligned, 0); -} - -// Sanity test that registration works. -TEST(RseqTest, Register) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestRegister, 0); -} - -// Registration can't be done twice. -TEST(RseqTest, DoubleRegister) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestDoubleRegister, 0); -} - -// Registration can be done again after unregister. -TEST(RseqTest, RegisterUnregister) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestRegisterUnregister, 0); -} - -// The pointer to rseq must match on register/unregister. -TEST(RseqTest, UnregisterDifferentPtr) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestUnregisterDifferentPtr, 0); -} - -// The signature must match on register/unregister. -TEST(RseqTest, UnregisterDifferentSignature) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestUnregisterDifferentSignature, 0); -} - -// The CPU ID is initialized. -TEST(RseqTest, CPU) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestCPU, 0); -} - -// Critical section is eventually aborted. -TEST(RseqTest, Abort) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestAbort, 0); -} - -// Abort may be before the critical section. -TEST(RseqTest, AbortBefore) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestAbortBefore, 0); -} - -// Signature must match. -TEST(RseqTest, AbortSignature) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestAbortSignature, SIGSEGV); -} - -// Abort must not be in the critical section. -TEST(RseqTest, AbortPreCommit) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestAbortPreCommit, SIGSEGV); -} - -// rseq.rseq_cs is cleared on abort. -TEST(RseqTest, AbortClearsCS) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestAbortClearsCS, 0); -} - -// rseq.rseq_cs is cleared on abort outside of critical section. -TEST(RseqTest, InvalidAbortClearsCS) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); - - RunChildTest(kRseqTestInvalidAbortClearsCS, 0); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD deleted file mode 100644 index ed488dbc2..000000000 --- a/test/syscalls/linux/rseq/BUILD +++ /dev/null @@ -1,58 +0,0 @@ -# This package contains a standalone rseq test binary. This binary must not -# depend on libc, which might use rseq itself. - -load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain") - -package(licenses = ["notice"]) - -genrule( - name = "rseq_binary", - srcs = [ - "critical.h", - "critical.S", - "rseq.cc", - "syscalls.h", - "start.S", - "test.h", - "types.h", - "uapi.h", - ], - outs = ["rseq"], - cmd = " ".join([ - "$(CC)", - "$(CC_FLAGS) ", - "-I.", - "-Wall", - "-Werror", - "-O2", - "-std=c++17", - "-static", - "-nostdlib", - "-ffreestanding", - "-o", - "$(location rseq)", - "$(location critical.S)", - "$(location rseq.cc)", - "$(location start.S)", - ]), - toolchains = [ - cc_toolchain, - ":no_pie_cc_flags", - ], - visibility = ["//:sandbox"], -) - -cc_flags_supplier( - name = "no_pie_cc_flags", - features = ["-pie"], -) - -cc_library( - name = "lib", - testonly = 1, - hdrs = [ - "test.h", - "uapi.h", - ], - visibility = ["//:sandbox"], -) diff --git a/test/syscalls/linux/rseq/critical.S b/test/syscalls/linux/rseq/critical.S deleted file mode 100644 index 8c0687e6d..000000000 --- a/test/syscalls/linux/rseq/critical.S +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Restartable sequences critical sections. - -// Loops continuously until aborted. -// -// void rseq_loop(struct rseq* r, struct rseq_cs* cs) - - .text - .globl rseq_loop - .type rseq_loop, @function - -rseq_loop: - jmp begin - - // Abort block before the critical section. - // Abort signature is 4 nops for simplicity. - .byte 0x90, 0x90, 0x90, 0x90 - .globl rseq_loop_early_abort -rseq_loop_early_abort: - ret - -begin: - // r->rseq_cs = cs - movq %rsi, 8(%rdi) - - // N.B. rseq_cs will be cleared by any preempt, even outside the critical - // section. Thus it must be set in or immediately before the critical section - // to ensure it is not cleared before the section begins. - .globl rseq_loop_start -rseq_loop_start: - jmp rseq_loop_start - - // "Pre-commit": extra instructions inside the critical section. These are - // used as the abort point in TestAbortPreCommit, which is not valid. - .globl rseq_loop_pre_commit -rseq_loop_pre_commit: - // Extra abort signature + nop for TestAbortPostCommit. - .byte 0x90, 0x90, 0x90, 0x90 - nop - - // "Post-commit": never reached in this case. - .globl rseq_loop_post_commit -rseq_loop_post_commit: - - // Abort signature is 4 nops for simplicity. - .byte 0x90, 0x90, 0x90, 0x90 - - .globl rseq_loop_abort -rseq_loop_abort: - ret - - .size rseq_loop,.-rseq_loop - .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/critical.h b/test/syscalls/linux/rseq/critical.h deleted file mode 100644 index ac987a25e..000000000 --- a/test/syscalls/linux/rseq/critical.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ - -#include "test/syscalls/linux/rseq/types.h" -#include "test/syscalls/linux/rseq/uapi.h" - -constexpr uint32_t kRseqSignature = 0x90909090; - -extern "C" { - -extern void rseq_loop(struct rseq* r, struct rseq_cs* cs); -extern void* rseq_loop_early_abort; -extern void* rseq_loop_start; -extern void* rseq_loop_pre_commit; -extern void* rseq_loop_post_commit; -extern void* rseq_loop_abort; - -extern int rseq_getpid(struct rseq* r, struct rseq_cs* cs); -extern void* rseq_getpid_start; -extern void* rseq_getpid_post_commit; -extern void* rseq_getpid_abort; - -} // extern "C" - -#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc deleted file mode 100644 index f036db26d..000000000 --- a/test/syscalls/linux/rseq/rseq.cc +++ /dev/null @@ -1,366 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/rseq/critical.h" -#include "test/syscalls/linux/rseq/syscalls.h" -#include "test/syscalls/linux/rseq/test.h" -#include "test/syscalls/linux/rseq/types.h" -#include "test/syscalls/linux/rseq/uapi.h" - -namespace gvisor { -namespace testing { - -extern "C" int main(int argc, char** argv, char** envp); - -// Standalone initialization before calling main(). -extern "C" void __init(uintptr_t* sp) { - int argc = sp[0]; - char** argv = reinterpret_cast<char**>(&sp[1]); - char** envp = &argv[argc + 1]; - - // Call main() and exit. - sys_exit_group(main(argc, argv, envp)); - - // sys_exit_group does not return -} - -int strcmp(const char* s1, const char* s2) { - const unsigned char* p1 = reinterpret_cast<const unsigned char*>(s1); - const unsigned char* p2 = reinterpret_cast<const unsigned char*>(s2); - - while (*p1 == *p2) { - if (!*p1) { - return 0; - } - ++p1; - ++p2; - } - return static_cast<int>(*p1) - static_cast<int>(*p2); -} - -int sys_rseq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) { - return raw_syscall(kRseqSyscall, rseq, rseq_len, flags, sig); -} - -// Test that rseq must be aligned. -int TestUnaligned() { - constexpr uintptr_t kRequiredAlignment = alignof(rseq); - - char buf[2 * kRequiredAlignment] = {}; - uintptr_t ptr = reinterpret_cast<uintptr_t>(&buf[0]); - if ((ptr & (kRequiredAlignment - 1)) == 0) { - // buf is already aligned. Misalign it. - ptr++; - } - - int ret = sys_rseq(reinterpret_cast<rseq*>(ptr), sizeof(rseq), 0, 0); - if (sys_errno(ret) != EINVAL) { - return 1; - } - return 0; -} - -// Sanity test that registration works. -int TestRegister() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { - return 1; - } - return 0; -}; - -// Registration can't be done twice. -int TestDoubleRegister() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { - return 1; - } - - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) { - return 1; - } - - return 0; -}; - -// Registration can be done again after unregister. -int TestRegisterUnregister() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { - return 1; - } - - if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); - sys_errno(ret) != 0) { - return 1; - } - - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { - return 1; - } - - return 0; -}; - -// The pointer to rseq must match on register/unregister. -int TestUnregisterDifferentPtr() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { - return 1; - } - - struct rseq r2 = {}; - if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); - sys_errno(ret) != EINVAL) { - return 1; - } - - return 0; -}; - -// The signature must match on register/unregister. -int TestUnregisterDifferentSignature() { - constexpr int kSignature = 0; - - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) { - return 1; - } - - if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); - sys_errno(ret) != EPERM) { - return 1; - } - - return 0; -}; - -// The CPU ID is initialized. -int TestCPU() { - struct rseq r = {}; - r.cpu_id = kRseqCPUIDUninitialized; - - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { - return 1; - } - - if (__atomic_load_n(&r.cpu_id, __ATOMIC_RELAXED) < 0) { - return 1; - } - if (__atomic_load_n(&r.cpu_id_start, __ATOMIC_RELAXED) < 0) { - return 1; - } - - return 0; -}; - -// Critical section is eventually aborted. -int TestAbort() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { - return 1; - } - - struct rseq_cs cs = {}; - cs.version = 0; - cs.flags = 0; - cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - - reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); - - // Loops until abort. If this returns then abort occurred. - rseq_loop(&r, &cs); - - return 0; -}; - -// Abort may be before the critical section. -int TestAbortBefore() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { - return 1; - } - - struct rseq_cs cs = {}; - cs.version = 0; - cs.flags = 0; - cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - - reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_early_abort); - - // Loops until abort. If this returns then abort occurred. - rseq_loop(&r, &cs); - - return 0; -}; - -// Signature must match. -int TestAbortSignature() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); - sys_errno(ret) != 0) { - return 1; - } - - struct rseq_cs cs = {}; - cs.version = 0; - cs.flags = 0; - cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - - reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); - - // Loops until abort. This should SIGSEGV on abort. - rseq_loop(&r, &cs); - - return 1; -}; - -// Abort must not be in the critical section. -int TestAbortPreCommit() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); - sys_errno(ret) != 0) { - return 1; - } - - struct rseq_cs cs = {}; - cs.version = 0; - cs.flags = 0; - cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - - reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_pre_commit); - - // Loops until abort. This should SIGSEGV on abort. - rseq_loop(&r, &cs); - - return 1; -}; - -// rseq.rseq_cs is cleared on abort. -int TestAbortClearsCS() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { - return 1; - } - - struct rseq_cs cs = {}; - cs.version = 0; - cs.flags = 0; - cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - - reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); - - // Loops until abort. If this returns then abort occurred. - rseq_loop(&r, &cs); - - if (__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) { - return 1; - } - - return 0; -}; - -// rseq.rseq_cs is cleared on abort outside of critical section. -int TestInvalidAbortClearsCS() { - struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { - return 1; - } - - struct rseq_cs cs = {}; - cs.version = 0; - cs.flags = 0; - cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - - reinterpret_cast<uint64_t>(&rseq_loop_start); - cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); - - __atomic_store_n(&r.rseq_cs, &cs, __ATOMIC_RELAXED); - - // When the next abort condition occurs, the kernel will clear cs once it - // determines we aren't in the critical section. - while (1) { - if (!__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) { - break; - } - } - - return 0; -}; - -// Exit codes: -// 0 - Pass -// 1 - Fail -// 2 - Missing argument -// 3 - Unknown test case -extern "C" int main(int argc, char** argv, char** envp) { - if (argc != 2) { - // Usage: rseq <test case> - return 2; - } - - if (strcmp(argv[1], kRseqTestUnaligned) == 0) { - return TestUnaligned(); - } - if (strcmp(argv[1], kRseqTestRegister) == 0) { - return TestRegister(); - } - if (strcmp(argv[1], kRseqTestDoubleRegister) == 0) { - return TestDoubleRegister(); - } - if (strcmp(argv[1], kRseqTestRegisterUnregister) == 0) { - return TestRegisterUnregister(); - } - if (strcmp(argv[1], kRseqTestUnregisterDifferentPtr) == 0) { - return TestUnregisterDifferentPtr(); - } - if (strcmp(argv[1], kRseqTestUnregisterDifferentSignature) == 0) { - return TestUnregisterDifferentSignature(); - } - if (strcmp(argv[1], kRseqTestCPU) == 0) { - return TestCPU(); - } - if (strcmp(argv[1], kRseqTestAbort) == 0) { - return TestAbort(); - } - if (strcmp(argv[1], kRseqTestAbortBefore) == 0) { - return TestAbortBefore(); - } - if (strcmp(argv[1], kRseqTestAbortSignature) == 0) { - return TestAbortSignature(); - } - if (strcmp(argv[1], kRseqTestAbortPreCommit) == 0) { - return TestAbortPreCommit(); - } - if (strcmp(argv[1], kRseqTestAbortClearsCS) == 0) { - return TestAbortClearsCS(); - } - if (strcmp(argv[1], kRseqTestInvalidAbortClearsCS) == 0) { - return TestInvalidAbortClearsCS(); - } - - return 3; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/rseq/start.S b/test/syscalls/linux/rseq/start.S deleted file mode 100644 index b9611b276..000000000 --- a/test/syscalls/linux/rseq/start.S +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - - .text - .align 4 - .type _start,@function - .globl _start - -_start: - movq %rsp,%rdi - call __init - hlt - - .size _start,.-_start - .section .note.GNU-stack,"",@progbits - - .text - .globl raw_syscall - .type raw_syscall, @function - -raw_syscall: - mov %rdi,%rax // syscall # - mov %rsi,%rdi // arg0 - mov %rdx,%rsi // arg1 - mov %rcx,%rdx // arg2 - mov %r8,%r10 // arg3 (goes in r10 instead of rcx for system calls) - mov %r9,%r8 // arg4 - mov 0x8(%rsp),%r9 // arg5 - syscall - ret - - .size raw_syscall,.-raw_syscall - .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/syscalls.h b/test/syscalls/linux/rseq/syscalls.h deleted file mode 100644 index e5299c188..000000000 --- a/test/syscalls/linux/rseq/syscalls.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ - -#include "test/syscalls/linux/rseq/types.h" - -#ifdef __x86_64__ -// Syscall numbers. -constexpr int kGetpid = 39; -constexpr int kExitGroup = 231; -#else -#error "Unknown architecture" -#endif - -namespace gvisor { -namespace testing { - -// Standalone system call interfaces. -// Note that these are all "raw" system call interfaces which encode -// errors by setting the return value to a small negative number. -// Use sys_errno() to check system call return values for errors. - -// Maximum Linux error number. -constexpr int kMaxErrno = 4095; - -// Errno values. -#define EPERM 1 -#define EFAULT 14 -#define EBUSY 16 -#define EINVAL 22 - -// Get the error number from a raw system call return value. -// Returns a positive error number or 0 if there was no error. -static inline int sys_errno(uintptr_t rval) { - if (rval >= static_cast<uintptr_t>(-kMaxErrno)) { - return -static_cast<int>(rval); - } - return 0; -} - -extern "C" uintptr_t raw_syscall(int number, ...); - -static inline void sys_exit_group(int status) { - raw_syscall(kExitGroup, status); -} -static inline int sys_getpid() { - return static_cast<int>(raw_syscall(kGetpid)); -} - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h deleted file mode 100644 index 3b7bb74b1..000000000 --- a/test/syscalls/linux/rseq/test.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ - -namespace gvisor { -namespace testing { - -// Test cases supported by rseq binary. - -inline constexpr char kRseqTestUnaligned[] = "unaligned"; -inline constexpr char kRseqTestRegister[] = "register"; -inline constexpr char kRseqTestDoubleRegister[] = "double-register"; -inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; -inline constexpr char kRseqTestUnregisterDifferentPtr[] = - "unregister-different-ptr"; -inline constexpr char kRseqTestUnregisterDifferentSignature[] = - "unregister-different-signature"; -inline constexpr char kRseqTestCPU[] = "cpu"; -inline constexpr char kRseqTestAbort[] = "abort"; -inline constexpr char kRseqTestAbortBefore[] = "abort-before"; -inline constexpr char kRseqTestAbortSignature[] = "abort-signature"; -inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; -inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; -inline constexpr char kRseqTestInvalidAbortClearsCS[] = - "invalid-abort-clears-cs"; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ diff --git a/test/syscalls/linux/rseq/types.h b/test/syscalls/linux/rseq/types.h deleted file mode 100644 index b6afe9817..000000000 --- a/test/syscalls/linux/rseq/types.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ - -using size_t = __SIZE_TYPE__; -using uintptr_t = __UINTPTR_TYPE__; - -using uint8_t = __UINT8_TYPE__; -using uint16_t = __UINT16_TYPE__; -using uint32_t = __UINT32_TYPE__; -using uint64_t = __UINT64_TYPE__; - -using int8_t = __INT8_TYPE__; -using int16_t = __INT16_TYPE__; -using int32_t = __INT32_TYPE__; -using int64_t = __INT64_TYPE__; - -#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h deleted file mode 100644 index ca1d67691..000000000 --- a/test/syscalls/linux/rseq/uapi.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ - -#include <stdint.h> - -// User-kernel ABI for restartable sequences. - -#ifdef __x86_64__ -// Syscall numbers. -constexpr int kRseqSyscall = 334; -#else -#error "Unknown architecture" -#endif // __x86_64__ - -struct rseq_cs { - uint32_t version; - uint32_t flags; - uint64_t start_ip; - uint64_t post_commit_offset; - uint64_t abort_ip; -} __attribute__((aligned(4 * sizeof(uint64_t)))); - -// N.B. alignment is enforced by the kernel. -struct rseq { - uint32_t cpu_id_start; - uint32_t cpu_id; - struct rseq_cs* rseq_cs; - uint32_t flags; -} __attribute__((aligned(4 * sizeof(uint64_t)))); - -constexpr int kRseqFlagUnregister = 1 << 0; - -constexpr int kRseqCPUIDUninitialized = -1; - -#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc deleted file mode 100644 index ed27e2566..000000000 --- a/test/syscalls/linux/rtsignal.cc +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include <cerrno> -#include <csignal> - -#include "gtest/gtest.h" -#include "test/util/cleanup.h" -#include "test/util/logging.h" -#include "test/util/posix_error.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// saved_info is set by the handler. -siginfo_t saved_info; - -// has_saved_info is set to true by the handler. -volatile bool has_saved_info; - -void SigHandler(int sig, siginfo_t* info, void* context) { - // Copy to the given info. - saved_info = *info; - has_saved_info = true; -} - -void ClearSavedInfo() { - // Clear the cached info. - memset(&saved_info, 0, sizeof(saved_info)); - has_saved_info = false; -} - -PosixErrorOr<Cleanup> SetupSignalHandler(int sig) { - struct sigaction sa; - sa.sa_sigaction = SigHandler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - return ScopedSigaction(sig, sa); -} - -class RtSignalTest : public ::testing::Test { - protected: - void SetUp() override { - action_cleanup_ = ASSERT_NO_ERRNO_AND_VALUE(SetupSignalHandler(SIGUSR1)); - mask_cleanup_ = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGUSR1)); - } - - void TearDown() override { ClearSavedInfo(); } - - private: - Cleanup action_cleanup_; - Cleanup mask_cleanup_; -}; - -static int rt_sigqueueinfo(pid_t tgid, int sig, siginfo_t* uinfo) { - int ret; - do { - // NOTE(b/25434735): rt_sigqueueinfo(2) could return EAGAIN for RT signals. - ret = syscall(SYS_rt_sigqueueinfo, tgid, sig, uinfo); - } while (ret == -1 && errno == EAGAIN); - return ret; -} - -TEST_F(RtSignalTest, InvalidTID) { - siginfo_t uinfo; - // Depending on the kernel version, these calls may fail with - // ESRCH (goobunutu machines) or EPERM (production machines). Thus, - // the test simply ensures that they do fail. - EXPECT_THAT(rt_sigqueueinfo(-1, SIGUSR1, &uinfo), SyscallFails()); - EXPECT_FALSE(has_saved_info); - EXPECT_THAT(rt_sigqueueinfo(0, SIGUSR1, &uinfo), SyscallFails()); - EXPECT_FALSE(has_saved_info); -} - -TEST_F(RtSignalTest, InvalidCodes) { - siginfo_t uinfo; - - // We need a child for the code checks to apply. If the process is delivering - // to itself, then it can use whatever codes it wants and they will go - // through. - pid_t child = fork(); - if (child == 0) { - _exit(1); - } - ASSERT_THAT(child, SyscallSucceeds()); - - // These are not allowed for child processes. - uinfo.si_code = 0; // SI_USER. - EXPECT_THAT(rt_sigqueueinfo(child, SIGUSR1, &uinfo), - SyscallFailsWithErrno(EPERM)); - uinfo.si_code = 0x80; // SI_KERNEL. - EXPECT_THAT(rt_sigqueueinfo(child, SIGUSR1, &uinfo), - SyscallFailsWithErrno(EPERM)); - uinfo.si_code = -6; // SI_TKILL. - EXPECT_THAT(rt_sigqueueinfo(child, SIGUSR1, &uinfo), - SyscallFailsWithErrno(EPERM)); - uinfo.si_code = -1; // SI_QUEUE (allowed). - EXPECT_THAT(rt_sigqueueinfo(child, SIGUSR1, &uinfo), SyscallSucceeds()); - - // Join the child process. - EXPECT_THAT(waitpid(child, nullptr, 0), SyscallSucceeds()); -} - -TEST_F(RtSignalTest, ValueDelivered) { - siginfo_t uinfo; - uinfo.si_code = -1; // SI_QUEUE (allowed). - uinfo.si_errno = 0x1234; - - EXPECT_EQ(saved_info.si_errno, 0x0); - EXPECT_THAT(rt_sigqueueinfo(getpid(), SIGUSR1, &uinfo), SyscallSucceeds()); - EXPECT_TRUE(has_saved_info); - EXPECT_EQ(saved_info.si_errno, 0x1234); -} - -TEST_F(RtSignalTest, SignoMatch) { - auto action2_cleanup = ASSERT_NO_ERRNO_AND_VALUE(SetupSignalHandler(SIGUSR2)); - auto mask2_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGUSR2)); - - siginfo_t uinfo; - uinfo.si_code = -1; // SI_QUEUE (allowed). - - EXPECT_THAT(rt_sigqueueinfo(getpid(), SIGUSR1, &uinfo), SyscallSucceeds()); - EXPECT_TRUE(has_saved_info); - EXPECT_EQ(saved_info.si_signo, SIGUSR1); - - ClearSavedInfo(); - - EXPECT_THAT(rt_sigqueueinfo(getpid(), SIGUSR2, &uinfo), SyscallSucceeds()); - EXPECT_TRUE(has_saved_info); - EXPECT_EQ(saved_info.si_signo, SIGUSR2); -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - // These tests depend on delivering SIGUSR1/2 to the main thread (so they can - // synchronously check has_saved_info). Block these so that any other threads - // created by TestInit will also have them blocked. - sigset_t set; - sigemptyset(&set); - sigaddset(&set, SIGUSR1); - sigaddset(&set, SIGUSR2); - TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); - - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/sched.cc b/test/syscalls/linux/sched.cc deleted file mode 100644 index 735e99411..000000000 --- a/test/syscalls/linux/sched.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <sched.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// In linux, pid is limited to 29 bits because how futex is implemented. -constexpr int kImpossiblePID = (1 << 29) + 1; - -TEST(SchedGetparamTest, ReturnsZero) { - struct sched_param param; - EXPECT_THAT(sched_getparam(getpid(), ¶m), SyscallSucceeds()); - EXPECT_EQ(param.sched_priority, 0); - EXPECT_THAT(sched_getparam(/*pid=*/0, ¶m), SyscallSucceeds()); - EXPECT_EQ(param.sched_priority, 0); -} - -TEST(SchedGetparamTest, InvalidPIDReturnsEINVAL) { - struct sched_param param; - EXPECT_THAT(sched_getparam(/*pid=*/-1, ¶m), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SchedGetparamTest, ImpossiblePIDReturnsESRCH) { - struct sched_param param; - EXPECT_THAT(sched_getparam(kImpossiblePID, ¶m), - SyscallFailsWithErrno(ESRCH)); -} - -TEST(SchedGetparamTest, NullParamReturnsEINVAL) { - EXPECT_THAT(sched_getparam(0, nullptr), SyscallFailsWithErrno(EINVAL)); -} - -TEST(SchedGetschedulerTest, ReturnsSchedOther) { - EXPECT_THAT(sched_getscheduler(getpid()), - SyscallSucceedsWithValue(SCHED_OTHER)); - EXPECT_THAT(sched_getscheduler(/*pid=*/0), - SyscallSucceedsWithValue(SCHED_OTHER)); -} - -TEST(SchedGetschedulerTest, ReturnsEINVAL) { - EXPECT_THAT(sched_getscheduler(/*pid=*/-1), SyscallFailsWithErrno(EINVAL)); -} - -TEST(SchedGetschedulerTest, ReturnsESRCH) { - EXPECT_THAT(sched_getscheduler(kImpossiblePID), SyscallFailsWithErrno(ESRCH)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sched_yield.cc b/test/syscalls/linux/sched_yield.cc deleted file mode 100644 index 5d24f5b58..000000000 --- a/test/syscalls/linux/sched_yield.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sched.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(SchedYieldTest, Success) { - EXPECT_THAT(sched_yield(), SyscallSucceeds()); - EXPECT_THAT(sched_yield(), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc deleted file mode 100644 index 8e0fc9acc..000000000 --- a/test/syscalls/linux/seccomp.cc +++ /dev/null @@ -1,415 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <linux/audit.h> -#include <linux/filter.h> -#include <linux/seccomp.h> -#include <pthread.h> -#include <sched.h> -#include <signal.h> -#include <string.h> -#include <sys/prctl.h> -#include <sys/syscall.h> -#include <time.h> -#include <ucontext.h> -#include <unistd.h> - -#include <atomic> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "test/util/logging.h" -#include "test/util/memory_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/proc_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -#ifndef SYS_SECCOMP -#define SYS_SECCOMP 1 -#endif - -namespace gvisor { -namespace testing { - -namespace { - -// A syscall not implemented by Linux that we don't expect to be called. -#ifdef __x86_64__ -constexpr uint32_t kFilteredSyscall = SYS_vserver; -#elif __aarch64__ -// Use the last of arch_specific_syscalls which are not implemented on arm64. -constexpr uint32_t kFilteredSyscall = __NR_arch_specific_syscall + 15; -#endif - -// Applies a seccomp-bpf filter that returns `filtered_result` for -// `sysno` and allows all other syscalls. Async-signal-safe. -void ApplySeccompFilter(uint32_t sysno, uint32_t filtered_result, - uint32_t flags = 0) { - // "Prior to [PR_SET_SECCOMP], the task must call prctl(PR_SET_NO_NEW_PRIVS, - // 1) or run with CAP_SYS_ADMIN privileges in its namespace." - - // Documentation/prctl/seccomp_filter.txt - // - // prctl(PR_SET_NO_NEW_PRIVS, 1) may be called repeatedly; calls after the - // first are no-ops. - TEST_PCHECK(prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) == 0); - MaybeSave(); - - struct sock_filter filter[] = { - // A = seccomp_data.arch - BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4), - // if (A != AUDIT_ARCH_X86_64) goto kill - BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4), - // A = seccomp_data.nr - BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0), - // if (A != sysno) goto allow - BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1), - // return filtered_result - BPF_STMT(BPF_RET | BPF_K, filtered_result), - // allow: return SECCOMP_RET_ALLOW - BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW), - // kill: return SECCOMP_RET_KILL - BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL), - }; - struct sock_fprog prog; - prog.len = ABSL_ARRAYSIZE(filter); - prog.filter = filter; - if (flags) { - TEST_CHECK(syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER, flags, &prog) == - 0); - } else { - TEST_PCHECK(prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, &prog, 0, 0) == 0); - } - MaybeSave(); -} - -// Wrapper for sigaction. Async-signal-safe. -void RegisterSignalHandler(int signum, - void (*handler)(int, siginfo_t*, void*)) { - struct sigaction sa = {}; - sa.sa_sigaction = handler; - sigemptyset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - TEST_PCHECK(sigaction(signum, &sa, nullptr) == 0); - MaybeSave(); -} - -// All of the following tests execute in a subprocess to ensure that each test -// is run in a separate process. This avoids cross-contamination of seccomp -// state between tests, and is necessary to ensure that test processes killed -// by SECCOMP_RET_KILL are single-threaded (since SECCOMP_RET_KILL only kills -// the offending thread, not the whole thread group). - -TEST(SeccompTest, RetKillCausesDeathBySIGSYS) { - pid_t const pid = fork(); - if (pid == 0) { - // Register a signal handler for SIGSYS that we don't expect to be invoked. - RegisterSignalHandler( - SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL); - syscall(kFilteredSyscall); - TEST_CHECK_MSG(false, "Survived invocation of test syscall"); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSYS) - << "status " << status; -} - -TEST(SeccompTest, RetKillOnlyKillsOneThread) { - Mapping stack = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - - pid_t const pid = fork(); - if (pid == 0) { - // Register a signal handler for SIGSYS that we don't expect to be invoked. - RegisterSignalHandler( - SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL); - // Pass CLONE_VFORK to block the original thread in the child process until - // the clone thread exits with SIGSYS. - // - // N.B. clone(2) is not officially async-signal-safe, but at minimum glibc's - // x86_64 implementation is safe. See glibc - // sysdeps/unix/sysv/linux/x86_64/clone.S. - clone( - +[](void* arg) { - syscall(kFilteredSyscall); // should kill the thread - _exit(1); // should be unreachable - return 2; // should be very unreachable, shut up the compiler - }, - stack.endptr(), - CLONE_FILES | CLONE_FS | CLONE_SIGHAND | CLONE_THREAD | CLONE_VM | - CLONE_VFORK, - nullptr); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -TEST(SeccompTest, RetTrapCausesSIGSYS) { - pid_t const pid = fork(); - if (pid == 0) { - constexpr uint16_t kTrapValue = 0xdead; - RegisterSignalHandler( - SIGSYS, +[](int signo, siginfo_t* info, void* ucv) { - ucontext_t* uc = static_cast<ucontext_t*>(ucv); - // This is a signal handler, so we must stay async-signal-safe. - TEST_CHECK(info->si_signo == SIGSYS); - TEST_CHECK(info->si_code == SYS_SECCOMP); - TEST_CHECK(info->si_errno == kTrapValue); - TEST_CHECK(info->si_call_addr != nullptr); - TEST_CHECK(info->si_syscall == kFilteredSyscall); -#ifdef __x86_64__ - TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64); - TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == kFilteredSyscall); -#endif // defined(__x86_64__) - _exit(0); - }); - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_TRAP | kTrapValue); - syscall(kFilteredSyscall); - TEST_CHECK_MSG(false, "Survived invocation of test syscall"); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -#ifdef __x86_64__ - -constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; - -time_t vsyscall_time(time_t* t) { - return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t); -} - -TEST(SeccompTest, SeccompAppliesToVsyscall) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled())); - - pid_t const pid = fork(); - if (pid == 0) { - constexpr uint16_t kTrapValue = 0xdead; - RegisterSignalHandler( - SIGSYS, +[](int signo, siginfo_t* info, void* ucv) { - ucontext_t* uc = static_cast<ucontext_t*>(ucv); - // This is a signal handler, so we must stay async-signal-safe. - TEST_CHECK(info->si_signo == SIGSYS); - TEST_CHECK(info->si_code == SYS_SECCOMP); - TEST_CHECK(info->si_errno == kTrapValue); - TEST_CHECK(info->si_call_addr != nullptr); - TEST_CHECK(info->si_syscall == SYS_time); - TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64); - TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == SYS_time); - _exit(0); - }); - ApplySeccompFilter(SYS_time, SECCOMP_RET_TRAP | kTrapValue); - vsyscall_time(nullptr); // Should result in death. - TEST_CHECK_MSG(false, "Survived invocation of test syscall"); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -TEST(SeccompTest, RetKillVsyscallCausesDeathBySIGSYS) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled())); - - pid_t const pid = fork(); - if (pid == 0) { - // Register a signal handler for SIGSYS that we don't expect to be invoked. - RegisterSignalHandler( - SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); - ApplySeccompFilter(SYS_time, SECCOMP_RET_KILL); - vsyscall_time(nullptr); // Should result in death. - TEST_CHECK_MSG(false, "Survived invocation of test syscall"); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSYS) - << "status " << status; -} - -#endif // defined(__x86_64__) - -TEST(SeccompTest, RetTraceWithoutPtracerReturnsENOSYS) { - pid_t const pid = fork(); - if (pid == 0) { - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_TRACE); - TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOSYS); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -TEST(SeccompTest, RetErrnoReturnsErrno) { - pid_t const pid = fork(); - if (pid == 0) { - // ENOTNAM: "Not a XENIX named type file" - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ERRNO | ENOTNAM); - TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOTNAM); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -TEST(SeccompTest, RetAllowAllowsSyscall) { - pid_t const pid = fork(); - if (pid == 0) { - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ALLOW); - TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOSYS); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -// This test will validate that TSYNC will apply to all threads. -TEST(SeccompTest, TsyncAppliesToAllThreads) { - Mapping stack = ASSERT_NO_ERRNO_AND_VALUE( - MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - - // We don't want to apply this policy to other test runner threads, so fork. - const pid_t pid = fork(); - - if (pid == 0) { - // First check that we receive a ENOSYS before the policy is applied. - TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOSYS); - - // N.B. clone(2) is not officially async-signal-safe, but at minimum glibc's - // x86_64 implementation is safe. See glibc - // sysdeps/unix/sysv/linux/x86_64/clone.S. - clone( - +[](void* arg) { - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ERRNO | ENOTNAM, - SECCOMP_FILTER_FLAG_TSYNC); - return 0; - }, - stack.endptr(), - CLONE_FILES | CLONE_FS | CLONE_SIGHAND | CLONE_THREAD | CLONE_VM | - CLONE_VFORK, - nullptr); - - // Because we're using CLONE_VFORK this thread will be blocked until - // the second thread has released resources to our virtual memory, since - // we're not execing that will happen on _exit. - - // Now verify that the policy applied to this thread too. - TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOTNAM); - _exit(0); - } - - ASSERT_THAT(pid, SyscallSucceeds()); - int status = 0; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -// This test will validate that seccomp(2) rejects unsupported flags. -TEST(SeccompTest, SeccompRejectsUnknownFlags) { - constexpr uint32_t kInvalidFlag = 123; - ASSERT_THAT( - syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER, kInvalidFlag, nullptr), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SeccompTest, LeastPermissiveFilterReturnValueApplies) { - // This is RetKillCausesDeathBySIGSYS, plus extra filters before and after the - // one that causes the kill that should be ignored. - pid_t const pid = fork(); - if (pid == 0) { - RegisterSignalHandler( - SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_TRACE); - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL); - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ERRNO | ENOTNAM); - syscall(kFilteredSyscall); - TEST_CHECK_MSG(false, "Survived invocation of test syscall"); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSYS) - << "status " << status; -} - -// Passed as argv[1] to cause the test binary to invoke kFilteredSyscall and -// exit. Not a real flag since flag parsing happens during initialization, -// which may create threads. -constexpr char kInvokeFilteredSyscallFlag[] = "--seccomp_test_child"; - -TEST(SeccompTest, FiltersPreservedAcrossForkAndExecve) { - ExecveArray const grandchild_argv( - {"/proc/self/exe", kInvokeFilteredSyscallFlag}); - - pid_t const pid = fork(); - if (pid == 0) { - ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL); - pid_t const grandchild_pid = fork(); - if (grandchild_pid == 0) { - execve(grandchild_argv.get()[0], grandchild_argv.get(), - /* envp = */ nullptr); - TEST_PCHECK_MSG(false, "execve failed"); - } - int status; - TEST_PCHECK(waitpid(grandchild_pid, &status, 0) == grandchild_pid); - TEST_CHECK(WIFSIGNALED(status) && WTERMSIG(status) == SIGSYS); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status " << status; -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - if (argc >= 2 && - strcmp(argv[1], gvisor::testing::kInvokeFilteredSyscallFlag) == 0) { - syscall(gvisor::testing::kFilteredSyscall); - exit(0); - } - - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc deleted file mode 100644 index be2364fb8..000000000 --- a/test/syscalls/linux/select.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/resource.h> -#include <sys/select.h> -#include <sys/time.h> - -#include <climits> -#include <csignal> -#include <cstdio> - -#include "gtest/gtest.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/base_poll_test.h" -#include "test/util/file_descriptor.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/rlimit_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -class SelectTest : public BasePollTest { - protected: - void SetUp() override { BasePollTest::SetUp(); } - void TearDown() override { BasePollTest::TearDown(); } -}; - -// See that when there are no FD sets, select behaves like sleep. -TEST_F(SelectTest, NullFds) { - struct timeval timeout = absl::ToTimeval(absl::Milliseconds(10)); - ASSERT_THAT(select(0, nullptr, nullptr, nullptr, &timeout), - SyscallSucceeds()); - EXPECT_EQ(timeout.tv_sec, 0); - EXPECT_EQ(timeout.tv_usec, 0); - - timeout = absl::ToTimeval(absl::Milliseconds(10)); - ASSERT_THAT(select(1, nullptr, nullptr, nullptr, &timeout), - SyscallSucceeds()); - EXPECT_EQ(timeout.tv_sec, 0); - EXPECT_EQ(timeout.tv_usec, 0); -} - -TEST_F(SelectTest, NegativeNfds) { - EXPECT_THAT(select(-1, nullptr, nullptr, nullptr, nullptr), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(select(-100000, nullptr, nullptr, nullptr, nullptr), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(select(INT_MIN, nullptr, nullptr, nullptr, nullptr), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(SelectTest, ClosedFds) { - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDONLY)); - - // We can't rely on a file descriptor being closed in a multi threaded - // application so fork to get a clean process. - EXPECT_THAT(InForkedProcess([&] { - int fd_num = fd.get(); - fd.reset(); - - fd_set read_set; - FD_ZERO(&read_set); - FD_SET(fd_num, &read_set); - - struct timeval timeout = - absl::ToTimeval(absl::Milliseconds(10)); - TEST_PCHECK(select(fd_num + 1, &read_set, nullptr, nullptr, - &timeout) != 0); - TEST_PCHECK(errno == EBADF); - }), - IsPosixErrorOkAndHolds(0)); -} - -TEST_F(SelectTest, ZeroTimeout) { - struct timeval timeout = {}; - EXPECT_THAT(select(1, nullptr, nullptr, nullptr, &timeout), - SyscallSucceeds()); - // Ignore timeout as its value is now undefined. -} - -// If random S/R interrupts the select, SIGALRM may be delivered before select -// restarts, causing the select to hang forever. -TEST_F(SelectTest, NoTimeout_NoRandomSave) { - // When there's no timeout, select may never return so set a timer. - SetTimer(absl::Milliseconds(100)); - // See that we get interrupted by the timer. - ASSERT_THAT(select(1, nullptr, nullptr, nullptr, nullptr), - SyscallFailsWithErrno(EINTR)); - EXPECT_TRUE(TimerFired()); -} - -TEST_F(SelectTest, InvalidTimeoutNegative) { - struct timeval timeout = absl::ToTimeval(absl::Microseconds(-1)); - EXPECT_THAT(select(1, nullptr, nullptr, nullptr, &timeout), - SyscallFailsWithErrno(EINVAL)); - // Ignore timeout as its value is now undefined. -} - -// Verify that a signal interrupts select. -// -// If random S/R interrupts the select, SIGALRM may be delivered before select -// restarts, causing the select to hang forever. -TEST_F(SelectTest, InterruptedBySignal_NoRandomSave) { - absl::Duration duration(absl::Seconds(5)); - struct timeval timeout = absl::ToTimeval(duration); - SetTimer(absl::Milliseconds(100)); - ASSERT_FALSE(TimerFired()); - ASSERT_THAT(select(1, nullptr, nullptr, nullptr, &timeout), - SyscallFailsWithErrno(EINTR)); - EXPECT_TRUE(TimerFired()); - // Ignore timeout as its value is now undefined. -} - -TEST_F(SelectTest, IgnoreBitsAboveNfds) { - // fd_set is a bit array with at least FD_SETSIZE bits. Test that bits - // corresponding to file descriptors above nfds are ignored. - fd_set read_set; - FD_ZERO(&read_set); - constexpr int kNfds = 1; - for (int fd = kNfds; fd < FD_SETSIZE; fd++) { - FD_SET(fd, &read_set); - } - // Pass a zero timeout so that select returns immediately. - struct timeval timeout = {}; - EXPECT_THAT(select(kNfds, &read_set, nullptr, nullptr, &timeout), - SyscallSucceedsWithValue(0)); -} - -// This test illustrates Linux's behavior of 'select' calls passing after -// setrlimit RLIMIT_NOFILE is called. In particular, versions of sshd rely on -// this behavior. See b/122318458. -TEST_F(SelectTest, SetrlimitCallNOFILE) { - fd_set read_set; - FD_ZERO(&read_set); - timeval timeout = {}; - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( - Open(NewTempAbsPath(), O_RDONLY | O_CREAT, S_IRUSR)); - - Cleanup reset_rlimit = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_NOFILE, 0)); - - FD_SET(fd.get(), &read_set); - // this call with zero timeout should return immediately - EXPECT_THAT(select(fd.get() + 1, &read_set, nullptr, nullptr, &timeout), - SyscallSucceeds()); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc deleted file mode 100644 index e9b131ca9..000000000 --- a/test/syscalls/linux/semaphore.cc +++ /dev/null @@ -1,491 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/ipc.h> -#include <sys/sem.h> -#include <sys/types.h> - -#include <atomic> -#include <cerrno> -#include <ctime> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/memory/memory.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "test/util/capability_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { -namespace { - -class AutoSem { - public: - explicit AutoSem(int id) : id_(id) {} - ~AutoSem() { - if (id_ >= 0) { - EXPECT_THAT(semctl(id_, 0, IPC_RMID), SyscallSucceeds()); - } - } - - int release() { - int old = id_; - id_ = -1; - return old; - } - - int get() { return id_; } - - private: - int id_ = -1; -}; - -TEST(SemaphoreTest, SemGet) { - // Test creation and lookup. - AutoSem sem(semget(1, 10, IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - EXPECT_THAT(semget(1, 10, IPC_CREAT), SyscallSucceedsWithValue(sem.get())); - EXPECT_THAT(semget(1, 9, IPC_CREAT), SyscallSucceedsWithValue(sem.get())); - - // Creation and lookup failure cases. - EXPECT_THAT(semget(1, 11, IPC_CREAT), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(semget(1, -1, IPC_CREAT), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(semget(1, 10, IPC_CREAT | IPC_EXCL), - SyscallFailsWithErrno(EEXIST)); - EXPECT_THAT(semget(2, 1, 0), SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(semget(2, 0, IPC_CREAT), SyscallFailsWithErrno(EINVAL)); - - // Private semaphores never conflict. - AutoSem sem2(semget(IPC_PRIVATE, 1, 0)); - AutoSem sem3(semget(IPC_PRIVATE, 1, 0)); - ASSERT_THAT(sem2.get(), SyscallSucceeds()); - EXPECT_NE(sem.get(), sem2.get()); - ASSERT_THAT(sem3.get(), SyscallSucceeds()); - EXPECT_NE(sem3.get(), sem2.get()); -} - -// Tests simple operations that shouldn't block in a single-thread. -TEST(SemaphoreTest, SemOpSingleNoBlock) { - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - struct sembuf buf = {}; - buf.sem_op = 1; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); - - buf.sem_op = -1; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); - - buf.sem_op = 0; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); - - // Error cases with invalid values. - ASSERT_THAT(semop(sem.get() + 1, &buf, 1), SyscallFailsWithErrno(EINVAL)); - - buf.sem_num = 1; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EFBIG)); - - ASSERT_THAT(semop(sem.get(), nullptr, 0), SyscallFailsWithErrno(EINVAL)); -} - -// Tests multiple operations that shouldn't block in a single-thread. -TEST(SemaphoreTest, SemOpMultiNoBlock) { - AutoSem sem(semget(IPC_PRIVATE, 4, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - struct sembuf bufs[5] = {}; - bufs[0].sem_num = 0; - bufs[0].sem_op = 10; - bufs[0].sem_flg = 0; - - bufs[1].sem_num = 1; - bufs[1].sem_op = 2; - bufs[1].sem_flg = 0; - - bufs[2].sem_num = 2; - bufs[2].sem_op = 3; - bufs[2].sem_flg = 0; - - bufs[3].sem_num = 0; - bufs[3].sem_op = -5; - bufs[3].sem_flg = 0; - - bufs[4].sem_num = 2; - bufs[4].sem_op = 2; - bufs[4].sem_flg = 0; - - ASSERT_THAT(semop(sem.get(), bufs, ABSL_ARRAYSIZE(bufs)), SyscallSucceeds()); - - ASSERT_THAT(semctl(sem.get(), 0, GETVAL), SyscallSucceedsWithValue(5)); - ASSERT_THAT(semctl(sem.get(), 1, GETVAL), SyscallSucceedsWithValue(2)); - ASSERT_THAT(semctl(sem.get(), 2, GETVAL), SyscallSucceedsWithValue(5)); - ASSERT_THAT(semctl(sem.get(), 3, GETVAL), SyscallSucceedsWithValue(0)); - - for (auto& b : bufs) { - b.sem_op = -b.sem_op; - } - // 0 and 3 order must be reversed, otherwise it will block. - std::swap(bufs[0].sem_op, bufs[3].sem_op); - ASSERT_THAT(RetryEINTR(semop)(sem.get(), bufs, ABSL_ARRAYSIZE(bufs)), - SyscallSucceeds()); - - // All semaphores should be back to 0 now. - for (size_t i = 0; i < 4; ++i) { - ASSERT_THAT(semctl(sem.get(), i, GETVAL), SyscallSucceedsWithValue(0)); - } -} - -// Makes a best effort attempt to ensure that operation would block. -TEST(SemaphoreTest, SemOpBlock) { - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - std::atomic<int> blocked = ATOMIC_VAR_INIT(1); - ScopedThread th([&sem, &blocked] { - absl::SleepFor(absl::Milliseconds(100)); - ASSERT_EQ(blocked.load(), 1); - - struct sembuf buf = {}; - buf.sem_op = 1; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - }); - - struct sembuf buf = {}; - buf.sem_op = -1; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - blocked.store(0); -} - -// Tests that IPC_NOWAIT returns with no wait. -TEST(SemaphoreTest, SemOpNoBlock) { - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - struct sembuf buf = {}; - buf.sem_flg = IPC_NOWAIT; - - buf.sem_op = -1; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EAGAIN)); - - buf.sem_op = 1; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); - - buf.sem_op = 0; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EAGAIN)); -} - -// Test runs 2 threads, one signals the other waits the same number of times. -TEST(SemaphoreTest, SemOpSimple) { - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - constexpr size_t kLoops = 100; - ScopedThread th([&sem] { - struct sembuf buf = {}; - buf.sem_op = 1; - for (size_t i = 0; i < kLoops; i++) { - // Sleep to prevent making all increments in one shot without letting - // the waiter wait. - absl::SleepFor(absl::Milliseconds(1)); - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); - } - }); - - struct sembuf buf = {}; - buf.sem_op = -1; - for (size_t i = 0; i < kLoops; i++) { - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - } -} - -// Tests that semaphore can be removed while there are waiters. -// NoRandomSave: Test relies on timing that random save throws off. -TEST(SemaphoreTest, SemOpRemoveWithWaiter_NoRandomSave) { - AutoSem sem(semget(IPC_PRIVATE, 2, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - ScopedThread th([&sem] { - absl::SleepFor(absl::Milliseconds(250)); - ASSERT_THAT(semctl(sem.release(), 0, IPC_RMID), SyscallSucceeds()); - }); - - // This must happen before IPC_RMID runs above. Otherwise it fails with EINVAL - // instead because the semaphore has already been removed. - struct sembuf buf = {}; - buf.sem_op = -1; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), - SyscallFailsWithErrno(EIDRM)); -} - -// Semaphore isn't fair. It will execute any waiter that can satisfy the -// request even if it gets in front of other waiters. -TEST(SemaphoreTest, SemOpBestFitExecution) { - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - ScopedThread th([&sem] { - struct sembuf buf = {}; - buf.sem_op = -2; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallFails()); - // Ensure that wait will only unblock when the semaphore is removed. On - // EINTR retry it may race with deletion and return EINVAL. - ASSERT_TRUE(errno == EIDRM || errno == EINVAL) << "errno=" << errno; - }); - - // Ensures that '-1' below will unblock even though '-10' above is waiting - // for the same semaphore. - for (size_t i = 0; i < 10; ++i) { - struct sembuf buf = {}; - buf.sem_op = 1; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - - absl::SleepFor(absl::Milliseconds(10)); - - buf.sem_op = -1; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - } - - ASSERT_THAT(semctl(sem.release(), 0, IPC_RMID), SyscallSucceeds()); -} - -// Executes random operations in multiple threads and verify correctness. -TEST(SemaphoreTest, SemOpRandom) { - // Don't do cooperative S/R tests because there are too many syscalls in - // this test, - const DisableSave ds; - - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - // Protects the seed below. - absl::Mutex mutex; - uint32_t seed = time(nullptr); - - int count = 0; // Tracks semaphore value. - bool done = false; // Tells waiters to stop after signal threads are done. - - // These threads will wait in a loop. - std::unique_ptr<ScopedThread> decs[5]; - for (auto& dec : decs) { - dec = absl::make_unique<ScopedThread>([&sem, &mutex, &count, &seed, &done] { - for (size_t i = 0; i < 500; ++i) { - int16_t val; - { - absl::MutexLock l(&mutex); - if (done) { - return; - } - val = (rand_r(&seed) % 10 + 1); // Rand between 1 and 10. - count -= val; - } - struct sembuf buf = {}; - buf.sem_op = -val; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - absl::SleepFor(absl::Milliseconds(val * 2)); - } - }); - } - - // These threads will wait for zero in a loop. - std::unique_ptr<ScopedThread> zeros[5]; - for (auto& zero : zeros) { - zero = absl::make_unique<ScopedThread>([&sem, &mutex, &done] { - for (size_t i = 0; i < 500; ++i) { - { - absl::MutexLock l(&mutex); - if (done) { - return; - } - } - struct sembuf buf = {}; - buf.sem_op = 0; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - absl::SleepFor(absl::Milliseconds(10)); - } - }); - } - - // These threads will signal in a loop. - std::unique_ptr<ScopedThread> incs[5]; - for (auto& inc : incs) { - inc = absl::make_unique<ScopedThread>([&sem, &mutex, &count, &seed] { - for (size_t i = 0; i < 500; ++i) { - int16_t val; - { - absl::MutexLock l(&mutex); - val = (rand_r(&seed) % 10 + 1); // Rand between 1 and 10. - count += val; - } - struct sembuf buf = {}; - buf.sem_op = val; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); - absl::SleepFor(absl::Milliseconds(val * 2)); - } - }); - } - - // First wait for signal threads to be done. - for (auto& inc : incs) { - inc->Join(); - } - - // Now there could be waiters blocked (remember operations are random). - // Notify waiters that we're done and signal semaphore just the right amount. - { - absl::MutexLock l(&mutex); - done = true; - struct sembuf buf = {}; - buf.sem_op = -count; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); - } - - // Now all waiters should unblock and exit. - for (auto& dec : decs) { - dec->Join(); - } - for (auto& zero : zeros) { - zero->Join(); - } -} - -TEST(SemaphoreTest, SemOpNamespace) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - AutoSem sem(semget(123, 1, 0600 | IPC_CREAT | IPC_EXCL)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - ScopedThread([]() { - EXPECT_THAT(unshare(CLONE_NEWIPC), SyscallSucceeds()); - AutoSem sem(semget(123, 1, 0600 | IPC_CREAT | IPC_EXCL)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - }); -} - -TEST(SemaphoreTest, SemCtlVal) { - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - // Semaphore must start with 0. - EXPECT_THAT(semctl(sem.get(), 0, GETVAL), SyscallSucceedsWithValue(0)); - - // Increase value and ensure waiters are woken up. - ScopedThread th([&sem] { - struct sembuf buf = {}; - buf.sem_op = -10; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - }); - - ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 9), SyscallSucceeds()); - EXPECT_THAT(semctl(sem.get(), 0, GETVAL), SyscallSucceedsWithValue(9)); - - ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 20), SyscallSucceeds()); - const int value = semctl(sem.get(), 0, GETVAL); - // 10 or 20 because it could have raced with waiter above. - EXPECT_TRUE(value == 10 || value == 20) << "value=" << value; - th.Join(); - - // Set it back to 0 and ensure that waiters are woken up. - ScopedThread thZero([&sem] { - struct sembuf buf = {}; - buf.sem_op = 0; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - }); - ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 0), SyscallSucceeds()); - EXPECT_THAT(semctl(sem.get(), 0, GETVAL), SyscallSucceedsWithValue(0)); - thZero.Join(); -} - -TEST(SemaphoreTest, SemCtlValAll) { - AutoSem sem(semget(IPC_PRIVATE, 3, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - // Semaphores must start with 0. - uint16_t get[3] = {10, 10, 10}; - EXPECT_THAT(semctl(sem.get(), 1, GETALL, get), SyscallSucceedsWithValue(0)); - for (auto v : get) { - EXPECT_EQ(v, 0); - } - - // SetAll and check that they were set. - uint16_t vals[3] = {0, 10, 20}; - EXPECT_THAT(semctl(sem.get(), 1, SETALL, vals), SyscallSucceedsWithValue(0)); - EXPECT_THAT(semctl(sem.get(), 1, GETALL, get), SyscallSucceedsWithValue(0)); - for (size_t i = 0; i < ABSL_ARRAYSIZE(vals); ++i) { - EXPECT_EQ(get[i], vals[i]); - } - - EXPECT_THAT(semctl(sem.get(), 1, SETALL, nullptr), - SyscallFailsWithErrno(EFAULT)); -} - -TEST(SemaphoreTest, SemCtlGetPid) { - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); - EXPECT_THAT(semctl(sem.get(), 0, GETPID), SyscallSucceedsWithValue(getpid())); -} - -TEST(SemaphoreTest, SemCtlGetPidFork) { - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - const pid_t child_pid = fork(); - if (child_pid == 0) { - TEST_PCHECK(semctl(sem.get(), 0, SETVAL, 1) == 0); - TEST_PCHECK(semctl(sem.get(), 0, GETPID) == getpid()); - - _exit(0); - } - ASSERT_THAT(child_pid, SyscallSucceeds()); - - int status; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; -} - -TEST(SemaphoreTest, SemIpcSet) { - // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); - - AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); - ASSERT_THAT(sem.get(), SyscallSucceeds()); - - struct semid_ds semid = {}; - semid.sem_perm.uid = getuid(); - semid.sem_perm.gid = getgid(); - - // Make semaphore readonly and check that signal fails. - semid.sem_perm.mode = 0400; - EXPECT_THAT(semctl(sem.get(), 0, IPC_SET, &semid), SyscallSucceeds()); - struct sembuf buf = {}; - buf.sem_op = 1; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EACCES)); - - // Make semaphore writeonly and check that wait for zero fails. - semid.sem_perm.mode = 0200; - EXPECT_THAT(semctl(sem.get(), 0, IPC_SET, &semid), SyscallSucceeds()); - buf.sem_op = 0; - ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EACCES)); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc deleted file mode 100644 index 580ab5193..000000000 --- a/test/syscalls/linux/sendfile.cc +++ /dev/null @@ -1,536 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/eventfd.h> -#include <sys/sendfile.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/eventfd_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(SendFileTest, SendZeroBytes) { - // Create temp files. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Send data and verify that sendfile returns the correct value. - EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0), - SyscallSucceedsWithValue(0)); -} - -TEST(SendFileTest, InvalidOffset) { - // Create temp files. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Send data and verify that sendfile returns the correct value. - off_t offset = -1; - EXPECT_THAT(sendfile(outf.get(), inf.get(), &offset, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SendFileTest, SendTrivially) { - // Create temp files. - constexpr char kData[] = "To be, or not to be, that is the question:"; - constexpr int kDataSize = sizeof(kData) - 1; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - FileDescriptor outf; - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Send data and verify that sendfile returns the correct value. - int bytes_sent; - EXPECT_THAT(bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kDataSize), - SyscallSucceedsWithValue(kDataSize)); - - // Close outf to avoid leak. - outf.reset(); - - // Open the output file as read only. - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY)); - - // Verify that the output file has the correct data. - char actual[kDataSize]; - ASSERT_THAT(read(outf.get(), &actual, bytes_sent), - SyscallSucceedsWithValue(kDataSize)); - EXPECT_EQ(kData, absl::string_view(actual, bytes_sent)); -} - -TEST(SendFileTest, SendTriviallyWithBothFilesReadWrite) { - // Create temp files. - constexpr char kData[] = "Whether 'tis nobler in the mind to suffer"; - constexpr int kDataSize = sizeof(kData) - 1; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as readwrite. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - - // Open the output file as readwrite. - FileDescriptor outf; - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); - - // Send data and verify that sendfile returns the correct value. - int bytes_sent; - EXPECT_THAT(bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kDataSize), - SyscallSucceedsWithValue(kDataSize)); - - // Close outf to avoid leak. - outf.reset(); - - // Open the output file as read only. - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY)); - - // Verify that the output file has the correct data. - char actual[kDataSize]; - ASSERT_THAT(read(outf.get(), &actual, bytes_sent), - SyscallSucceedsWithValue(kDataSize)); - EXPECT_EQ(kData, absl::string_view(actual, bytes_sent)); -} - -TEST(SendFileTest, SendAndUpdateFileOffset) { - // Create temp files. - // Test input string length must be > 2 AND even. - constexpr char kData[] = "The slings and arrows of outrageous fortune,"; - constexpr int kDataSize = sizeof(kData) - 1; - constexpr int kHalfDataSize = kDataSize / 2; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - FileDescriptor outf; - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Send data and verify that sendfile returns the correct value. - int bytes_sent; - EXPECT_THAT( - bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize), - SyscallSucceedsWithValue(kHalfDataSize)); - - // Close outf to avoid leak. - outf.reset(); - - // Open the output file as read only. - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY)); - - // Verify that the output file has the correct data. - char actual[kHalfDataSize]; - ASSERT_THAT(read(outf.get(), &actual, bytes_sent), - SyscallSucceedsWithValue(kHalfDataSize)); - EXPECT_EQ(absl::string_view(kData, kHalfDataSize), - absl::string_view(actual, bytes_sent)); - - // Verify that the input file offset has been updated - ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent), - SyscallSucceedsWithValue(kHalfDataSize)); - EXPECT_EQ( - absl::string_view(kData + kDataSize - bytes_sent, kDataSize - bytes_sent), - absl::string_view(actual, kHalfDataSize)); -} - -TEST(SendFileTest, SendAndUpdateFileOffsetFromNonzeroStartingPoint) { - // Create temp files. - // Test input string length must be > 2 AND divisible by 4. - constexpr char kData[] = "The slings and arrows of outrageous fortune,"; - constexpr int kDataSize = sizeof(kData) - 1; - constexpr int kHalfDataSize = kDataSize / 2; - constexpr int kQuarterDataSize = kHalfDataSize / 2; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - FileDescriptor outf; - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Read a quarter of the data from the infile which should update the file - // offset, we don't actually care about the data so it goes into the garbage. - char garbage[kQuarterDataSize]; - ASSERT_THAT(read(inf.get(), &garbage, kQuarterDataSize), - SyscallSucceedsWithValue(kQuarterDataSize)); - - // Send data and verify that sendfile returns the correct value. - int bytes_sent; - EXPECT_THAT( - bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize), - SyscallSucceedsWithValue(kHalfDataSize)); - - // Close out_fd to avoid leak. - outf.reset(); - - // Open the output file as read only. - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY)); - - // Verify that the output file has the correct data. - char actual[kHalfDataSize]; - ASSERT_THAT(read(outf.get(), &actual, bytes_sent), - SyscallSucceedsWithValue(kHalfDataSize)); - EXPECT_EQ(absl::string_view(kData + kQuarterDataSize, kHalfDataSize), - absl::string_view(actual, bytes_sent)); - - // Verify that the input file offset has been updated - ASSERT_THAT(read(inf.get(), &actual, kQuarterDataSize), - SyscallSucceedsWithValue(kQuarterDataSize)); - - EXPECT_EQ( - absl::string_view(kData + kDataSize - kQuarterDataSize, kQuarterDataSize), - absl::string_view(actual, kQuarterDataSize)); -} - -TEST(SendFileTest, SendAndUpdateGivenOffset) { - // Create temp files. - // Test input string length must be >= 4 AND divisible by 4. - constexpr char kData[] = "Or to take Arms against a Sea of troubles,"; - constexpr int kDataSize = sizeof(kData) + 1; - constexpr int kHalfDataSize = kDataSize / 2; - constexpr int kQuarterDataSize = kHalfDataSize / 2; - constexpr int kThreeFourthsDataSize = 3 * kDataSize / 4; - - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - FileDescriptor outf; - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Create offset for sending. - off_t offset = kQuarterDataSize; - - // Send data and verify that sendfile returns the correct value. - int bytes_sent; - EXPECT_THAT( - bytes_sent = sendfile(outf.get(), inf.get(), &offset, kHalfDataSize), - SyscallSucceedsWithValue(kHalfDataSize)); - - // Close out_fd to avoid leak. - outf.reset(); - - // Open the output file as read only. - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY)); - - // Verify that the output file has the correct data. - char actual[kHalfDataSize]; - ASSERT_THAT(read(outf.get(), &actual, bytes_sent), - SyscallSucceedsWithValue(kHalfDataSize)); - EXPECT_EQ(absl::string_view(kData + kQuarterDataSize, kHalfDataSize), - absl::string_view(actual, bytes_sent)); - - // Verify that the input file offset has NOT been updated. - ASSERT_THAT(read(inf.get(), &actual, kHalfDataSize), - SyscallSucceedsWithValue(kHalfDataSize)); - EXPECT_EQ(absl::string_view(kData, kHalfDataSize), - absl::string_view(actual, kHalfDataSize)); - - // Verify that the offset pointer has been updated. - EXPECT_EQ(offset, kThreeFourthsDataSize); -} - -TEST(SendFileTest, DoNotSendfileIfOutfileIsAppendOnly) { - // Create temp files. - constexpr char kData[] = "And by opposing end them: to die, to sleep"; - constexpr int kDataSize = sizeof(kData) - 1; - - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as append only. - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY | O_APPEND)); - - // Send data and verify that sendfile returns the correct errno. - EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SendFileTest, AppendCheckOrdering) { - constexpr char kData[] = "And by opposing end them: to die, to sleep"; - constexpr int kDataSize = sizeof(kData) - 1; - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - - const FileDescriptor read = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - const FileDescriptor write = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); - const FileDescriptor append = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_APPEND)); - - // Check that read/write file mode is verified before append. - EXPECT_THAT(sendfile(append.get(), read.get(), nullptr, kDataSize), - SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(sendfile(write.get(), write.get(), nullptr, kDataSize), - SyscallFailsWithErrno(EBADF)); -} - -TEST(SendFileTest, DoNotSendfileIfOutfileIsNotWritable) { - // Create temp files. - constexpr char kData[] = "No more; and by a sleep, to say we end"; - constexpr int kDataSize = sizeof(kData) - 1; - - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as read only. - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY)); - - // Send data and verify that sendfile returns the correct errno. - EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize), - SyscallFailsWithErrno(EBADF)); -} - -TEST(SendFileTest, DoNotSendfileIfInfileIsNotReadable) { - // Create temp files. - constexpr char kData[] = "the heart-ache, and the thousand natural shocks"; - constexpr int kDataSize = sizeof(kData) - 1; - - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as write only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_WRONLY)); - - // Open the output file as write only. - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Send data and verify that sendfile returns the correct errno. - EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize), - SyscallFailsWithErrno(EBADF)); -} - -TEST(SendFileTest, DoNotSendANegativeNumberOfBytes) { - // Create temp files. - constexpr char kData[] = "that Flesh is heir to? 'Tis a consummation"; - - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Send data and verify that sendfile returns the correct errno. - EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, -1), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SendFileTest, SendTheCorrectNumberOfBytesEvenIfWeTryToSendTooManyBytes) { - // Create temp files. - constexpr char kData[] = "devoutly to be wished. To die, to sleep,"; - constexpr int kDataSize = sizeof(kData) - 1; - - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - FileDescriptor outf; - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Send data and verify that sendfile returns the correct value. - int bytes_sent; - EXPECT_THAT( - bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kDataSize + 100), - SyscallSucceedsWithValue(kDataSize)); - - // Close outf to avoid leak. - outf.reset(); - - // Open the output file as read only. - outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY)); - - // Verify that the output file has the correct data. - char actual[kDataSize]; - ASSERT_THAT(read(outf.get(), &actual, bytes_sent), - SyscallSucceedsWithValue(kDataSize)); - EXPECT_EQ(kData, absl::string_view(actual, bytes_sent)); -} - -TEST(SendFileTest, SendToNotARegularFile) { - // Make temp input directory and open as read only. - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY)); - - // Make temp output file and open as write only. - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Receive an error since a directory is not a regular file. - EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SendFileTest, SendPipeWouldBlock) { - // Create temp file. - constexpr char kData[] = - "The fool doth think he is wise, but the wise man knows himself to be a " - "fool."; - constexpr int kDataSize = sizeof(kData) - 1; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Setup the output named pipe. - int fds[2]; - ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill up the pipe's buffer. - int pipe_size = -1; - ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); - std::vector<char> buf(2 * pipe_size); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(pipe_size)); - - EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST(SendFileTest, SendPipeBlocks) { - // Create temp file. - constexpr char kData[] = - "The fault, dear Brutus, is not in our stars, but in ourselves."; - constexpr int kDataSize = sizeof(kData) - 1; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Setup the output named pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill up the pipe's buffer. - int pipe_size = -1; - ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); - std::vector<char> buf(pipe_size); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(pipe_size)); - - ScopedThread t([&]() { - absl::SleepFor(absl::Milliseconds(100)); - ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(pipe_size)); - }); - - EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize), - SyscallSucceedsWithValue(kDataSize)); -} - -TEST(SendFileTest, SendToSpecialFile) { - // Create temp file. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - constexpr int kSize = 0x7ff; - ASSERT_THAT(ftruncate(inf.get(), kSize), SyscallSucceeds()); - - auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); - - // eventfd can accept a number of bytes which is a multiple of 8. - EXPECT_THAT(sendfile(eventfd.get(), inf.get(), nullptr, 0xfffff), - SyscallSucceedsWithValue(kSize & (~7))); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc deleted file mode 100644 index 8f7ee4163..000000000 --- a/test/syscalls/linux/sendfile_socket.cc +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <netinet/in.h> -#include <sys/sendfile.h> -#include <sys/socket.h> -#include <unistd.h> - -#include <iostream> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { -namespace { - -class SendFileTest : public ::testing::TestWithParam<int> { - protected: - PosixErrorOr<std::tuple<int, int>> Sockets() { - // Bind a server socket. - int family = GetParam(); - struct sockaddr server_addr = {}; - switch (family) { - case AF_INET: { - struct sockaddr_in* server_addr_in = - reinterpret_cast<struct sockaddr_in*>(&server_addr); - server_addr_in->sin_family = family; - server_addr_in->sin_addr.s_addr = INADDR_ANY; - break; - } - case AF_UNIX: { - struct sockaddr_un* server_addr_un = - reinterpret_cast<struct sockaddr_un*>(&server_addr); - server_addr_un->sun_family = family; - server_addr_un->sun_path[0] = '\0'; - break; - } - default: - return PosixError(EINVAL); - } - int server = socket(family, SOCK_STREAM, 0); - if (bind(server, &server_addr, sizeof(server_addr)) < 0) { - return PosixError(errno); - } - if (listen(server, 1) < 0) { - close(server); - return PosixError(errno); - } - - // Fetch the address; both are anonymous. - socklen_t length = sizeof(server_addr); - if (getsockname(server, &server_addr, &length) < 0) { - close(server); - return PosixError(errno); - } - - // Connect the client. - int client = socket(family, SOCK_STREAM, 0); - if (connect(client, &server_addr, length) < 0) { - close(server); - close(client); - return PosixError(errno); - } - - // Accept on the server. - int server_client = accept(server, nullptr, 0); - if (server_client < 0) { - close(server); - close(client); - return PosixError(errno); - } - close(server); - return std::make_tuple(client, server_client); - } -}; - -// Sends large file to exercise the path that read and writes data multiple -// times, esp. when more data is read than can be written. -TEST_P(SendFileTest, SendMultiple) { - std::vector<char> data(5 * 1024 * 1024); - RandomizeBuffer(data.data(), data.size()); - - // Create temp files. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::string_view(data.data(), data.size()), - TempPath::kDefaultFileMode)); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Create sockets. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor server(std::get<0>(fds)); - FileDescriptor client(std::get<1>(fds)); // non-const, reset is used. - - // Thread that reads data from socket and dumps to a file. - ScopedThread th([&] { - FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Read until socket is closed. - char buf[10240]; - for (int cnt = 0;; cnt++) { - int r = RetryEINTR(read)(server.get(), buf, sizeof(buf)); - // We cannot afford to save on every read() call. - if (cnt % 1000 == 0) { - ASSERT_THAT(r, SyscallSucceeds()); - } else { - const DisableSave ds; - ASSERT_THAT(r, SyscallSucceeds()); - } - if (r == 0) { - // EOF - break; - } - int w = RetryEINTR(write)(outf.get(), buf, r); - // We cannot afford to save on every write() call. - if (cnt % 1010 == 0) { - ASSERT_THAT(w, SyscallSucceedsWithValue(r)); - } else { - const DisableSave ds; - ASSERT_THAT(w, SyscallSucceedsWithValue(r)); - } - } - }); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - int cnt = 0; - for (size_t sent = 0; sent < data.size(); cnt++) { - const size_t remain = data.size() - sent; - std::cout << "sendfile, size=" << data.size() << ", sent=" << sent - << ", remain=" << remain; - - // Send data and verify that sendfile returns the correct value. - int res = sendfile(client.get(), inf.get(), nullptr, remain); - // We cannot afford to save on every sendfile() call. - if (cnt % 120 == 0) { - MaybeSave(); - } - if (res == 0) { - // EOF - break; - } - if (res > 0) { - sent += res; - } else { - ASSERT_TRUE(errno == EINTR || errno == EAGAIN) << "errno=" << errno; - } - } - - // Close socket to stop thread. - client.reset(); - th.Join(); - - // Verify that the output file has the correct data. - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY)); - std::vector<char> actual(data.size(), '\0'); - ASSERT_THAT(RetryEINTR(read)(outf.get(), actual.data(), actual.size()), - SyscallSucceedsWithValue(actual.size())); - ASSERT_EQ(memcmp(data.data(), actual.data(), data.size()), 0); -} - -TEST_P(SendFileTest, Shutdown) { - // Create a socket. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor client(std::get<0>(fds)); - FileDescriptor server(std::get<1>(fds)); // non-const, reset below. - - // If this is a TCP socket, then turn off linger. - if (GetParam() == AF_INET) { - struct linger sl; - sl.l_onoff = 1; - sl.l_linger = 0; - ASSERT_THAT( - setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), - SyscallSucceeds()); - } - - // Create a 1m file with random data. - std::vector<char> data(1024 * 1024); - RandomizeBuffer(data.data(), data.size()); - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::string_view(data.data(), data.size()), - TempPath::kDefaultFileMode)); - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Read some data, then shutdown the socket. We don't actually care about - // checking the contents (other tests do that), so we just re-use the same - // buffer as above. - ScopedThread t([&]() { - size_t done = 0; - while (done < data.size()) { - int n = RetryEINTR(read)(server.get(), data.data(), data.size()); - ASSERT_THAT(n, SyscallSucceeds()); - done += n; - } - // Close the server side socket. - server.reset(); - }); - - // Continuously stream from the file to the socket. Note we do not assert - // that a specific amount of data has been written at any time, just that some - // data is written. Eventually, we should get a connection reset error. - while (1) { - off_t offset = 0; // Always read from the start. - int n = sendfile(client.get(), inf.get(), &offset, data.size()); - EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET), - SyscallFailsWithErrno(EPIPE), SyscallSucceeds())); - if (n <= 0) { - break; - } - } -} - -INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest, - ::testing::Values(AF_UNIX, AF_INET)); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc deleted file mode 100644 index c7fdbb924..000000000 --- a/test/syscalls/linux/shm.cc +++ /dev/null @@ -1,508 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/ipc.h> -#include <sys/mman.h> -#include <sys/shm.h> -#include <sys/types.h> - -#include "absl/time/clock.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -using ::testing::_; - -const uint64_t kAllocSize = kPageSize * 128ULL; - -PosixErrorOr<char*> Shmat(int shmid, const void* shmaddr, int shmflg) { - const intptr_t addr = - reinterpret_cast<intptr_t>(shmat(shmid, shmaddr, shmflg)); - if (addr == -1) { - return PosixError(errno, "shmat() failed"); - } - return reinterpret_cast<char*>(addr); -} - -PosixError Shmdt(const char* shmaddr) { - const int ret = shmdt(shmaddr); - if (ret == -1) { - return PosixError(errno, "shmdt() failed"); - } - return NoError(); -} - -template <typename T> -PosixErrorOr<int> Shmctl(int shmid, int cmd, T* buf) { - int ret = shmctl(shmid, cmd, reinterpret_cast<struct shmid_ds*>(buf)); - if (ret == -1) { - return PosixError(errno, "shmctl() failed"); - } - return ret; -} - -// ShmSegment is a RAII object for automatically cleaning up shm segments. -class ShmSegment { - public: - explicit ShmSegment(int id) : id_(id) {} - - ~ShmSegment() { - if (id_ >= 0) { - EXPECT_NO_ERRNO(Rmid()); - id_ = -1; - } - } - - ShmSegment(ShmSegment&& other) : id_(other.release()) {} - - ShmSegment& operator=(ShmSegment&& other) { - id_ = other.release(); - return *this; - } - - ShmSegment(ShmSegment const& other) = delete; - ShmSegment& operator=(ShmSegment const& other) = delete; - - int id() const { return id_; } - - int release() { - int id = id_; - id_ = -1; - return id; - } - - PosixErrorOr<int> Rmid() { - RETURN_IF_ERRNO(Shmctl<void>(id_, IPC_RMID, nullptr)); - return release(); - } - - private: - int id_ = -1; -}; - -PosixErrorOr<int> ShmgetRaw(key_t key, size_t size, int shmflg) { - int id = shmget(key, size, shmflg); - if (id == -1) { - return PosixError(errno, "shmget() failed"); - } - return id; -} - -PosixErrorOr<ShmSegment> Shmget(key_t key, size_t size, int shmflg) { - ASSIGN_OR_RETURN_ERRNO(int id, ShmgetRaw(key, size, shmflg)); - return ShmSegment(id); -} - -TEST(ShmTest, AttachDetach) { - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - struct shmid_ds attr; - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - EXPECT_EQ(attr.shm_segsz, kAllocSize); - EXPECT_EQ(attr.shm_nattch, 0); - - const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - EXPECT_EQ(attr.shm_nattch, 1); - - const char* addr2 = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - EXPECT_EQ(attr.shm_nattch, 2); - - ASSERT_NO_ERRNO(Shmdt(addr)); - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - EXPECT_EQ(attr.shm_nattch, 1); - - ASSERT_NO_ERRNO(Shmdt(addr2)); - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - EXPECT_EQ(attr.shm_nattch, 0); -} - -TEST(ShmTest, LookupByKey) { - const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const key_t key = ftok(keyfile.path().c_str(), 1); - const ShmSegment shm = - ASSERT_NO_ERRNO_AND_VALUE(Shmget(key, kAllocSize, IPC_CREAT | 0777)); - const int id2 = ASSERT_NO_ERRNO_AND_VALUE(ShmgetRaw(key, kAllocSize, 0777)); - EXPECT_EQ(shm.id(), id2); -} - -TEST(ShmTest, DetachedSegmentsPersist) { - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - addr[0] = 'x'; - ASSERT_NO_ERRNO(Shmdt(addr)); - - // We should be able to re-attach to the same segment and get our data back. - addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - EXPECT_EQ(addr[0], 'x'); - ASSERT_NO_ERRNO(Shmdt(addr)); -} - -TEST(ShmTest, MultipleDetachFails) { - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - ASSERT_NO_ERRNO(Shmdt(addr)); - EXPECT_THAT(Shmdt(addr), PosixErrorIs(EINVAL, _)); -} - -TEST(ShmTest, IpcStat) { - const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const key_t key = ftok(keyfile.path().c_str(), 1); - - const time_t start = time(nullptr); - - const ShmSegment shm = - ASSERT_NO_ERRNO_AND_VALUE(Shmget(key, kAllocSize, IPC_CREAT | 0777)); - - const uid_t uid = getuid(); - const gid_t gid = getgid(); - const pid_t pid = getpid(); - - struct shmid_ds attr; - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - - EXPECT_EQ(attr.shm_perm.__key, key); - EXPECT_EQ(attr.shm_perm.uid, uid); - EXPECT_EQ(attr.shm_perm.gid, gid); - EXPECT_EQ(attr.shm_perm.cuid, uid); - EXPECT_EQ(attr.shm_perm.cgid, gid); - EXPECT_EQ(attr.shm_perm.mode, 0777); - - EXPECT_EQ(attr.shm_segsz, kAllocSize); - - EXPECT_EQ(attr.shm_atime, 0); - EXPECT_EQ(attr.shm_dtime, 0); - - // Change time is set on creation. - EXPECT_GE(attr.shm_ctime, start); - - EXPECT_EQ(attr.shm_cpid, pid); - EXPECT_EQ(attr.shm_lpid, 0); - - EXPECT_EQ(attr.shm_nattch, 0); - - // The timestamps only have a resolution of seconds; slow down so we actually - // see the timestamps change. - absl::SleepFor(absl::Seconds(1)); - const time_t pre_attach = time(nullptr); - - const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - - EXPECT_GE(attr.shm_atime, pre_attach); - EXPECT_EQ(attr.shm_dtime, 0); - EXPECT_LT(attr.shm_ctime, pre_attach); - EXPECT_EQ(attr.shm_lpid, pid); - EXPECT_EQ(attr.shm_nattch, 1); - - absl::SleepFor(absl::Seconds(1)); - const time_t pre_detach = time(nullptr); - - ASSERT_NO_ERRNO(Shmdt(addr)); - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - - EXPECT_LT(attr.shm_atime, pre_detach); - EXPECT_GE(attr.shm_dtime, pre_detach); - EXPECT_LT(attr.shm_ctime, pre_detach); - EXPECT_EQ(attr.shm_lpid, pid); - EXPECT_EQ(attr.shm_nattch, 0); -} - -TEST(ShmTest, ShmStat) { - // This test relies on the segment we create to be the first one on the - // system, causing it to occupy slot 1. We can't reasonably expect this on a - // general Linux host. - SKIP_IF(!IsRunningOnGvisor()); - - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - struct shmid_ds attr; - ASSERT_NO_ERRNO(Shmctl(1, SHM_STAT, &attr)); - // This does the same thing as IPC_STAT, so only test that the syscall - // succeeds here. -} - -TEST(ShmTest, IpcInfo) { - struct shminfo info; - ASSERT_NO_ERRNO(Shmctl(0, IPC_INFO, &info)); - - EXPECT_EQ(info.shmmin, 1); // This is always 1, according to the man page. - EXPECT_GT(info.shmmax, info.shmmin); - EXPECT_GT(info.shmmni, 0); - EXPECT_GT(info.shmseg, 0); - EXPECT_GT(info.shmall, 0); -} - -TEST(ShmTest, ShmInfo) { - struct shm_info info; - - // We generally can't know what other processes on a linux machine - // does with shared memory segments, so we can't test specific - // numbers on Linux. When running under gvisor, we're guaranteed to - // be the only ones using shm, so we can easily verify machine-wide - // numbers. - if (IsRunningOnGvisor()) { - ASSERT_NO_ERRNO(Shmctl(0, SHM_INFO, &info)); - EXPECT_EQ(info.used_ids, 0); - EXPECT_EQ(info.shm_tot, 0); - EXPECT_EQ(info.shm_rss, 0); - EXPECT_EQ(info.shm_swp, 0); - } - - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - - ASSERT_NO_ERRNO(Shmctl(1, SHM_INFO, &info)); - - if (IsRunningOnGvisor()) { - ASSERT_NO_ERRNO(Shmctl(shm.id(), SHM_INFO, &info)); - EXPECT_EQ(info.used_ids, 1); - EXPECT_EQ(info.shm_tot, kAllocSize / kPageSize); - EXPECT_EQ(info.shm_rss, kAllocSize / kPageSize); - EXPECT_EQ(info.shm_swp, 0); // Gvisor currently never swaps. - } - - ASSERT_NO_ERRNO(Shmdt(addr)); -} - -TEST(ShmTest, ShmCtlSet) { - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - - struct shmid_ds attr; - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - ASSERT_EQ(attr.shm_perm.mode, 0777); - - attr.shm_perm.mode = 0766; - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_SET, &attr)); - - ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr)); - ASSERT_EQ(attr.shm_perm.mode, 0766); - - ASSERT_NO_ERRNO(Shmdt(addr)); -} - -TEST(ShmTest, RemovedSegmentsAreMarkedDeleted) { - ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - const int id = ASSERT_NO_ERRNO_AND_VALUE(shm.Rmid()); - struct shmid_ds attr; - ASSERT_NO_ERRNO(Shmctl(id, IPC_STAT, &attr)); - EXPECT_NE(attr.shm_perm.mode & SHM_DEST, 0); - ASSERT_NO_ERRNO(Shmdt(addr)); -} - -TEST(ShmTest, RemovedSegmentsAreDestroyed) { - ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - - const uint64_t alloc_pages = kAllocSize / kPageSize; - - struct shm_info info; - ASSERT_NO_ERRNO(Shmctl(0 /*ignored*/, SHM_INFO, &info)); - const uint64_t before = info.shm_tot; - - ASSERT_NO_ERRNO(shm.Rmid()); - ASSERT_NO_ERRNO(Shmdt(addr)); - - ASSERT_NO_ERRNO(Shmctl(0 /*ignored*/, SHM_INFO, &info)); - if (IsRunningOnGvisor()) { - // No guarantees on system-wide shm memory usage on a generic linux host. - const uint64_t after = info.shm_tot; - EXPECT_EQ(after, before - alloc_pages); - } -} - -TEST(ShmTest, AllowsAttachToRemovedSegmentWithRefs) { - ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - const int id = ASSERT_NO_ERRNO_AND_VALUE(shm.Rmid()); - const char* addr2 = ASSERT_NO_ERRNO_AND_VALUE(Shmat(id, nullptr, 0)); - ASSERT_NO_ERRNO(Shmdt(addr)); - ASSERT_NO_ERRNO(Shmdt(addr2)); -} - -TEST(ShmTest, RemovedSegmentsAreNotDiscoverable) { - const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const key_t key = ftok(keyfile.path().c_str(), 1); - ShmSegment shm = - ASSERT_NO_ERRNO_AND_VALUE(Shmget(key, kAllocSize, IPC_CREAT | 0777)); - ASSERT_NO_ERRNO(shm.Rmid()); - EXPECT_THAT(Shmget(key, kAllocSize, 0777), PosixErrorIs(ENOENT, _)); -} - -TEST(ShmDeathTest, ReadonlySegment) { - SetupGvisorDeathTest(); - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, SHM_RDONLY)); - // Reading succeeds. - static_cast<void>(addr[0]); - // Writing fails. - EXPECT_EXIT(addr[0] = 'x', ::testing::KilledBySignal(SIGSEGV), ""); -} - -TEST(ShmDeathTest, SegmentNotAccessibleAfterDetach) { - // This test is susceptible to races with concurrent mmaps running in parallel - // gtest threads since the test relies on the address freed during a shm - // segment destruction to remain unused. We run the test body in a forked - // child to guarantee a single-threaded context to avoid this. - - SetupGvisorDeathTest(); - - const auto rest = [&] { - ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - - // Mark the segment as destroyed so it's automatically cleaned up when we - // crash below. We can't rely on the standard cleanup since the destructor - // will not run after the SIGSEGV. Note that this doesn't destroy the - // segment immediately since we're still attached to it. - ASSERT_NO_ERRNO(shm.Rmid()); - - addr[0] = 'x'; - ASSERT_NO_ERRNO(Shmdt(addr)); - - // This access should cause a SIGSEGV. - addr[0] = 'x'; - }; - - EXPECT_THAT(InForkedProcess(rest), - IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV))); -} - -TEST(ShmTest, RequestingSegmentSmallerThanSHMMINFails) { - struct shminfo info; - ASSERT_NO_ERRNO(Shmctl(0, IPC_INFO, &info)); - const uint64_t size = info.shmmin - 1; - EXPECT_THAT(Shmget(IPC_PRIVATE, size, IPC_CREAT | 0777), - PosixErrorIs(EINVAL, _)); -} - -TEST(ShmTest, RequestingSegmentLargerThanSHMMAXFails) { - struct shminfo info; - ASSERT_NO_ERRNO(Shmctl(0, IPC_INFO, &info)); - const uint64_t size = info.shmmax + kPageSize; - EXPECT_THAT(Shmget(IPC_PRIVATE, size, IPC_CREAT | 0777), - PosixErrorIs(EINVAL, _)); -} - -TEST(ShmTest, RequestingUnalignedSizeSucceeds) { - EXPECT_NO_ERRNO(Shmget(IPC_PRIVATE, 4097, IPC_CREAT | 0777)); -} - -TEST(ShmTest, RequestingDuplicateCreationFails) { - const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const key_t key = ftok(keyfile.path().c_str(), 1); - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(key, kAllocSize, IPC_CREAT | IPC_EXCL | 0777)); - EXPECT_THAT(Shmget(key, kAllocSize, IPC_CREAT | IPC_EXCL | 0777), - PosixErrorIs(EEXIST, _)); -} - -TEST(ShmTest, NonExistentSegmentsAreNotFound) { - const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const key_t key = ftok(keyfile.path().c_str(), 1); - // Do not request creation. - EXPECT_THAT(Shmget(key, kAllocSize, 0777), PosixErrorIs(ENOENT, _)); -} - -TEST(ShmTest, SegmentsSizeFixedOnCreation) { - const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const key_t key = ftok(keyfile.path().c_str(), 1); - - // Base segment. - const ShmSegment shm = - ASSERT_NO_ERRNO_AND_VALUE(Shmget(key, kAllocSize, IPC_CREAT | 0777)); - - // Ask for the same segment at half size. This succeeds. - const int id2 = - ASSERT_NO_ERRNO_AND_VALUE(ShmgetRaw(key, kAllocSize / 2, 0777)); - - // Ask for the same segment at double size. - EXPECT_THAT(Shmget(key, kAllocSize * 2, 0777), PosixErrorIs(EINVAL, _)); - - char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - char* addr2 = ASSERT_NO_ERRNO_AND_VALUE(Shmat(id2, nullptr, 0)); - - // We have 2 different maps... - EXPECT_NE(addr, addr2); - - // ... And both maps are kAllocSize bytes; despite asking for a half-sized - // segment for the second map. - addr[kAllocSize - 1] = 'x'; - addr2[kAllocSize - 1] = 'x'; - - ASSERT_NO_ERRNO(Shmdt(addr)); - ASSERT_NO_ERRNO(Shmdt(addr2)); -} - -TEST(ShmTest, PartialUnmap) { - const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - EXPECT_THAT(munmap(addr + (kAllocSize / 4), kAllocSize / 2), - SyscallSucceeds()); - ASSERT_NO_ERRNO(Shmdt(addr)); -} - -// Check that sentry does not panic when asked for a zero-length private shm -// segment. Regression test for b/110694797. -TEST(ShmTest, GracefullyFailOnZeroLenSegmentCreation) { - EXPECT_THAT(Shmget(IPC_PRIVATE, 0, 0), PosixErrorIs(EINVAL, _)); -} - -TEST(ShmTest, NoDestructionOfAttachedSegmentWithMultipleRmid) { - ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( - Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - char* addr2 = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); - - // There should be 2 refs to the segment from the 2 attachments, and a single - // self-reference. Mark the segment as destroyed more than 3 times through - // shmctl(RMID). If there's a bug with the ref counting, this should cause the - // count to drop to zero. - int id = shm.release(); - for (int i = 0; i < 6; ++i) { - ASSERT_NO_ERRNO(Shmctl<void>(id, IPC_RMID, nullptr)); - } - - // Segment should remain accessible. - addr[0] = 'x'; - ASSERT_NO_ERRNO(Shmdt(addr)); - - // Segment should remain accessible even after one of the two attachments are - // detached. - addr2[0] = 'x'; - ASSERT_NO_ERRNO(Shmdt(addr2)); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sigaction.cc b/test/syscalls/linux/sigaction.cc deleted file mode 100644 index 9d9dd57a8..000000000 --- a/test/syscalls/linux/sigaction.cc +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <sys/syscall.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(SigactionTest, GetLessThanOrEqualToZeroFails) { - struct sigaction act = {}; - ASSERT_THAT(sigaction(-1, nullptr, &act), SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT(sigaction(0, nullptr, &act), SyscallFailsWithErrno(EINVAL)); -} - -TEST(SigactionTest, SetLessThanOrEqualToZeroFails) { - struct sigaction act = {}; - ASSERT_THAT(sigaction(0, &act, nullptr), SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT(sigaction(0, &act, nullptr), SyscallFailsWithErrno(EINVAL)); -} - -TEST(SigactionTest, GetGreaterThanMaxFails) { - struct sigaction act = {}; - ASSERT_THAT(sigaction(SIGRTMAX + 1, nullptr, &act), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SigactionTest, SetGreaterThanMaxFails) { - struct sigaction act = {}; - ASSERT_THAT(sigaction(SIGRTMAX + 1, &act, nullptr), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SigactionTest, SetSigkillFails) { - struct sigaction act = {}; - ASSERT_THAT(sigaction(SIGKILL, nullptr, &act), SyscallSucceeds()); - ASSERT_THAT(sigaction(SIGKILL, &act, nullptr), SyscallFailsWithErrno(EINVAL)); -} - -TEST(SigactionTest, SetSigstopFails) { - struct sigaction act = {}; - ASSERT_THAT(sigaction(SIGSTOP, nullptr, &act), SyscallSucceeds()); - ASSERT_THAT(sigaction(SIGSTOP, &act, nullptr), SyscallFailsWithErrno(EINVAL)); -} - -TEST(SigactionTest, BadSigsetFails) { - constexpr size_t kWrongSigSetSize = 43; - - struct sigaction act = {}; - - // The syscall itself (rather than the libc wrapper) takes the sigset_t size. - ASSERT_THAT( - syscall(SYS_rt_sigaction, SIGTERM, nullptr, &act, kWrongSigSetSize), - SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT( - syscall(SYS_rt_sigaction, SIGTERM, &act, nullptr, kWrongSigSetSize), - SyscallFailsWithErrno(EINVAL)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc deleted file mode 100644 index 24e7c4960..000000000 --- a/test/syscalls/linux/sigaltstack.cc +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <signal.h> -#include <stdio.h> -#include <string.h> -#include <unistd.h> - -#include <functional> -#include <vector> - -#include "gtest/gtest.h" -#include "test/util/cleanup.h" -#include "test/util/fs_util.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -PosixErrorOr<Cleanup> ScopedSigaltstack(stack_t const& stack) { - stack_t old_stack; - int rc = sigaltstack(&stack, &old_stack); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "sigaltstack failed"); - } - return Cleanup([old_stack] { - EXPECT_THAT(sigaltstack(&old_stack, nullptr), SyscallSucceeds()); - }); -} - -volatile bool got_signal = false; -volatile int sigaltstack_errno = 0; -volatile int ss_flags = 0; - -void sigaltstack_handler(int sig, siginfo_t* siginfo, void* arg) { - got_signal = true; - - stack_t stack; - int ret = sigaltstack(nullptr, &stack); - MaybeSave(); - if (ret < 0) { - sigaltstack_errno = errno; - return; - } - ss_flags = stack.ss_flags; -} - -TEST(SigaltstackTest, Success) { - std::vector<char> stack_mem(SIGSTKSZ); - stack_t stack = {}; - stack.ss_sp = stack_mem.data(); - stack.ss_size = stack_mem.size(); - auto const cleanup_sigstack = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack)); - - struct sigaction sa = {}; - sa.sa_sigaction = sigaltstack_handler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO | SA_ONSTACK; - auto const cleanup_sa = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa)); - - // Send signal to this thread, as sigaltstack is per-thread. - EXPECT_THAT(tgkill(getpid(), gettid(), SIGUSR1), SyscallSucceeds()); - - EXPECT_TRUE(got_signal); - EXPECT_EQ(sigaltstack_errno, 0); - EXPECT_NE(0, ss_flags & SS_ONSTACK); -} - -TEST(SigaltstackTest, ResetByExecve) { - std::vector<char> stack_mem(SIGSTKSZ); - stack_t stack = {}; - stack.ss_sp = stack_mem.data(); - stack.ss_size = stack_mem.size(); - auto const cleanup_sigstack = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack)); - - std::string full_path = RunfilePath("test/syscalls/linux/sigaltstack_check"); - - pid_t child_pid = -1; - int execve_errno = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(full_path, {"sigaltstack_check"}, {}, nullptr, &child_pid, - &execve_errno)); - - ASSERT_GT(child_pid, 0); - ASSERT_EQ(execve_errno, 0); - - int status = 0; - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - ASSERT_TRUE(WIFEXITED(status)); - ASSERT_EQ(WEXITSTATUS(status), 0); -} - -volatile bool badhandler_on_sigaltstack = true; // Set by the handler. -char* volatile badhandler_low_water_mark = nullptr; // Set by the handler. -volatile uint8_t badhandler_recursive_faults = 0; // Consumed by the handler. - -void badhandler(int sig, siginfo_t* siginfo, void* arg) { - char stack_var = 0; - char* current_ss = &stack_var; - - stack_t stack; - int ret = sigaltstack(nullptr, &stack); - if (ret < 0 || (stack.ss_flags & SS_ONSTACK) != SS_ONSTACK) { - // We should always be marked as being on the stack. Don't allow this to hit - // the bottom if this is ever not true (the main test will fail as a - // result, but we still need to unwind the recursive faults). - badhandler_on_sigaltstack = false; - } - if (current_ss < badhandler_low_water_mark) { - // Record the low point for the signal stack. We never expected this to be - // before stack bottom, but this is asserted in the actual test. - badhandler_low_water_mark = current_ss; - } - if (badhandler_recursive_faults > 0) { - badhandler_recursive_faults--; - Fault(); - } - FixupFault(reinterpret_cast<ucontext_t*>(arg)); -} - -TEST(SigaltstackTest, WalksOffBottom) { - // This test marks the upper half of the stack_mem array as the signal stack. - // It asserts that when a fault occurs in the handler (already on the signal - // stack), we eventually continue to fault our way off the stack. We should - // not revert to the top of the signal stack when we fall off the bottom and - // the signal stack should remain "in use". When we fall off the signal stack, - // we should have an unconditional signal delivered and not start using the - // first part of the stack_mem array. - std::vector<char> stack_mem(SIGSTKSZ * 2); - stack_t stack = {}; - stack.ss_sp = stack_mem.data() + SIGSTKSZ; // See above: upper half. - stack.ss_size = SIGSTKSZ; // Only one half the array. - auto const cleanup_sigstack = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack)); - - // Setup the handler: this must be for SIGSEGV, and it must allow proper - // nesting (no signal mask, no defer) so that we can trigger multiple times. - // - // When we walk off the bottom of the signal stack and force signal delivery - // of a SIGSEGV, the handler will revert to the default behavior (kill). - struct sigaction sa = {}; - sa.sa_sigaction = badhandler; - sa.sa_flags = SA_SIGINFO | SA_ONSTACK | SA_NODEFER; - auto const cleanup_sa = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa)); - - // Trigger a single fault. - badhandler_low_water_mark = - static_cast<char*>(stack.ss_sp) + SIGSTKSZ; // Expected top. - badhandler_recursive_faults = 0; // Disable refault. - Fault(); - EXPECT_TRUE(badhandler_on_sigaltstack); - EXPECT_THAT(sigaltstack(nullptr, &stack), SyscallSucceeds()); - EXPECT_EQ(stack.ss_flags & SS_ONSTACK, 0); - EXPECT_LT(badhandler_low_water_mark, - reinterpret_cast<char*>(stack.ss_sp) + 2 * SIGSTKSZ); - EXPECT_GT(badhandler_low_water_mark, reinterpret_cast<char*>(stack.ss_sp)); - - // Trigger two faults. - char* prev_low_water_mark = badhandler_low_water_mark; // Previous top. - badhandler_recursive_faults = 1; // One refault. - Fault(); - ASSERT_TRUE(badhandler_on_sigaltstack); - EXPECT_THAT(sigaltstack(nullptr, &stack), SyscallSucceeds()); - EXPECT_EQ(stack.ss_flags & SS_ONSTACK, 0); - EXPECT_LT(badhandler_low_water_mark, prev_low_water_mark); - EXPECT_GT(badhandler_low_water_mark, reinterpret_cast<char*>(stack.ss_sp)); - - // Calculate the stack growth for a fault, and set the recursive faults to - // ensure that the signal handler stack required exceeds our marked stack area - // by a minimal amount. It should remain in the valid stack_mem area so that - // we can test the signal is forced merely by going out of the signal stack - // bounds, not by a genuine fault. - uintptr_t frame_size = - static_cast<uintptr_t>(prev_low_water_mark - badhandler_low_water_mark); - badhandler_recursive_faults = (SIGSTKSZ + frame_size) / frame_size; - EXPECT_EXIT(Fault(), ::testing::KilledBySignal(SIGSEGV), ""); -} - -volatile int setonstack_retval = 0; // Set by the handler. -volatile int setonstack_errno = 0; // Set by the handler. - -void setonstack(int sig, siginfo_t* siginfo, void* arg) { - char stack_mem[SIGSTKSZ]; - stack_t stack = {}; - stack.ss_sp = &stack_mem[0]; - stack.ss_size = SIGSTKSZ; - setonstack_retval = sigaltstack(&stack, nullptr); - setonstack_errno = errno; - FixupFault(reinterpret_cast<ucontext_t*>(arg)); -} - -TEST(SigaltstackTest, SetWhileOnStack) { - // Reserve twice as much stack here, since the handler will allocate a vector - // of size SIGTKSZ and attempt to set the sigaltstack to that value. - std::vector<char> stack_mem(2 * SIGSTKSZ); - stack_t stack = {}; - stack.ss_sp = stack_mem.data(); - stack.ss_size = stack_mem.size(); - auto const cleanup_sigstack = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack)); - - // See above. - struct sigaction sa = {}; - sa.sa_sigaction = setonstack; - sa.sa_flags = SA_SIGINFO | SA_ONSTACK; - auto const cleanup_sa = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa)); - - // Trigger a fault. - Fault(); - - // The set should have failed. - EXPECT_EQ(setonstack_retval, -1); - EXPECT_EQ(setonstack_errno, EPERM); -} - -TEST(SigaltstackTest, SetCurrentStack) { - // This is executed as an exit test because once the signal stack is set to - // the local stack, there's no good way to unwind. We don't want to taint the - // test of any other tests that might run within this process. - EXPECT_EXIT( - { - char stack_value = 0; - stack_t stack = {}; - stack.ss_sp = &stack_value - kPageSize; // Lower than current level. - stack.ss_size = 2 * kPageSize; // => &stack_value +/- kPageSize. - TEST_CHECK(sigaltstack(&stack, nullptr) == 0); - TEST_CHECK(sigaltstack(nullptr, &stack) == 0); - TEST_CHECK((stack.ss_flags & SS_ONSTACK) != 0); - - // Should not be able to change the stack (even no-op). - TEST_CHECK(sigaltstack(&stack, nullptr) == -1 && errno == EPERM); - - // Should not be able to disable the stack. - stack.ss_flags = SS_DISABLE; - TEST_CHECK(sigaltstack(&stack, nullptr) == -1 && errno == EPERM); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sigaltstack_check.cc b/test/syscalls/linux/sigaltstack_check.cc deleted file mode 100644 index 5ac1b661d..000000000 --- a/test/syscalls/linux/sigaltstack_check.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Checks that there is no alternate signal stack by default. -// -// Used by a test in sigaltstack.cc. -#include <errno.h> -#include <signal.h> -#include <stdio.h> -#include <string.h> -#include <unistd.h> - -#include "test/util/logging.h" - -int main(int /* argc */, char** /* argv */) { - stack_t stack; - TEST_CHECK(sigaltstack(nullptr, &stack) >= 0); - TEST_CHECK(stack.ss_flags == SS_DISABLE); - TEST_CHECK(stack.ss_sp == 0); - TEST_CHECK(stack.ss_size == 0); - return 0; -} diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc deleted file mode 100644 index 6227774a4..000000000 --- a/test/syscalls/linux/sigiret.cc +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <sys/types.h> -#include <sys/ucontext.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/logging.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr uint64_t kOrigRcx = 0xdeadbeeffacefeed; -constexpr uint64_t kOrigR11 = 0xfacefeedbaad1dea; - -volatile int gotvtalrm, ready; - -void sigvtalrm(int sig, siginfo_t* siginfo, void* _uc) { - ucontext_t* uc = reinterpret_cast<ucontext_t*>(_uc); - - // Verify that: - // - test is in the busy-wait loop waiting for signal. - // - %rcx and %r11 values in mcontext_t match kOrigRcx and kOrigR11. - if (ready && - static_cast<uint64_t>(uc->uc_mcontext.gregs[REG_RCX]) == kOrigRcx && - static_cast<uint64_t>(uc->uc_mcontext.gregs[REG_R11]) == kOrigR11) { - // Modify the values %rcx and %r11 in the ucontext. These are the - // values seen by the application after the signal handler returns. - uc->uc_mcontext.gregs[REG_RCX] = ~kOrigRcx; - uc->uc_mcontext.gregs[REG_R11] = ~kOrigR11; - gotvtalrm = 1; - } -} - -TEST(SigIretTest, CheckRcxR11) { - // Setup signal handler for SIGVTALRM. - struct sigaction sa = {}; - sigfillset(&sa.sa_mask); - sa.sa_sigaction = sigvtalrm; - sa.sa_flags = SA_SIGINFO; - auto const action_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGVTALRM, sa)); - - auto const mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGVTALRM)); - - // Setup itimer to fire after 500 msecs. - struct itimerval itimer = {}; - itimer.it_value.tv_usec = 500 * 1000; // 500 msecs. - auto const timer_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_VIRTUAL, itimer)); - - // Initialize %rcx and %r11 and spin until the signal handler returns. - uint64_t rcx = kOrigRcx; - uint64_t r11 = kOrigR11; - asm volatile( - "movq %[rcx], %%rcx;" // %rcx = rcx - "movq %[r11], %%r11;" // %r11 = r11 - "movl $1, %[ready];" // ready = 1 - "1: pause; cmpl $0, %[gotvtalrm]; je 1b;" // while (!gotvtalrm); - "movq %%rcx, %[rcx];" // rcx = %rcx - "movq %%r11, %[r11];" // r11 = %r11 - : [ ready ] "=m"(ready), [ rcx ] "+m"(rcx), [ r11 ] "+m"(r11) - : [ gotvtalrm ] "m"(gotvtalrm) - : "cc", "memory", "rcx", "r11"); - - // If sigreturn(2) returns via 'sysret' then %rcx and %r11 will be - // clobbered and set to 'ptregs->rip' and 'ptregs->rflags' respectively. - // - // The following check verifies that %rcx and %r11 were not clobbered - // when returning from the signal handler (via sigreturn(2)). - EXPECT_EQ(rcx, ~kOrigRcx); - EXPECT_EQ(r11, ~kOrigR11); -} - -constexpr uint64_t kNonCanonicalRip = 0xCCCC000000000000; - -// Test that a non-canonical signal handler faults as expected. -TEST(SigIretTest, BadHandler) { - struct sigaction sa = {}; - sa.sa_sigaction = - reinterpret_cast<void (*)(int, siginfo_t*, void*)>(kNonCanonicalRip); - auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa)); - - pid_t pid = fork(); - if (pid == 0) { - // Child, wait for signal. - while (1) { - pause(); - } - } - ASSERT_THAT(pid, SyscallSucceeds()); - - EXPECT_THAT(kill(pid, SIGUSR1), SyscallSucceeds()); - - int status; - EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) - << "status = " << status; -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - // SigIretTest.CheckRcxR11 depends on delivering SIGVTALRM to the main thread. - // Block SIGVTALRM so that any other threads created by TestInit will also - // have SIGVTALRM blocked. - sigset_t set; - sigemptyset(&set); - sigaddset(&set, SIGVTALRM); - TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); - - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc deleted file mode 100644 index 389e5fca2..000000000 --- a/test/syscalls/linux/signalfd.cc +++ /dev/null @@ -1,373 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <poll.h> -#include <signal.h> -#include <stdio.h> -#include <string.h> -#include <sys/signalfd.h> -#include <unistd.h> - -#include <functional> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/synchronization/mutex.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -using ::testing::KilledBySignal; - -namespace gvisor { -namespace testing { - -namespace { - -constexpr int kSigno = SIGUSR1; -constexpr int kSignoMax = 64; // SIGRTMAX -constexpr int kSignoAlt = SIGUSR2; - -// Returns a new signalfd. -inline PosixErrorOr<FileDescriptor> NewSignalFD(sigset_t* mask, int flags = 0) { - int fd = signalfd(-1, mask, flags); - MaybeSave(); - if (fd < 0) { - return PosixError(errno, "signalfd"); - } - return FileDescriptor(fd); -} - -class SignalfdTest : public ::testing::TestWithParam<int> {}; - -TEST_P(SignalfdTest, Basic) { - int signo = GetParam(); - // Create the signalfd. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, signo); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); - - // Deliver the blocked signal. - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); - ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); - - // We should now read the signal. - struct signalfd_siginfo rbuf; - ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, signo); -} - -TEST_P(SignalfdTest, MaskWorks) { - int signo = GetParam(); - // Create two signalfds with different masks. - sigset_t mask1, mask2; - sigemptyset(&mask1); - sigemptyset(&mask2); - sigaddset(&mask1, signo); - sigaddset(&mask2, kSignoAlt); - FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask1, 0)); - FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask2, 0)); - - // Deliver the two signals. - const auto scoped_sigmask1 = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); - const auto scoped_sigmask2 = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSignoAlt)); - ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); - ASSERT_THAT(tgkill(getpid(), gettid(), kSignoAlt), SyscallSucceeds()); - - // We should see the signals on the appropriate signalfds. - // - // We read in the opposite order as the signals deliver above, to ensure that - // we don't happen to read the correct signal from the correct signalfd. - struct signalfd_siginfo rbuf1, rbuf2; - ASSERT_THAT(read(fd2.get(), &rbuf2, sizeof(rbuf2)), - SyscallSucceedsWithValue(sizeof(rbuf2))); - EXPECT_EQ(rbuf2.ssi_signo, kSignoAlt); - ASSERT_THAT(read(fd1.get(), &rbuf1, sizeof(rbuf1)), - SyscallSucceedsWithValue(sizeof(rbuf1))); - EXPECT_EQ(rbuf1.ssi_signo, signo); -} - -TEST(Signalfd, Cloexec) { - // Exec tests confirm that O_CLOEXEC has the intended effect. We just create a - // signalfd with the appropriate flag here and assert that the FD has it set. - sigset_t mask; - sigemptyset(&mask); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC)); - EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); -} - -TEST_P(SignalfdTest, Blocking) { - int signo = GetParam(); - // Create the signalfd in blocking mode. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, signo); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); - - // Shared tid variable. - absl::Mutex mu; - bool has_tid; - pid_t tid; - - // Start a thread reading. - ScopedThread t([&] { - // Copy the tid and notify the caller. - { - absl::MutexLock ml(&mu); - tid = gettid(); - has_tid = true; - } - - // Read the signal from the signalfd. - struct signalfd_siginfo rbuf; - ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, signo); - }); - - // Wait until blocked. - absl::MutexLock ml(&mu); - mu.Await(absl::Condition(&has_tid)); - - // Deliver the signal to either the waiting thread, or - // to this thread. N.B. this is a bug in the core gVisor - // behavior for signalfd, and needs to be fixed. - // - // See gvisor.dev/issue/139. - if (IsRunningOnGvisor()) { - ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); - } else { - ASSERT_THAT(tgkill(getpid(), tid, signo), SyscallSucceeds()); - } - - // Ensure that it was received. - t.Join(); -} - -TEST_P(SignalfdTest, ThreadGroup) { - int signo = GetParam(); - // Create the signalfd in blocking mode. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, signo); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); - - // Shared variable. - absl::Mutex mu; - bool first = false; - bool second = false; - - // Start a thread reading. - ScopedThread t([&] { - // Read the signal from the signalfd. - struct signalfd_siginfo rbuf; - ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, signo); - - // Wait for the other thread. - absl::MutexLock ml(&mu); - first = true; - mu.Await(absl::Condition(&second)); - }); - - // Deliver the signal to the threadgroup. - ASSERT_THAT(kill(getpid(), signo), SyscallSucceeds()); - - // Wait for the first thread to process. - { - absl::MutexLock ml(&mu); - mu.Await(absl::Condition(&first)); - } - - // Deliver to the thread group again (other thread still exists). - ASSERT_THAT(kill(getpid(), signo), SyscallSucceeds()); - - // Ensure that we can also receive it. - struct signalfd_siginfo rbuf; - ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, signo); - - // Mark the test as done. - { - absl::MutexLock ml(&mu); - second = true; - } - - // The other thread should be joinable. - t.Join(); -} - -TEST_P(SignalfdTest, Nonblock) { - int signo = GetParam(); - // Create the signalfd in non-blocking mode. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, signo); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK)); - - // We should return if we attempt to read. - struct signalfd_siginfo rbuf; - ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Block and deliver the signal. - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); - ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); - - // Ensure that a read actually works. - ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, signo); - - // Should block again. - EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(SignalfdTest, SetMask) { - int signo = GetParam(); - // Create the signalfd matching nothing. - sigset_t mask; - sigemptyset(&mask); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK)); - - // Block and deliver a signal. - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); - ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); - - // We should have nothing. - struct signalfd_siginfo rbuf; - ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Change the signal mask. - sigaddset(&mask, signo); - ASSERT_THAT(signalfd(fd.get(), &mask, 0), SyscallSucceeds()); - - // We should now have the signal. - ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, signo); -} - -TEST_P(SignalfdTest, Poll) { - int signo = GetParam(); - // Create the signalfd. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, signo); - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); - - // Block the signal, and start a thread to deliver it. - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); - pid_t orig_tid = gettid(); - ScopedThread t([&] { - absl::SleepFor(absl::Seconds(5)); - ASSERT_THAT(tgkill(getpid(), orig_tid, signo), SyscallSucceeds()); - }); - - // Start polling for the signal. We expect that it is not available at the - // outset, but then becomes available when the signal is sent. We give a - // timeout of 10000ms (or the delay above + 5 seconds of additional grace - // time). - struct pollfd poll_fd = {fd.get(), POLLIN, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), - SyscallSucceedsWithValue(1)); - - // Actually read the signal to prevent delivery. - struct signalfd_siginfo rbuf; - EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); -} - -std::string PrintSigno(::testing::TestParamInfo<int> info) { - switch (info.param) { - case kSigno: - return "kSigno"; - case kSignoMax: - return "kSignoMax"; - default: - return absl::StrCat(info.param); - } -} -INSTANTIATE_TEST_SUITE_P(Signalfd, SignalfdTest, - ::testing::Values(kSigno, kSignoMax), PrintSigno); - -TEST(Signalfd, Ppoll) { - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, SIGKILL); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC)); - - // Ensure that the given ppoll blocks. - struct pollfd pfd = {}; - pfd.fd = fd.get(); - pfd.events = POLLIN; - struct timespec timeout = {}; - timeout.tv_sec = 1; - EXPECT_THAT(RetryEINTR(ppoll)(&pfd, 1, &timeout, &mask), - SyscallSucceedsWithValue(0)); -} - -TEST(Signalfd, KillStillKills) { - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, SIGKILL); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC)); - - // Just because there is a signalfd, we shouldn't see any change in behavior - // for unblockable signals. It's easier to test this with SIGKILL. - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL)); - EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), ""); -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - // These tests depend on delivering signals. Block them up front so that all - // other threads created by TestInit will also have them blocked, and they - // will not interface with the rest of the test. - sigset_t set; - sigemptyset(&set); - sigaddset(&set, gvisor::testing::kSigno); - sigaddset(&set, gvisor::testing::kSignoMax); - sigaddset(&set, gvisor::testing::kSignoAlt); - TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); - - gvisor::testing::TestInit(&argc, &argv); - - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/sigprocmask.cc b/test/syscalls/linux/sigprocmask.cc deleted file mode 100644 index a603fc1d1..000000000 --- a/test/syscalls/linux/sigprocmask.cc +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <stddef.h> -#include <sys/syscall.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Signals numbers used for testing. -static constexpr int kTestSignal1 = SIGUSR1; -static constexpr int kTestSignal2 = SIGUSR2; - -static int raw_sigprocmask(int how, const sigset_t* set, sigset_t* oldset) { - return syscall(SYS_rt_sigprocmask, how, set, oldset, _NSIG / 8); -} - -// count of the number of signals received -int signal_count[kMaxSignal + 1]; - -// signal handler increments the signal counter -void SigHandler(int sig, siginfo_t* info, void* context) { - TEST_CHECK(sig > 0 && sig <= kMaxSignal); - signal_count[sig] += 1; -} - -// The test fixture saves and restores the signal mask and -// sets up handlers for kTestSignal1 and kTestSignal2. -class SigProcMaskTest : public ::testing::Test { - protected: - void SetUp() override { - // Save the current signal mask. - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &mask_), - SyscallSucceeds()); - - // Setup signal handlers for kTestSignal1 and kTestSignal2. - struct sigaction sa; - sa.sa_sigaction = SigHandler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - EXPECT_THAT(sigaction(kTestSignal1, &sa, &sa_test_sig_1_), - SyscallSucceeds()); - EXPECT_THAT(sigaction(kTestSignal2, &sa, &sa_test_sig_2_), - SyscallSucceeds()); - - // Clear the signal counters. - memset(signal_count, 0, sizeof(signal_count)); - } - - void TearDown() override { - // Restore the signal mask. - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &mask_, nullptr), - SyscallSucceeds()); - - // Restore the signal handlers for kTestSignal1 and kTestSignal2. - EXPECT_THAT(sigaction(kTestSignal1, &sa_test_sig_1_, nullptr), - SyscallSucceeds()); - EXPECT_THAT(sigaction(kTestSignal2, &sa_test_sig_2_, nullptr), - SyscallSucceeds()); - } - - private: - sigset_t mask_; - struct sigaction sa_test_sig_1_; - struct sigaction sa_test_sig_2_; -}; - -// Both sigsets nullptr should succeed and do nothing. -TEST_F(SigProcMaskTest, NullAddress) { - EXPECT_THAT(raw_sigprocmask(SIG_BLOCK, nullptr, NULL), SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_UNBLOCK, nullptr, NULL), SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, NULL), SyscallSucceeds()); -} - -// Bad address for either sigset should fail with EFAULT. -TEST_F(SigProcMaskTest, BadAddress) { - sigset_t* bad_addr = reinterpret_cast<sigset_t*>(-1); - - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, bad_addr, nullptr), - SyscallFailsWithErrno(EFAULT)); - - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, bad_addr), - SyscallFailsWithErrno(EFAULT)); -} - -// Bad value of the "how" parameter should fail with EINVAL. -TEST_F(SigProcMaskTest, BadParameter) { - int bad_param_1 = -1; - int bad_param_2 = 42; - - sigset_t set1; - sigemptyset(&set1); - - EXPECT_THAT(raw_sigprocmask(bad_param_1, &set1, nullptr), - SyscallFailsWithErrno(EINVAL)); - - EXPECT_THAT(raw_sigprocmask(bad_param_2, &set1, nullptr), - SyscallFailsWithErrno(EINVAL)); -} - -// Check that we can get the current signal mask. -TEST_F(SigProcMaskTest, GetMask) { - sigset_t set1; - sigset_t set2; - - sigemptyset(&set1); - sigfillset(&set2); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &set1), SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &set2), SyscallSucceeds()); - EXPECT_THAT(set1, EqualsSigset(set2)); -} - -// Check that we can set the signal mask. -TEST_F(SigProcMaskTest, SetMask) { - sigset_t actual; - sigset_t expected; - - // Try to mask all signals - sigfillset(&expected); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &expected, nullptr), - SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual), - SyscallSucceeds()); - // sigprocmask() should have silently ignored SIGKILL and SIGSTOP. - sigdelset(&expected, SIGSTOP); - sigdelset(&expected, SIGKILL); - EXPECT_THAT(actual, EqualsSigset(expected)); - - // Try to clear the signal mask - sigemptyset(&expected); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &expected, nullptr), - SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual), - SyscallSucceeds()); - EXPECT_THAT(actual, EqualsSigset(expected)); - - // Try to set a mask with one signal. - sigemptyset(&expected); - sigaddset(&expected, kTestSignal1); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &expected, nullptr), - SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual), - SyscallSucceeds()); - EXPECT_THAT(actual, EqualsSigset(expected)); -} - -// Check that we can add and remove signals. -TEST_F(SigProcMaskTest, BlockUnblock) { - sigset_t actual; - sigset_t expected; - - // Try to set a mask with one signal. - sigemptyset(&expected); - sigaddset(&expected, kTestSignal1); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &expected, nullptr), - SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual), - SyscallSucceeds()); - EXPECT_THAT(actual, EqualsSigset(expected)); - - // Try to add another signal. - sigset_t block; - sigemptyset(&block); - sigaddset(&block, kTestSignal2); - EXPECT_THAT(raw_sigprocmask(SIG_BLOCK, &block, nullptr), SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual), - SyscallSucceeds()); - sigaddset(&expected, kTestSignal2); - EXPECT_THAT(actual, EqualsSigset(expected)); - - // Try to remove a signal. - sigset_t unblock; - sigemptyset(&unblock); - sigaddset(&unblock, kTestSignal1); - EXPECT_THAT(raw_sigprocmask(SIG_UNBLOCK, &unblock, nullptr), - SyscallSucceeds()); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual), - SyscallSucceeds()); - sigdelset(&expected, kTestSignal1); - EXPECT_THAT(actual, EqualsSigset(expected)); -} - -// Test that the signal mask actually blocks signals. -TEST_F(SigProcMaskTest, SignalHandler) { - sigset_t mask; - - // clear the signal mask - sigemptyset(&mask); - EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &mask, nullptr), SyscallSucceeds()); - - // Check the initial signal counts. - EXPECT_EQ(0, signal_count[kTestSignal1]); - EXPECT_EQ(0, signal_count[kTestSignal2]); - - // Check that both kTestSignal1 and kTestSignal2 are not blocked. - raise(kTestSignal1); - raise(kTestSignal2); - EXPECT_EQ(1, signal_count[kTestSignal1]); - EXPECT_EQ(1, signal_count[kTestSignal2]); - - // Block kTestSignal1. - sigaddset(&mask, kTestSignal1); - EXPECT_THAT(raw_sigprocmask(SIG_BLOCK, &mask, nullptr), SyscallSucceeds()); - - // Check that kTestSignal1 is blocked. - raise(kTestSignal1); - raise(kTestSignal2); - EXPECT_EQ(1, signal_count[kTestSignal1]); - EXPECT_EQ(2, signal_count[kTestSignal2]); - - // Unblock kTestSignal1. - sigaddset(&mask, kTestSignal1); - EXPECT_THAT(raw_sigprocmask(SIG_UNBLOCK, &mask, nullptr), SyscallSucceeds()); - - // Check that the unblocked kTestSignal1 has been delivered. - EXPECT_EQ(2, signal_count[kTestSignal1]); - EXPECT_EQ(2, signal_count[kTestSignal2]); -} - -// Check that sigprocmask correctly handles aliasing of the set and oldset -// pointers. Regression test for b/30502311. -TEST_F(SigProcMaskTest, AliasedSets) { - sigset_t mask; - - // Set a mask in which only kTestSignal1 is blocked. - sigset_t mask1; - sigemptyset(&mask1); - sigaddset(&mask1, kTestSignal1); - mask = mask1; - ASSERT_THAT(raw_sigprocmask(SIG_SETMASK, &mask, nullptr), SyscallSucceeds()); - - // Exchange it with a mask in which only kTestSignal2 is blocked. - sigset_t mask2; - sigemptyset(&mask2); - sigaddset(&mask2, kTestSignal2); - mask = mask2; - ASSERT_THAT(raw_sigprocmask(SIG_SETMASK, &mask, &mask), SyscallSucceeds()); - - // Check that the exchange succeeeded: - // mask should now contain the previously-set mask blocking only kTestSignal1. - EXPECT_THAT(mask, EqualsSigset(mask1)); - // The current mask should block only kTestSignal2. - ASSERT_THAT(raw_sigprocmask(0, nullptr, &mask), SyscallSucceeds()); - EXPECT_THAT(mask, EqualsSigset(mask2)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc deleted file mode 100644 index b2fcedd62..000000000 --- a/test/syscalls/linux/sigstop.cc +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <stdlib.h> -#include <sys/select.h> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -ABSL_FLAG(bool, sigstop_test_child, false, - "If true, run the SigstopTest child workload."); - -namespace gvisor { -namespace testing { - -namespace { - -constexpr absl::Duration kChildStartupDelay = absl::Seconds(5); -constexpr absl::Duration kChildMainThreadDelay = absl::Seconds(10); -constexpr absl::Duration kChildExtraThreadDelay = absl::Seconds(15); -constexpr absl::Duration kPostSIGSTOPDelay = absl::Seconds(20); - -// Comparisons on absl::Duration aren't yet constexpr (2017-07-14), so we -// can't just use static_assert. -TEST(SigstopTest, TimesAreRelativelyConsistent) { - EXPECT_LT(kChildStartupDelay, kChildMainThreadDelay) - << "Child process will exit before the parent process attempts to stop " - "it"; - EXPECT_LT(kChildMainThreadDelay, kChildExtraThreadDelay) - << "Secondary thread in child process will exit before main thread, " - "causing it to exit with the wrong code"; - EXPECT_LT(kChildExtraThreadDelay, kPostSIGSTOPDelay) - << "Parent process stops waiting before child process may exit if " - "improperly stopped, rendering the test ineffective"; -} - -// Exit codes communicated from the child workload to the parent test process. -constexpr int kChildMainThreadExitCode = 10; -constexpr int kChildExtraThreadExitCode = 11; - -TEST(SigstopTest, Correctness) { - pid_t child_pid = -1; - int execve_errno = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec("/proc/self/exe", {"/proc/self/exe", "--sigstop_test_child"}, - {}, nullptr, &child_pid, &execve_errno)); - - ASSERT_GT(child_pid, 0); - ASSERT_EQ(execve_errno, 0); - - // Wait for the child subprocess to start the second thread before stopping - // it. - absl::SleepFor(kChildStartupDelay); - ASSERT_THAT(kill(child_pid, SIGSTOP), SyscallSucceeds()); - int status; - EXPECT_THAT(RetryEINTR(waitpid)(child_pid, &status, WUNTRACED), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFSTOPPED(status)); - EXPECT_EQ(SIGSTOP, WSTOPSIG(status)); - - // Sleep for longer than either of the sleeps in the child subprocess, - // expecting the child to stay alive because it's stopped. - absl::SleepFor(kPostSIGSTOPDelay); - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, WNOHANG), - SyscallSucceedsWithValue(0)); - - // Resume the child. - ASSERT_THAT(kill(child_pid, SIGCONT), SyscallSucceeds()); - - EXPECT_THAT(RetryEINTR(waitpid)(child_pid, &status, WCONTINUED), - SyscallSucceedsWithValue(child_pid)); - EXPECT_TRUE(WIFCONTINUED(status)); - - // Expect it to die. - ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - ASSERT_TRUE(WIFEXITED(status)); - ASSERT_EQ(WEXITSTATUS(status), kChildMainThreadExitCode); -} - -// Like base:SleepFor, but tries to avoid counting time spent stopped due to a -// stop signal toward the sleep. -// -// This is required due to an inconsistency in how nanosleep(2) and stop signals -// interact on Linux. When nanosleep is interrupted, it writes the remaining -// time back to its second timespec argument, so that if nanosleep is -// interrupted by a signal handler then userspace can immediately call nanosleep -// again with that timespec. However, if nanosleep is automatically restarted -// (because it's interrupted by a signal that is not delivered to a handler, -// such as a stop signal), it's restarted based on the timer's former *absolute* -// expiration time (via ERESTART_RESTARTBLOCK => SYS_restart_syscall => -// hrtimer_nanosleep_restart). This means that time spent stopped is effectively -// counted as time spent sleeping, resulting in less time spent sleeping than -// expected. -// -// Dividing the sleep into multiple smaller sleeps limits the impact of this -// effect to the length of each sleep during which a stop occurs; for example, -// if a sleeping process is only stopped once, SleepIgnoreStopped can -// under-sleep by at most 100ms. -void SleepIgnoreStopped(absl::Duration d) { - absl::Duration const max_sleep = absl::Milliseconds(100); - while (d > absl::ZeroDuration()) { - absl::Duration to_sleep = std::min(d, max_sleep); - absl::SleepFor(to_sleep); - d -= to_sleep; - } -} - -void RunChild() { - // Start another thread that attempts to call exit_group with a different - // error code, in order to verify that SIGSTOP stops this thread as well. - ScopedThread t([] { - SleepIgnoreStopped(kChildExtraThreadDelay); - exit(kChildExtraThreadExitCode); - }); - SleepIgnoreStopped(kChildMainThreadDelay); - exit(kChildMainThreadExitCode); -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - gvisor::testing::TestInit(&argc, &argv); - - if (absl::GetFlag(FLAGS_sigstop_test_child)) { - gvisor::testing::RunChild(); - return 1; - } - - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc deleted file mode 100644 index 4f8afff15..000000000 --- a/test/syscalls/linux/sigtimedwait.cc +++ /dev/null @@ -1,323 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/wait.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// N.B. main() blocks SIGALRM and SIGCHLD on all threads. - -constexpr int kAlarmSecs = 12; - -void NoopHandler(int sig, siginfo_t* info, void* context) {} - -TEST(SigtimedwaitTest, InvalidTimeout) { - sigset_t mask; - sigemptyset(&mask); - struct timespec timeout = {0, 1000000001}; - EXPECT_THAT(sigtimedwait(&mask, nullptr, &timeout), - SyscallFailsWithErrno(EINVAL)); - timeout = {-1, 0}; - EXPECT_THAT(sigtimedwait(&mask, nullptr, &timeout), - SyscallFailsWithErrno(EINVAL)); - timeout = {0, -1}; - EXPECT_THAT(sigtimedwait(&mask, nullptr, &timeout), - SyscallFailsWithErrno(EINVAL)); -} - -// No random save as the test relies on alarm timing. Cooperative save tests -// already cover the save between alarm and wait. -TEST(SigtimedwaitTest, AlarmReturnsAlarm_NoRandomSave) { - struct itimerval itv = {}; - itv.it_value.tv_sec = kAlarmSecs; - const auto itimer_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, itv)); - - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, SIGALRM); - siginfo_t info = {}; - EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, &info, nullptr), - SyscallSucceedsWithValue(SIGALRM)); - EXPECT_EQ(SIGALRM, info.si_signo); -} - -// No random save as the test relies on alarm timing. Cooperative save tests -// already cover the save between alarm and wait. -TEST(SigtimedwaitTest, NullTimeoutReturnsEINTR_NoRandomSave) { - struct sigaction sa; - sa.sa_sigaction = NoopHandler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - const auto action_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); - - const auto mask_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM)); - - struct itimerval itv = {}; - itv.it_value.tv_sec = kAlarmSecs; - const auto itimer_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, itv)); - - sigset_t mask; - sigemptyset(&mask); - EXPECT_THAT(sigtimedwait(&mask, nullptr, nullptr), - SyscallFailsWithErrno(EINTR)); -} - -TEST(SigtimedwaitTest, LegitTimeoutReturnsEAGAIN) { - sigset_t mask; - sigemptyset(&mask); - struct timespec timeout = {1, 0}; // 1 second - EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &timeout), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST(SigtimedwaitTest, ZeroTimeoutReturnsEAGAIN) { - sigset_t mask; - sigemptyset(&mask); - struct timespec timeout = {0, 0}; // 0 second - EXPECT_THAT(sigtimedwait(&mask, nullptr, &timeout), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST(SigtimedwaitTest, KillGeneratedSIGCHLD) { - EXPECT_THAT(kill(getpid(), SIGCHLD), SyscallSucceeds()); - - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, SIGCHLD); - struct timespec ts = {5, 0}; - EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &ts), - SyscallSucceedsWithValue(SIGCHLD)); -} - -TEST(SigtimedwaitTest, ChildExitGeneratedSIGCHLD) { - pid_t pid = fork(); - if (pid == 0) { - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - - int status; - EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) << status; - - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, SIGCHLD); - struct timespec ts = {5, 0}; - EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &ts), - SyscallSucceedsWithValue(SIGCHLD)); -} - -TEST(SigtimedwaitTest, ChildExitGeneratedSIGCHLDWithHandler) { - // Setup handler for SIGCHLD, but don't unblock it. - struct sigaction sa; - sa.sa_sigaction = NoopHandler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - const auto action_cleanup = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa)); - - pid_t pid = fork(); - if (pid == 0) { - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, SIGCHLD); - struct timespec ts = {5, 0}; - EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &ts), - SyscallSucceedsWithValue(SIGCHLD)); - - int status; - EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) << status; -} - -// sigtimedwait cannot catch SIGKILL. -TEST(SigtimedwaitTest, SIGKILLUncaught) { - // This is a regression test for sigtimedwait dequeuing SIGKILLs, thus - // preventing the task from exiting. - // - // The explanation below is specific to behavior in gVisor. The Linux behavior - // here is irrelevant because without a bug that prevents delivery of SIGKILL, - // none of this behavior is visible (in Linux or gVisor). - // - // SIGKILL is rather intrusive. Simply sending the SIGKILL marks - // ThreadGroup.exitStatus as exiting with SIGKILL, before the SIGKILL is even - // delivered. - // - // As a result, we cannot simply exit the child with a different exit code if - // it survives and expect to see that code in waitpid because: - // 1. PrepareGroupExit will override Task.exitStatus with - // ThreadGroup.exitStatus. - // 2. waitpid(2) will always return ThreadGroup.exitStatus rather than - // Task.exitStatus. - // - // We could use exit(2) to set Task.exitStatus without override, and a SIGCHLD - // handler to receive Task.exitStatus in the parent, but with that much - // test complexity, it is cleaner to simply use a pipe to notify the parent - // that we survived. - constexpr auto kSigtimedwaitSetupTime = absl::Seconds(2); - - int pipe_fds[2]; - ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); - FileDescriptor rfd(pipe_fds[0]); - FileDescriptor wfd(pipe_fds[1]); - - pid_t pid = fork(); - if (pid == 0) { - rfd.reset(); - - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, SIGKILL); - RetryEINTR(sigtimedwait)(&mask, nullptr, nullptr); - - // Survived. - char c = 'a'; - TEST_PCHECK(WriteFd(wfd.get(), &c, 1) == 1); - _exit(1); - } - ASSERT_THAT(pid, SyscallSucceeds()); - - wfd.reset(); - - // Wait for child to block in sigtimedwait, then kill it. - absl::SleepFor(kSigtimedwaitSetupTime); - - // Sending SIGKILL will attempt to enqueue the signal twice: once in the - // normal signal sending path, and once to all Tasks in the ThreadGroup when - // applying SIGKILL side-effects. - // - // If we use kill(2), the former will be on the ThreadGroup signal queue and - // the latter will be on the Task signal queue. sigtimedwait can only dequeue - // one signal, so the other would kill the Task, masking bugs. - // - // If we use tkill(2), the former will be on the Task signal queue and the - // latter will be dropped as a duplicate. Then sigtimedwait can theoretically - // dequeue the single SIGKILL. - EXPECT_THAT(syscall(SYS_tkill, pid, SIGKILL), SyscallSucceeds()); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(pid, &status, 0), - SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) << status; - - // Child shouldn't have survived. - char c; - EXPECT_THAT(ReadFd(rfd.get(), &c, 1), SyscallSucceedsWithValue(0)); -} - -TEST(SigtimedwaitTest, IgnoredUnmaskedSignal) { - constexpr int kSigno = SIGUSR1; - constexpr auto kSigtimedwaitSetupTime = absl::Seconds(2); - constexpr auto kSigtimedwaitTimeout = absl::Seconds(5); - ASSERT_GT(kSigtimedwaitTimeout, kSigtimedwaitSetupTime); - - // Ensure that kSigno is ignored, and unmasked on this thread. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - const auto scoped_sigaction = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kSigno, sa)); - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, kSigno); - auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, mask)); - - // Create a thread which will send us kSigno while we are blocked in - // sigtimedwait. - pid_t tid = gettid(); - ScopedThread sigthread([&] { - absl::SleepFor(kSigtimedwaitSetupTime); - EXPECT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds()); - }); - - // sigtimedwait should not observe kSigno since it is ignored and already - // unmasked, causing it to be dropped before it is enqueued. - struct timespec timeout_ts = absl::ToTimespec(kSigtimedwaitTimeout); - EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &timeout_ts), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST(SigtimedwaitTest, IgnoredMaskedSignal) { - constexpr int kSigno = SIGUSR1; - constexpr auto kSigtimedwaitSetupTime = absl::Seconds(2); - constexpr auto kSigtimedwaitTimeout = absl::Seconds(5); - ASSERT_GT(kSigtimedwaitTimeout, kSigtimedwaitSetupTime); - - // Ensure that kSigno is ignored, and masked on this thread. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - const auto scoped_sigaction = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kSigno, sa)); - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, kSigno); - auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, mask)); - - // Create a thread which will send us kSigno while we are blocked in - // sigtimedwait. - pid_t tid = gettid(); - ScopedThread sigthread([&] { - absl::SleepFor(kSigtimedwaitSetupTime); - EXPECT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds()); - }); - - // sigtimedwait should observe kSigno since it is normally masked, causing it - // to be enqueued despite being ignored. - struct timespec timeout_ts = absl::ToTimespec(kSigtimedwaitTimeout); - EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &timeout_ts), - SyscallSucceedsWithValue(kSigno)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - // These tests depend on delivering SIGALRM/SIGCHLD to the main thread or in - // sigtimedwait. Block them so that any other threads created by TestInit will - // also have them blocked. - sigset_t set; - sigemptyset(&set); - sigaddset(&set, SIGALRM); - sigaddset(&set, SIGCHLD); - TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); - - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc deleted file mode 100644 index 3a07ac8d2..000000000 --- a/test/syscalls/linux/socket.cc +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/socket.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -TEST(SocketTest, UnixSocketPairProtocol) { - int socks[2]; - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks), - SyscallSucceeds()); - close(socks[0]); - close(socks[1]); -} - -TEST(SocketTest, ProtocolUnix) { - struct { - int domain, type, protocol; - } tests[] = { - {AF_UNIX, SOCK_STREAM, PF_UNIX}, - {AF_UNIX, SOCK_SEQPACKET, PF_UNIX}, - {AF_UNIX, SOCK_DGRAM, PF_UNIX}, - }; - for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { - ASSERT_NO_ERRNO_AND_VALUE( - Socket(tests[i].domain, tests[i].type, tests[i].protocol)); - } -} - -TEST(SocketTest, ProtocolInet) { - struct { - int domain, type, protocol; - } tests[] = { - {AF_INET, SOCK_DGRAM, IPPROTO_UDP}, - {AF_INET, SOCK_STREAM, IPPROTO_TCP}, - }; - for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { - ASSERT_NO_ERRNO_AND_VALUE( - Socket(tests[i].domain, tests[i].type, tests[i].protocol)); - } -} - -using SocketOpenTest = ::testing::TestWithParam<int>; - -// UDS cannot be opened. -TEST_P(SocketOpenTest, Unix) { - // FIXME(b/142001530): Open incorrectly succeeds on gVisor. - SKIP_IF(IsRunningOnGvisor()); - - FileDescriptor bound = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); - - struct sockaddr_un addr = - ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX)); - - ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - EXPECT_THAT(open(addr.sun_path, GetParam()), SyscallFailsWithErrno(ENXIO)); -} - -INSTANTIATE_TEST_SUITE_P(OpenModes, SocketOpenTest, - ::testing::Values(O_RDONLY, O_RDWR)); - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_abstract.cc b/test/syscalls/linux/socket_abstract.cc deleted file mode 100644 index 00999f192..000000000 --- a/test/syscalls/linux/socket_abstract.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_generic.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/socket_unix.h" -#include "test/syscalls/linux/socket_unix_cmsg.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVec<SocketPairKind>( - AbstractBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P( - AbstractUnixSockets, AllSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - AbstractUnixSockets, UnixSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - AbstractUnixSockets, UnixSocketPairCmsgTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc deleted file mode 100644 index 6b27f6eab..000000000 --- a/test/syscalls/linux/socket_bind_to_device.cc +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <linux/if_tun.h> -#include <net/if.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <cstdio> -#include <cstring> -#include <map> -#include <memory> -#include <unordered_map> -#include <unordered_set> -#include <utility> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_bind_to_device_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -using std::string; - -// Test fixture for SO_BINDTODEVICE tests. -class BindToDeviceTest : public ::testing::TestWithParam<SocketKind> { - protected: - void SetUp() override { - printf("Testing case: %s\n", GetParam().description.c_str()); - ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) - << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; - - interface_name_ = "eth1"; - auto interface_names = GetInterfaceNames(); - if (interface_names.find(interface_name_) == interface_names.end()) { - // Need a tunnel. - tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()); - interface_name_ = tunnel_->GetName(); - ASSERT_FALSE(interface_name_.empty()); - } - socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create()); - } - - string interface_name() const { return interface_name_; } - - int socket_fd() const { return socket_->get(); } - - private: - std::unique_ptr<Tunnel> tunnel_; - string interface_name_; - std::unique_ptr<FileDescriptor> socket_; -}; - -constexpr char kIllegalIfnameChar = '/'; - -// Tests getsockopt of the default value. -TEST_P(BindToDeviceTest, GetsockoptDefault) { - char name_buffer[IFNAMSIZ * 2]; - char original_name_buffer[IFNAMSIZ * 2]; - socklen_t name_buffer_size; - - // Read the default SO_BINDTODEVICE. - memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - for (size_t i = 0; i <= sizeof(name_buffer); i++) { - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = i; - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, - name_buffer, &name_buffer_size), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(name_buffer_size, 0); - EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)), - 0); - } -} - -// Tests setsockopt of invalid device name. -TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) { - char name_buffer[IFNAMSIZ * 2]; - socklen_t name_buffer_size; - - // Set an invalid device name. - memset(name_buffer, kIllegalIfnameChar, 5); - name_buffer_size = 5; - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - name_buffer_size), - SyscallFailsWithErrno(ENODEV)); -} - -// Tests setsockopt of a buffer with a valid device name but not -// null-terminated, with different sizes of buffer. -TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) { - char name_buffer[IFNAMSIZ * 2]; - socklen_t name_buffer_size; - - strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); - // Intentionally overwrite the null at the end. - memset(name_buffer + interface_name().size(), kIllegalIfnameChar, - sizeof(name_buffer) - interface_name().size()); - for (size_t i = 1; i <= sizeof(name_buffer); i++) { - name_buffer_size = i; - SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); - // It should only work if the size provided is exactly right. - if (name_buffer_size == interface_name().size()) { - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, - name_buffer, name_buffer_size), - SyscallSucceeds()); - } else { - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, - name_buffer, name_buffer_size), - SyscallFailsWithErrno(ENODEV)); - } - } -} - -// Tests setsockopt of a buffer with a valid device name and null-terminated, -// with different sizes of buffer. -TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) { - char name_buffer[IFNAMSIZ * 2]; - socklen_t name_buffer_size; - - strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); - // Don't overwrite the null at the end. - memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar, - sizeof(name_buffer) - interface_name().size() - 1); - for (size_t i = 1; i <= sizeof(name_buffer); i++) { - name_buffer_size = i; - SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); - // It should only work if the size provided is at least the right size. - if (name_buffer_size >= interface_name().size()) { - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, - name_buffer, name_buffer_size), - SyscallSucceeds()); - } else { - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, - name_buffer, name_buffer_size), - SyscallFailsWithErrno(ENODEV)); - } - } -} - -// Tests that setsockopt of an invalid device name doesn't unset the previous -// valid setsockopt. -TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) { - char name_buffer[IFNAMSIZ * 2]; - socklen_t name_buffer_size; - - // Write successfully. - strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); - ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - sizeof(name_buffer)), - SyscallSucceeds()); - - // Read it back successfully. - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = sizeof(name_buffer); - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - &name_buffer_size), - SyscallSucceeds()); - EXPECT_EQ(name_buffer_size, interface_name().size() + 1); - EXPECT_STREQ(name_buffer, interface_name().c_str()); - - // Write unsuccessfully. - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = 5; - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - sizeof(name_buffer)), - SyscallFailsWithErrno(ENODEV)); - - // Read it back successfully, it's unchanged. - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = sizeof(name_buffer); - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - &name_buffer_size), - SyscallSucceeds()); - EXPECT_EQ(name_buffer_size, interface_name().size() + 1); - EXPECT_STREQ(name_buffer, interface_name().c_str()); -} - -// Tests that setsockopt of zero-length string correctly unsets the previous -// value. -TEST_P(BindToDeviceTest, SetsockoptValidThenClear) { - char name_buffer[IFNAMSIZ * 2]; - socklen_t name_buffer_size; - - // Write successfully. - strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - sizeof(name_buffer)), - SyscallSucceeds()); - - // Read it back successfully. - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = sizeof(name_buffer); - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - &name_buffer_size), - SyscallSucceeds()); - EXPECT_EQ(name_buffer_size, interface_name().size() + 1); - EXPECT_STREQ(name_buffer, interface_name().c_str()); - - // Clear it successfully. - name_buffer_size = 0; - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - name_buffer_size), - SyscallSucceeds()); - - // Read it back successfully, it's cleared. - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = sizeof(name_buffer); - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - &name_buffer_size), - SyscallSucceeds()); - EXPECT_EQ(name_buffer_size, 0); -} - -// Tests that setsockopt of empty string correctly unsets the previous -// value. -TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) { - char name_buffer[IFNAMSIZ * 2]; - socklen_t name_buffer_size; - - // Write successfully. - strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - sizeof(name_buffer)), - SyscallSucceeds()); - - // Read it back successfully. - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = sizeof(name_buffer); - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - &name_buffer_size), - SyscallSucceeds()); - EXPECT_EQ(name_buffer_size, interface_name().size() + 1); - EXPECT_STREQ(name_buffer, interface_name().c_str()); - - // Clear it successfully. - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer[0] = 0; - name_buffer_size = sizeof(name_buffer); - EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - name_buffer_size), - SyscallSucceeds()); - - // Read it back successfully, it's cleared. - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = sizeof(name_buffer); - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - &name_buffer_size), - SyscallSucceeds()); - EXPECT_EQ(name_buffer_size, 0); -} - -// Tests getsockopt with different buffer sizes. -TEST_P(BindToDeviceTest, GetsockoptDevice) { - char name_buffer[IFNAMSIZ * 2]; - socklen_t name_buffer_size; - - // Write successfully. - strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); - ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, - sizeof(name_buffer)), - SyscallSucceeds()); - - // Read it back at various buffer sizes. - for (size_t i = 0; i <= sizeof(name_buffer); i++) { - memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); - name_buffer_size = i; - SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); - // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice - // for this interface name. - if (name_buffer_size >= IFNAMSIZ) { - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, - name_buffer, &name_buffer_size), - SyscallSucceeds()); - EXPECT_EQ(name_buffer_size, interface_name().size() + 1); - EXPECT_STREQ(name_buffer, interface_name().c_str()); - } else { - EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, - name_buffer, &name_buffer_size), - SyscallFailsWithErrno(EINVAL)); - EXPECT_EQ(name_buffer_size, i); - } - } -} - -INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest, - ::testing::Values(IPv4UDPUnboundSocket(0), - IPv4TCPUnboundSocket(0))); - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc deleted file mode 100644 index 5ed57625c..000000000 --- a/test/syscalls/linux/socket_bind_to_device_distribution.cc +++ /dev/null @@ -1,401 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <linux/if_tun.h> -#include <net/if.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <atomic> -#include <cstdio> -#include <cstring> -#include <map> -#include <memory> -#include <unordered_map> -#include <unordered_set> -#include <utility> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_bind_to_device_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -using std::string; -using std::vector; - -struct EndpointConfig { - std::string bind_to_device; - double expected_ratio; -}; - -struct DistributionTestCase { - std::string name; - std::vector<EndpointConfig> endpoints; -}; - -struct ListenerConnector { - TestAddress listener; - TestAddress connector; -}; - -// Test fixture for SO_BINDTODEVICE tests the distribution of packets received -// with varying SO_BINDTODEVICE settings. -class BindToDeviceDistributionTest - : public ::testing::TestWithParam< - ::testing::tuple<ListenerConnector, DistributionTestCase>> { - protected: - void SetUp() override { - printf("Testing case: %s, listener=%s, connector=%s\n", - ::testing::get<1>(GetParam()).name.c_str(), - ::testing::get<0>(GetParam()).listener.description.c_str(), - ::testing::get<0>(GetParam()).connector.description.c_str()); - ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) - << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; - } -}; - -PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { - switch (family) { - case AF_INET: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); - case AF_INET6: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { - switch (family) { - case AF_INET: - reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; - return NoError(); - case AF_INET6: - reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; - return NoError(); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -// Binds sockets to different devices and then creates many TCP connections. -// Checks that the distribution of connections received on the sockets matches -// the expectation. -TEST_P(BindToDeviceDistributionTest, Tcp) { - auto const& [listener_connector, test] = GetParam(); - - TestAddress const& listener = listener_connector.listener; - TestAddress const& connector = listener_connector.connector; - sockaddr_storage listen_addr = listener.addr; - sockaddr_storage conn_addr = connector.addr; - - auto interface_names = GetInterfaceNames(); - - // Create the listening sockets. - std::vector<FileDescriptor> listener_fds; - std::vector<std::unique_ptr<Tunnel>> all_tunnels; - for (auto const& endpoint : test.endpoints) { - if (!endpoint.bind_to_device.empty() && - interface_names.find(endpoint.bind_to_device) == - interface_names.end()) { - all_tunnels.push_back( - ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); - interface_names.insert(endpoint.bind_to_device); - } - - listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP))); - int fd = listener_fds.back().get(); - - ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, - endpoint.bind_to_device.c_str(), - endpoint.bind_to_device.size() + 1), - SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); - - // On the first bind we need to determine which port was bound. - if (listener_fds.size() > 1) { - continue; - } - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - } - - constexpr int kConnectAttempts = 10000; - std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); - std::vector<int> accept_counts(listener_fds.size(), 0); - std::vector<std::unique_ptr<ScopedThread>> listen_threads( - listener_fds.size()); - - for (int i = 0; i < listener_fds.size(); i++) { - listen_threads[i] = absl::make_unique<ScopedThread>( - [&listener_fds, &accept_counts, &connects_received, i, - kConnectAttempts]() { - do { - auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); - if (!fd.ok()) { - // Another thread has shutdown our read side causing the accept to - // fail. - ASSERT_GE(connects_received, kConnectAttempts) - << "errno = " << fd.error(); - return; - } - // Receive some data from a socket to be sure that the connect() - // system call has been completed on another side. - // Do a short read and then close the socket to trigger a RST. This - // ensures that both ends of the connection are cleaned up and no - // goroutines hang around in TIME-WAIT. We do this so that this test - // does not timeout under gotsan runs where lots of goroutines can - // cause the test to use absurd amounts of memory. - // - // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17 - uint16_t data; - EXPECT_THAT( - RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), - SyscallSucceedsWithValue(sizeof(data))); - accept_counts[i]++; - } while (++connects_received < kConnectAttempts); - - // Shutdown all sockets to wake up other threads. - for (auto const& listener_fd : listener_fds) { - shutdown(listener_fd.get(), SHUT_RDWR); - } - }); - } - - for (int i = 0; i < kConnectAttempts; i++) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT( - RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Do two separate sends to ensure two segments are received. This is - // required for netstack where read is incorrectly assuming a whole - // segment is read when endpoint.Read() is called which is technically - // incorrect as the syscall that invoked endpoint.Read() may only - // consume it partially. This results in a case where a close() of - // such a socket does not trigger a RST in netstack due to the - // endpoint assuming that the endpoint has no unread data. - EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - - // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly - // generates a RST. - if (IsRunningOnGvisor()) { - EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - } - } - - // Join threads to be sure that all connections have been counted. - for (auto const& listen_thread : listen_threads) { - listen_thread->Join(); - } - // Check that connections are distributed correctly among listening sockets. - for (int i = 0; i < accept_counts.size(); i++) { - EXPECT_THAT( - accept_counts[i], - EquivalentWithin(static_cast<int>(kConnectAttempts * - test.endpoints[i].expected_ratio), - 0.10)) - << "endpoint " << i << " got the wrong number of packets"; - } -} - -// Binds sockets to different devices and then sends many UDP packets. Checks -// that the distribution of packets received on the sockets matches the -// expectation. -TEST_P(BindToDeviceDistributionTest, Udp) { - auto const& [listener_connector, test] = GetParam(); - - TestAddress const& listener = listener_connector.listener; - TestAddress const& connector = listener_connector.connector; - sockaddr_storage listen_addr = listener.addr; - sockaddr_storage conn_addr = connector.addr; - - auto interface_names = GetInterfaceNames(); - - // Create the listening socket. - std::vector<FileDescriptor> listener_fds; - std::vector<std::unique_ptr<Tunnel>> all_tunnels; - for (auto const& endpoint : test.endpoints) { - if (!endpoint.bind_to_device.empty() && - interface_names.find(endpoint.bind_to_device) == - interface_names.end()) { - all_tunnels.push_back( - ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); - interface_names.insert(endpoint.bind_to_device); - } - - listener_fds.push_back( - ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0))); - int fd = listener_fds.back().get(); - - ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, - endpoint.bind_to_device.c_str(), - endpoint.bind_to_device.size() + 1), - SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); - - // On the first bind we need to determine which port was bound. - if (listener_fds.size() > 1) { - continue; - } - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - } - - constexpr int kConnectAttempts = 10000; - std::atomic<int> packets_received = ATOMIC_VAR_INIT(0); - std::vector<int> packets_per_socket(listener_fds.size(), 0); - std::vector<std::unique_ptr<ScopedThread>> receiver_threads( - listener_fds.size()); - - for (int i = 0; i < listener_fds.size(); i++) { - receiver_threads[i] = absl::make_unique<ScopedThread>( - [&listener_fds, &packets_per_socket, &packets_received, i]() { - do { - struct sockaddr_storage addr = {}; - socklen_t addrlen = sizeof(addr); - int data; - - auto ret = RetryEINTR(recvfrom)( - listener_fds[i].get(), &data, sizeof(data), 0, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen); - - if (packets_received < kConnectAttempts) { - ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); - } - - if (ret != sizeof(data)) { - // Another thread may have shutdown our read side causing the - // recvfrom to fail. - break; - } - - packets_received++; - packets_per_socket[i]++; - - // A response is required to synchronize with the main thread, - // otherwise the main thread can send more than can fit into receive - // queues. - EXPECT_THAT(RetryEINTR(sendto)( - listener_fds[i].get(), &data, sizeof(data), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(data))); - } while (packets_received < kConnectAttempts); - - // Shutdown all sockets to wake up other threads. - for (auto const& listener_fd : listener_fds) { - shutdown(listener_fd.get(), SHUT_RDWR); - } - }); - } - - for (int i = 0; i < kConnectAttempts; i++) { - FileDescriptor const fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); - EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceedsWithValue(sizeof(i))); - int data; - EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), - SyscallSucceedsWithValue(sizeof(data))); - } - - // Join threads to be sure that all connections have been counted. - for (auto const& receiver_thread : receiver_threads) { - receiver_thread->Join(); - } - // Check that packets are distributed correctly among listening sockets. - for (int i = 0; i < packets_per_socket.size(); i++) { - EXPECT_THAT( - packets_per_socket[i], - EquivalentWithin(static_cast<int>(kConnectAttempts * - test.endpoints[i].expected_ratio), - 0.10)) - << "endpoint " << i << " got the wrong number of packets"; - } -} - -std::vector<DistributionTestCase> GetDistributionTestCases() { - return std::vector<DistributionTestCase>{ - {"Even distribution among sockets not bound to device", - {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}}, - {"Sockets bound to other interfaces get no packets", - {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}}, - {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}}, - {"Even distribution among sockets bound to device", - {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}}, - }; -} - -INSTANTIATE_TEST_SUITE_P( - BindToDeviceTest, BindToDeviceDistributionTest, - ::testing::Combine(::testing::Values( - // Listeners bound to IPv4 addresses refuse - // connections using IPv6 addresses. - ListenerConnector{V4Any(), V4Loopback()}, - ListenerConnector{V4Loopback(), V4MappedLoopback()}), - ::testing::ValuesIn(GetDistributionTestCases()))); - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc deleted file mode 100644 index 637d1151a..000000000 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <linux/capability.h> -#include <linux/if_tun.h> -#include <net/if.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <cstdio> -#include <cstring> -#include <map> -#include <memory> -#include <unordered_map> -#include <unordered_set> -#include <utility> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_bind_to_device_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -using std::string; -using std::vector; - -// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket -// binding. -class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { - protected: - void SetUp() override { - printf("Testing case: %s\n", GetParam().description.c_str()); - ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) - << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; - socket_factory_ = GetParam(); - - interface_names_ = GetInterfaceNames(); - } - - PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const { - return socket_factory_.Create(); - } - - // Gets a device by device_id. If the device_id has been seen before, returns - // the previously returned device. If not, finds or creates a new device. - // Returns an empty string on failure. - void GetDevice(int device_id, string* device_name) { - auto device = devices_.find(device_id); - if (device != devices_.end()) { - *device_name = device->second; - return; - } - - // Need to pick a new device. Try ethernet first. - *device_name = absl::StrCat("eth", next_unused_eth_); - if (interface_names_.find(*device_name) != interface_names_.end()) { - devices_[device_id] = *device_name; - next_unused_eth_++; - return; - } - - // Need to make a new tunnel device. gVisor tests should have enough - // ethernet devices to never reach here. - ASSERT_FALSE(IsRunningOnGvisor()); - // Need a tunnel. - tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New())); - devices_[device_id] = tunnels_.back()->GetName(); - *device_name = devices_[device_id]; - } - - // Release the socket - void ReleaseSocket(int socket_id) { - // Close the socket that was made in a previous action. The socket_id - // indicates which socket to close based on index into the list of actions. - sockets_to_close_.erase(socket_id); - } - - // SetDevice changes the bind_to_device option. It does not bind or re-bind. - void SetDevice(int socket_id, int device_id) { - auto socket_fd = sockets_to_close_[socket_id]->get(); - string device_name; - ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); - EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, - device_name.c_str(), device_name.size() + 1), - SyscallSucceedsWithValue(0)); - } - - // Bind a socket with the reuse options and bind_to_device options. Checks - // that all steps succeed and that the bind command's error matches want. - // Sets the socket_id to uniquely identify the socket bound if it is not - // nullptr. - void BindSocket(bool reuse_port, bool reuse_addr, int device_id = 0, - int want = 0, int* socket_id = nullptr) { - next_socket_id_++; - sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket_fd = sockets_to_close_[next_socket_id_]->get(); - if (socket_id != nullptr) { - *socket_id = next_socket_id_; - } - - // If reuse_port is indicated, do that. - if (reuse_port) { - EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - } - - // If reuse_addr is indicated, do that. - if (reuse_addr) { - EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - } - - // If the device is non-zero, bind to that device. - if (device_id != 0) { - string device_name; - ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); - EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, - device_name.c_str(), device_name.size() + 1), - SyscallSucceedsWithValue(0)); - char get_device[100]; - socklen_t get_device_size = 100; - EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device, - &get_device_size), - SyscallSucceedsWithValue(0)); - } - - struct sockaddr_in addr = {}; - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr.sin_port = port_; - if (want == 0) { - ASSERT_THAT( - bind(socket_fd, reinterpret_cast<const struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - } else { - ASSERT_THAT( - bind(socket_fd, reinterpret_cast<const struct sockaddr*>(&addr), - sizeof(addr)), - SyscallFailsWithErrno(want)); - } - - if (port_ == 0) { - // We don't yet know what port we'll be using so we need to fetch it and - // remember it for future commands. - socklen_t addr_size = sizeof(addr); - ASSERT_THAT( - getsockname(socket_fd, reinterpret_cast<struct sockaddr*>(&addr), - &addr_size), - SyscallSucceeds()); - port_ = addr.sin_port; - } - } - - private: - SocketKind socket_factory_; - // devices maps from the device id in the test case to the name of the device. - std::unordered_map<int, string> devices_; - // These are the tunnels that were created for the test and will be destroyed - // by the destructor. - vector<std::unique_ptr<Tunnel>> tunnels_; - // A list of all interface names before the test started. - std::unordered_set<string> interface_names_; - // The next ethernet device to use when requested a device. - int next_unused_eth_ = 1; - // The port for all tests. Originally 0 (any) and later set to the port that - // all further commands will use. - in_port_t port_ = 0; - // sockets_to_close_ is a map from action index to the socket that was - // created. - std::unordered_map<int, - std::unique_ptr<gvisor::testing::FileDescriptor>> - sockets_to_close_; - int next_socket_id_ = 0; -}; - -TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 3, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindToDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 1)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 2)); -} - -TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindWithDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 123, 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 456, 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 789, 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindWithReuse) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reusePort */ true, /* reuse_addr */ false)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, - /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0)); -} - -TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse_port */ true, /* reuse_addr */ false)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 999, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 123, 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 456, 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 789, 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 999, 0)); -} - -TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindAndRelease) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); - int to_release; - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, 0, &to_release)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 345, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789)); - // Release the bind to device 0 and try again. - ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 345)); -} - -TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reusePort */ false, /* reuse_addr */ true)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); -} - -TEST_P(BindToDeviceSequenceTest, - CannotBindTo0AfterMixingReuseAddrAndNotReuseAddr) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ true, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ true, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ true, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ true, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ true, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ true, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ true, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ true, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ true, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindReusePortThenReuseAddrReusePort) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ true, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ true, - /* bind_to_device */ 0, EADDRINUSE)); -} - -TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddr) { - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0, EADDRINUSE)); -} - -// This behavior seems like a bug? -TEST_P(BindToDeviceSequenceTest, - BindReuseAddrThenReuseAddrReusePortThenReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ true, - /* bind_to_device */ 0)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, - /* reuse_addr */ false, - /* bind_to_device */ 0)); -} - -// Repro test for gvisor.dev/issue/1217. Not replicated in ports_test.go as this -// test is different from the others and wouldn't fit well there. -TEST_P(BindToDeviceSequenceTest, BindAndReleaseDifferentDevice) { - int to_release; - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 3, 0, &to_release)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, - /* reuse_addr */ false, - /* bind_to_device */ 3, EADDRINUSE)); - // Change the device. Since the socket was already bound, this should have no - // effect. - SetDevice(to_release, 2); - // Release the bind to device 3 and try again. - ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); - ASSERT_NO_FATAL_FAILURE(BindSocket( - /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3)); -} - -INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest, - ::testing::Values(IPv4UDPUnboundSocket(0), - IPv4TCPUnboundSocket(0))); - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc deleted file mode 100644 index f4ee775bd..000000000 --- a/test/syscalls/linux/socket_bind_to_device_util.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_bind_to_device_util.h" - -#include <arpa/inet.h> -#include <fcntl.h> -#include <linux/if_tun.h> -#include <net/if.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> -#include <unistd.h> - -#include <cstdio> -#include <cstring> -#include <map> -#include <memory> -#include <unordered_map> -#include <unordered_set> -#include <utility> -#include <vector> - -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -using std::string; - -PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) { - int fd; - RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR)); - - // Using `new` to access a non-public constructor. - auto new_tunnel = absl::WrapUnique(new Tunnel(fd)); - - ifreq ifr = {}; - ifr.ifr_flags = IFF_TUN; - strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name)); - - RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr)); - new_tunnel->name_ = ifr.ifr_name; - return new_tunnel; -} - -std::unordered_set<string> GetInterfaceNames() { - struct if_nameindex* interfaces = if_nameindex(); - std::unordered_set<string> names; - if (interfaces == nullptr) { - return names; - } - for (auto interface = interfaces; - interface->if_index != 0 || interface->if_name != nullptr; interface++) { - names.insert(interface->if_name); - } - if_freenameindex(interfaces); - return names; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h deleted file mode 100644 index f941ccc86..000000000 --- a/test/syscalls/linux/socket_bind_to_device_util.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ -#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ - -#include <arpa/inet.h> -#include <linux/if_tun.h> -#include <net/if.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> -#include <unistd.h> - -#include <cstdio> -#include <cstring> -#include <map> -#include <memory> -#include <string> -#include <unordered_map> -#include <unordered_set> -#include <utility> -#include <vector> - -#include "absl/memory/memory.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -class Tunnel { - public: - static PosixErrorOr<std::unique_ptr<Tunnel>> New( - std::string tunnel_name = ""); - const std::string& GetName() const { return name_; } - - ~Tunnel() { - if (fd_ != -1) { - close(fd_); - } - } - - private: - Tunnel(int fd) : fd_(fd) {} - int fd_ = -1; - std::string name_; -}; - -std::unordered_set<std::string> GetInterfaceNames(); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ diff --git a/test/syscalls/linux/socket_blocking.cc b/test/syscalls/linux/socket_blocking.cc deleted file mode 100644 index 7e88aa2d9..000000000 --- a/test/syscalls/linux/socket_blocking.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_blocking.h" - -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <cstdio> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(BlockingSocketPairTest, RecvBlocks) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[100]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - constexpr auto kDuration = absl::Milliseconds(200); - auto before = Now(CLOCK_MONOTONIC); - - const ScopedThread t([&]() { - absl::SleepFor(kDuration); - ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - }); - - char received_data[sizeof(sent_data)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - auto after = Now(CLOCK_MONOTONIC); - EXPECT_GE(after - before, kDuration); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_blocking.h b/test/syscalls/linux/socket_blocking.h deleted file mode 100644 index db26e5ef5..000000000 --- a/test/syscalls/linux/socket_blocking.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_BLOCKING_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_BLOCKING_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of blocking connected sockets. -using BlockingSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_BLOCKING_H_ diff --git a/test/syscalls/linux/socket_filesystem.cc b/test/syscalls/linux/socket_filesystem.cc deleted file mode 100644 index 287359363..000000000 --- a/test/syscalls/linux/socket_filesystem.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_generic.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/socket_unix.h" -#include "test/syscalls/linux/socket_unix_cmsg.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVec<SocketPairKind>( - FilesystemBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P( - FilesystemUnixSockets, AllSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - FilesystemUnixSockets, UnixSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - FilesystemUnixSockets, UnixSocketPairCmsgTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc deleted file mode 100644 index f7d6139f1..000000000 --- a/test/syscalls/linux/socket_generic.cc +++ /dev/null @@ -1,820 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_generic.h" - -#include <stdio.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -// This file is a generic socket test file. It must be built with another file -// that provides the test types. - -namespace gvisor { -namespace testing { - -TEST_P(AllSocketPairTest, BasicReadWrite) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char buf[20]; - const std::string data = "abc"; - ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), 3), - SyscallSucceedsWithValue(3)); - ASSERT_THAT(ReadFd(sockets->second_fd(), buf, 3), - SyscallSucceedsWithValue(3)); - EXPECT_EQ(data, absl::string_view(buf, 3)); -} - -TEST_P(AllSocketPairTest, BasicSendRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data)]; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(AllSocketPairTest, BasicSendmmsg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[200]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - std::vector<struct mmsghdr> msgs(10); - std::vector<struct iovec> iovs(msgs.size()); - const int chunk_size = sizeof(sent_data) / msgs.size(); - for (size_t i = 0; i < msgs.size(); i++) { - iovs[i].iov_len = chunk_size; - iovs[i].iov_base = &sent_data[i * chunk_size]; - msgs[i].msg_hdr.msg_iov = &iovs[i]; - msgs[i].msg_hdr.msg_iovlen = 1; - } - - ASSERT_THAT( - RetryEINTR(sendmmsg)(sockets->first_fd(), &msgs[0], msgs.size(), 0), - SyscallSucceedsWithValue(msgs.size())); - - for (const struct mmsghdr& msg : msgs) { - EXPECT_EQ(chunk_size, msg.msg_len); - } - - char received_data[sizeof(sent_data)]; - for (size_t i = 0; i < msgs.size(); i++) { - ASSERT_THAT(ReadFd(sockets->second_fd(), &received_data[i * chunk_size], - chunk_size), - SyscallSucceedsWithValue(chunk_size)); - } - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(AllSocketPairTest, BasicRecvmmsg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[200]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - char received_data[sizeof(sent_data)]; - std::vector<struct mmsghdr> msgs(10); - std::vector<struct iovec> iovs(msgs.size()); - const int chunk_size = sizeof(sent_data) / msgs.size(); - for (size_t i = 0; i < msgs.size(); i++) { - iovs[i].iov_len = chunk_size; - iovs[i].iov_base = &received_data[i * chunk_size]; - msgs[i].msg_hdr.msg_iov = &iovs[i]; - msgs[i].msg_hdr.msg_iovlen = 1; - } - - for (size_t i = 0; i < msgs.size(); i++) { - ASSERT_THAT( - WriteFd(sockets->first_fd(), &sent_data[i * chunk_size], chunk_size), - SyscallSucceedsWithValue(chunk_size)); - } - - ASSERT_THAT(RetryEINTR(recvmmsg)(sockets->second_fd(), &msgs[0], msgs.size(), - 0, nullptr), - SyscallSucceedsWithValue(msgs.size())); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - for (const struct mmsghdr& msg : msgs) { - EXPECT_EQ(chunk_size, msg.msg_len); - } -} - -TEST_P(AllSocketPairTest, SendmsgRecvmsg10KB) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - std::vector<char> sent_data(10 * 1024); - RandomizeBuffer(sent_data.data(), sent_data.size()); - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data.data(), sent_data.size())); - - std::vector<char> received_data(sent_data.size()); - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sockets->second_fd(), received_data.data(), - received_data.size())); - - EXPECT_EQ(0, - memcmp(sent_data.data(), received_data.data(), sent_data.size())); -} - -// This test validates that a sendmsg/recvmsg w/ MSG_CTRUNC is a no-op on -// input flags. -TEST_P(AllSocketPairTest, SendmsgRecvmsgMsgCtruncNoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - std::vector<char> sent_data(10 * 1024); - RandomizeBuffer(sent_data.data(), sent_data.size()); - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data.data(), sent_data.size())); - - std::vector<char> received_data(sent_data.size()); - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) + CMSG_SPACE(sizeof(struct ucred))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - iov.iov_base = &received_data[0]; - iov.iov_len = received_data.size(); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - // MSG_CTRUNC should be a no-op. - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CTRUNC), - SyscallSucceedsWithValue(received_data.size())); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - EXPECT_EQ(cmsg, nullptr); - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(0, - memcmp(sent_data.data(), received_data.data(), sent_data.size())); -} - -TEST_P(AllSocketPairTest, SendmsgRecvmsg16KB) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - std::vector<char> sent_data(16 * 1024); - RandomizeBuffer(sent_data.data(), sent_data.size()); - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data.data(), sent_data.size())); - - std::vector<char> received_data(sent_data.size()); - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sockets->second_fd(), received_data.data(), - received_data.size())); - - EXPECT_EQ(0, - memcmp(sent_data.data(), received_data.data(), sent_data.size())); -} - -TEST_P(AllSocketPairTest, RecvmsgMsghdrFlagsNotClearedOnFailure) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char received_data[10] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); - - // Check that msghdr flags were not changed. - EXPECT_EQ(msg.msg_flags, -1); -} - -TEST_P(AllSocketPairTest, RecvmsgMsghdrFlagsCleared) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data)] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(sent_data))); - - // Check that msghdr flags were cleared. - EXPECT_EQ(msg.msg_flags, 0); -} - -TEST_P(AllSocketPairTest, RecvmsgPeekMsghdrFlagsCleared) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data)] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_PEEK), - SyscallSucceedsWithValue(sizeof(sent_data))); - EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(sent_data))); - - // Check that msghdr flags were cleared. - EXPECT_EQ(msg.msg_flags, 0); -} - -TEST_P(AllSocketPairTest, RecvmsgIovNotUpdated) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data) * 2] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(sent_data))); - - // Check that the iovec length was not updated. - EXPECT_EQ(msg.msg_iov->iov_len, sizeof(received_data)); -} - -TEST_P(AllSocketPairTest, RecvmmsgInvalidTimeout) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char buf[10]; - struct mmsghdr msg = {}; - struct iovec iov = {}; - iov.iov_len = sizeof(buf); - iov.iov_base = buf; - msg.msg_hdr.msg_iov = &iov; - msg.msg_hdr.msg_iovlen = 1; - struct timespec timeout = {-1, -1}; - ASSERT_THAT(RetryEINTR(recvmmsg)(sockets->first_fd(), &msg, 1, 0, &timeout), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(AllSocketPairTest, RecvmmsgTimeoutBeforeRecv) { - // There is a known bug in the Linux recvmmsg(2) causing it to block forever - // if the timeout expires while blocking for the first message. - SKIP_IF(!IsRunningOnGvisor()); - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char buf[10]; - struct mmsghdr msg = {}; - struct iovec iov = {}; - iov.iov_len = sizeof(buf); - iov.iov_base = buf; - msg.msg_hdr.msg_iov = &iov; - msg.msg_hdr.msg_iovlen = 1; - struct timespec timeout = {}; - ASSERT_THAT(RetryEINTR(recvmmsg)(sockets->first_fd(), &msg, 1, 0, &timeout), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, MsgPeek) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[50]; - memset(&sent_data, 0, sizeof(sent_data)); - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data)]; - for (int i = 0; i < 3; i++) { - memset(received_data, 0, sizeof(received_data)); - EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_PEEK), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data))); - } - - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data))); -} - -TEST_P(AllSocketPairTest, LingerSocketOption) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - struct linger got_linger = {-1, -1}; - socklen_t length = sizeof(struct linger); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, - &got_linger, &length), - SyscallSucceedsWithValue(0)); - struct linger want_linger = {}; - EXPECT_EQ(0, memcmp(&want_linger, &got_linger, sizeof(struct linger))); - EXPECT_EQ(sizeof(struct linger), length); -} - -TEST_P(AllSocketPairTest, KeepAliveSocketOption) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - int keepalive = -1; - socklen_t length = sizeof(int); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, - &keepalive, &length), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(0, keepalive); - EXPECT_EQ(sizeof(int), length); -} - -TEST_P(AllSocketPairTest, RcvBufSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - int size = 0; - socklen_t size_size = sizeof(size); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &size, &size_size), - SyscallSucceeds()); - EXPECT_GT(size, 0); -} - -TEST_P(AllSocketPairTest, SndBufSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - int size = 0; - socklen_t size_size = sizeof(size); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &size, &size_size), - SyscallSucceeds()); - EXPECT_GT(size, 0); -} - -TEST_P(AllSocketPairTest, RecvTimeoutReadSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - char buf[20] = {}; - EXPECT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutRecvSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - char buf[20] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutRecvOneSecondSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 1, .tv_usec = 0 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - char buf[20] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - struct msghdr msg = {}; - char buf[20] = {}; - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - EXPECT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, SendTimeoutDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - timeval actual_tv = {.tv_sec = -1, .tv_usec = -1}; - socklen_t len = sizeof(actual_tv); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, - &actual_tv, &len), - SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv_sec, 0); - EXPECT_EQ(actual_tv.tv_usec, 0); -} - -TEST_P(AllSocketPairTest, SetGetSendTimeout) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - timeval tv = {.tv_sec = 89, .tv_usec = 42000}; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - timeval actual_tv = {}; - socklen_t len = sizeof(actual_tv); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, - &actual_tv, &len), - SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv_sec, 89); - EXPECT_EQ(actual_tv.tv_usec, 42000); -} - -TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval_with_extra { - struct timeval tv; - int64_t extra_data; - } ABSL_ATTRIBUTE_PACKED; - - timeval_with_extra tv_extra = { - .tv = {.tv_sec = 0, .tv_usec = 123000}, - }; - - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, - &tv_extra, sizeof(tv_extra)), - SyscallSucceeds()); - - timeval_with_extra actual_tv = {}; - socklen_t len = sizeof(actual_tv); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, - &actual_tv, &len), - SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv.tv_sec, 0); - EXPECT_EQ(actual_tv.tv.tv_usec, 123000); -} - -TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - char buf[20] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(AllSocketPairTest, SendTimeoutAllowsSend) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - char buf[20] = {}; - ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - char buf[20] = {}; - ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf))); -} - -TEST_P(AllSocketPairTest, RecvTimeoutDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - timeval actual_tv = {.tv_sec = -1, .tv_usec = -1}; - socklen_t len = sizeof(actual_tv); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, - &actual_tv, &len), - SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv_sec, 0); - EXPECT_EQ(actual_tv.tv_usec, 0); -} - -TEST_P(AllSocketPairTest, SetGetRecvTimeout) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - timeval tv = {.tv_sec = 123, .tv_usec = 456000}; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - timeval actual_tv = {}; - socklen_t len = sizeof(actual_tv); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, - &actual_tv, &len), - SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv_sec, 123); - EXPECT_EQ(actual_tv.tv_usec, 456000); -} - -TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval_with_extra { - struct timeval tv; - int64_t extra_data; - } ABSL_ATTRIBUTE_PACKED; - - timeval_with_extra tv_extra = { - .tv = {.tv_sec = 0, .tv_usec = 432000}, - }; - - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, - &tv_extra, sizeof(tv_extra)), - SyscallSucceeds()); - - timeval_with_extra actual_tv = {}; - socklen_t len = sizeof(actual_tv); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, - &actual_tv, &len), - SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv.tv_sec, 0); - EXPECT_EQ(actual_tv.tv.tv_usec, 432000); -} - -TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 1, .tv_usec = 0 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - struct msghdr msg = {}; - char buf[20] = {}; - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - EXPECT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutUsecTooLarge) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 2000000 // 2 seconds. - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallFailsWithErrno(EDOM)); -} - -TEST_P(AllSocketPairTest, SendTimeoutUsecTooLarge) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 2000000 // 2 seconds. - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), - SyscallFailsWithErrno(EDOM)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutUsecNeg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = -1 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallFailsWithErrno(EDOM)); -} - -TEST_P(AllSocketPairTest, SendTimeoutUsecNeg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = -1 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), - SyscallFailsWithErrno(EDOM)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutNegSecRead) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = -1, .tv_usec = 0 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - char buf[20] = {}; - EXPECT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutNegSecRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = -1, .tv_usec = 0 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - char buf[20] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutNegSecRecvmsg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = -1, .tv_usec = 0 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - struct msghdr msg = {}; - char buf[20] = {}; - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - EXPECT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, RecvWaitAll) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[100]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_WAITALL), - SyscallSucceedsWithValue(sizeof(sent_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(AllSocketPairTest, RecvWaitAllDontWait) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char data[100] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), data, sizeof(data), - MSG_WAITALL | MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(AllSocketPairTest, RecvTimeoutWaitAll) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 200000 // 200ms - }; - EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, - sizeof(tv)), - SyscallSucceeds()); - - char sent_data[100]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data) * 2] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_WAITALL), - SyscallSucceedsWithValue(sizeof(sent_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(AllSocketPairTest, GetSockoptType) { - int type = GetParam().type; - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - for (const int fd : {sockets->first_fd(), sockets->second_fd()}) { - int opt; - socklen_t optlen = sizeof(opt); - EXPECT_THAT(getsockopt(fd, SOL_SOCKET, SO_TYPE, &opt, &optlen), - SyscallSucceeds()); - - // Type may have SOCK_NONBLOCK and SOCK_CLOEXEC ORed into it. Remove these - // before comparison. - type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC); - EXPECT_EQ(opt, type) << absl::StrFormat( - "getsockopt(%d, SOL_SOCKET, SO_TYPE, &opt, &optlen) => opt=%d was " - "unexpected", - fd, opt); - } -} - -TEST_P(AllSocketPairTest, GetSockoptDomain) { - const int domain = GetParam().domain; - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - for (const int fd : {sockets->first_fd(), sockets->second_fd()}) { - int opt; - socklen_t optlen = sizeof(opt); - EXPECT_THAT(getsockopt(fd, SOL_SOCKET, SO_DOMAIN, &opt, &optlen), - SyscallSucceeds()); - EXPECT_EQ(opt, domain) << absl::StrFormat( - "getsockopt(%d, SOL_SOCKET, SO_DOMAIN, &opt, &optlen) => opt=%d was " - "unexpected", - fd, opt); - } -} - -TEST_P(AllSocketPairTest, GetSockoptProtocol) { - const int protocol = GetParam().protocol; - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - for (const int fd : {sockets->first_fd(), sockets->second_fd()}) { - int opt; - socklen_t optlen = sizeof(opt); - EXPECT_THAT(getsockopt(fd, SOL_SOCKET, SO_PROTOCOL, &opt, &optlen), - SyscallSucceeds()); - EXPECT_EQ(opt, protocol) << absl::StrFormat( - "getsockopt(%d, SOL_SOCKET, SO_PROTOCOL, &opt, &optlen) => opt=%d was " - "unexpected", - fd, opt); - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_generic.h b/test/syscalls/linux/socket_generic.h deleted file mode 100644 index 00ae7bfc3..000000000 --- a/test/syscalls/linux/socket_generic.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_GENERIC_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_GENERIC_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of blocking and non-blocking -// connected stream sockets. -using AllSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_GENERIC_H_ diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc deleted file mode 100644 index 6a232238d..000000000 --- a/test/syscalls/linux/socket_generic_stress.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected sockets. -using ConnectStressTest = SocketPairTest; - -TEST_P(ConnectStressTest, Reset65kTimes) { - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Send some data to ensure that the connection gets reset and the port gets - // released immediately. This avoids either end entering TIME-WAIT. - char sent_data[100] = {}; - ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - } -} - -INSTANTIATE_TEST_SUITE_P( - AllConnectedSockets, ConnectStressTest, - ::testing::Values(IPv6UDPBidirectionalBindSocketPair(0), - IPv4UDPBidirectionalBindSocketPair(0), - DualStackUDPBidirectionalBindSocketPair(0), - - // Without REUSEADDR, we get port exhaustion on Linux. - SetSockOpt(SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn)(IPv6TCPAcceptBindSocketPair(0)), - SetSockOpt(SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn)(IPv4TCPAcceptBindSocketPair(0)), - SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( - DualStackTCPAcceptBindSocketPair(0)))); - -// Test fixture for tests that apply to pairs of connected sockets created with -// a persistent listener (if applicable). -using PersistentListenerConnectStressTest = SocketPairTest; - -TEST_P(PersistentListenerConnectStressTest, 65kTimes) { - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - } -} - -INSTANTIATE_TEST_SUITE_P( - AllConnectedSockets, PersistentListenerConnectStressTest, - ::testing::Values( - IPv6UDPBidirectionalBindSocketPair(0), - IPv4UDPBidirectionalBindSocketPair(0), - DualStackUDPBidirectionalBindSocketPair(0), - - // Without REUSEADDR, we get port exhaustion on Linux. - SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( - IPv6TCPAcceptBindPersistentListenerSocketPair(0)), - SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( - IPv4TCPAcceptBindPersistentListenerSocketPair(0)), - SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( - DualStackTCPAcceptBindPersistentListenerSocketPair(0)))); - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc deleted file mode 100644 index b24618a88..000000000 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ /dev/null @@ -1,2264 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <poll.h> -#include <string.h> - -#include <atomic> -#include <iostream> -#include <memory> -#include <string> -#include <tuple> -#include <utility> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -using ::testing::Gt; - -PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { - switch (family) { - case AF_INET: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); - case AF_INET6: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { - switch (family) { - case AF_INET: - reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; - return NoError(); - case AF_INET6: - reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; - return NoError(); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -struct TestParam { - TestAddress listener; - TestAddress connector; -}; - -std::string DescribeTestParam(::testing::TestParamInfo<TestParam> const& info) { - return absl::StrCat("Listen", info.param.listener.description, "_Connect", - info.param.connector.description); -} - -using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>; - -TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { - int fd[2] = {}; - - // Valid AF but invalid for socketpair(2) return ESOCKTNOSUPPORT. - ASSERT_THAT(socketpair(AF_INET, 0, 0, fd), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - ASSERT_THAT(socketpair(AF_INET6, 0, 0, fd), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - - // Invalid AF will return ENOAFSUPPORT. - ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), - SyscallFailsWithErrno(EAFNOSUPPORT)); - ASSERT_THAT(socketpair(8675309, 0, 0, fd), - SyscallFailsWithErrno(EAFNOSUPPORT)); -} - -enum class Operation { - Bind, - Connect, - SendTo, -}; - -std::string OperationToString(Operation operation) { - switch (operation) { - case Operation::Bind: - return "Bind"; - case Operation::Connect: - return "Connect"; - case Operation::SendTo: - return "SendTo"; - } -} - -using OperationSequence = std::vector<Operation>; - -using DualStackSocketTest = - ::testing::TestWithParam<std::tuple<TestAddress, OperationSequence>>; - -TEST_P(DualStackSocketTest, AddressOperations) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, 0)); - - const TestAddress& addr = std::get<0>(GetParam()); - const OperationSequence& operations = std::get<1>(GetParam()); - - auto addr_in = reinterpret_cast<const sockaddr*>(&addr.addr); - - // sockets may only be bound once. Both `connect` and `sendto` cause a socket - // to be bound. - bool bound = false; - for (const Operation& operation : operations) { - bool sockname = false; - bool peername = false; - switch (operation) { - case Operation::Bind: { - ASSERT_NO_ERRNO(SetAddrPort( - addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 0)); - - int bind_ret = bind(fd.get(), addr_in, addr.addr_len); - - // Dual stack sockets may only be bound to AF_INET6. - if (!bound && addr.family() == AF_INET6) { - EXPECT_THAT(bind_ret, SyscallSucceeds()); - bound = true; - - sockname = true; - } else { - EXPECT_THAT(bind_ret, SyscallFailsWithErrno(EINVAL)); - } - break; - } - case Operation::Connect: { - ASSERT_NO_ERRNO(SetAddrPort( - addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337)); - - EXPECT_THAT(connect(fd.get(), addr_in, addr.addr_len), - SyscallSucceeds()) - << GetAddrStr(addr_in); - bound = true; - - sockname = true; - peername = true; - - break; - } - case Operation::SendTo: { - const char payload[] = "hello"; - ASSERT_NO_ERRNO(SetAddrPort( - addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337)); - - ssize_t sendto_ret = sendto(fd.get(), &payload, sizeof(payload), 0, - addr_in, addr.addr_len); - - EXPECT_THAT(sendto_ret, SyscallSucceedsWithValue(sizeof(payload))); - sockname = !bound; - bound = true; - break; - } - } - - if (sockname) { - sockaddr_storage sock_addr; - socklen_t addrlen = sizeof(sock_addr); - ASSERT_THAT(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&sock_addr), - &addrlen), - SyscallSucceeds()); - ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); - - auto sock_addr_in6 = reinterpret_cast<const sockaddr_in6*>(&sock_addr); - - if (operation == Operation::SendTo) { - EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6); - EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32)) - << OperationToString(operation) << " getsocknam=" - << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr)); - - EXPECT_NE(sock_addr_in6->sin6_port, 0); - } else if (IN6_IS_ADDR_V4MAPPED( - reinterpret_cast<const sockaddr_in6*>(addr_in) - ->sin6_addr.s6_addr32)) { - EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32)) - << OperationToString(operation) << " getsocknam=" - << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr)); - } - } - - if (peername) { - sockaddr_storage peer_addr; - socklen_t addrlen = sizeof(peer_addr); - ASSERT_THAT(getpeername(fd.get(), reinterpret_cast<sockaddr*>(&peer_addr), - &addrlen), - SyscallSucceeds()); - ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); - - if (addr.family() == AF_INET || - IN6_IS_ADDR_V4MAPPED(reinterpret_cast<const sockaddr_in6*>(addr_in) - ->sin6_addr.s6_addr32)) { - EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED( - reinterpret_cast<const sockaddr_in6*>(&peer_addr) - ->sin6_addr.s6_addr32)) - << OperationToString(operation) << " getpeername=" - << GetAddrStr(reinterpret_cast<sockaddr*>(&peer_addr)); - } - } - } -} - -// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny. -INSTANTIATE_TEST_SUITE_P( - All, DualStackSocketTest, - ::testing::Combine( - ::testing::Values(V4Any(), V4Loopback(), /*V4MappedAny(),*/ - V4MappedLoopback(), V6Any(), V6Loopback()), - ::testing::ValuesIn<OperationSequence>( - {{Operation::Bind, Operation::Connect, Operation::SendTo}, - {Operation::Bind, Operation::SendTo, Operation::Connect}, - {Operation::Connect, Operation::Bind, Operation::SendTo}, - {Operation::Connect, Operation::SendTo, Operation::Bind}, - {Operation::SendTo, Operation::Bind, Operation::Connect}, - {Operation::SendTo, Operation::Connect, Operation::Bind}})), - [](::testing::TestParamInfo< - std::tuple<TestAddress, OperationSequence>> const& info) { - const TestAddress& addr = std::get<0>(info.param); - const OperationSequence& operations = std::get<1>(info.param); - std::string s = addr.description; - for (const Operation& operation : operations) { - absl::StrAppend(&s, OperationToString(operation)); - } - return s; - }); - -void tcpSimpleConnectTest(TestAddress const& listener, - TestAddress const& connector, bool unbound) { - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - if (!unbound) { - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - } - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Connect to the listening socket. - const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Accept the connection. - // - // We have to assign a name to the accepted socket, as unamed temporary - // objects are destructed upon full evaluation of the expression it is in, - // potentially causing the connecting socket to fail to shutdown properly. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RDWR), SyscallSucceeds()); - - ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds()); -} - -TEST_P(SocketInetLoopbackTest, TCP) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - tcpSimpleConnectTest(listener, connector, true); -} - -TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { - auto const& param = GetParam(); - - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - tcpSimpleConnectTest(listener, connector, false); -} - -TEST_P(SocketInetLoopbackTest, TCPListenClose) { - auto const& param = GetParam(); - - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - constexpr int kAcceptCount = 32; - constexpr int kBacklog = kAcceptCount * 2; - constexpr int kFDs = 128; - constexpr int kThreadCount = 4; - constexpr int kFDsPerThread = kFDs / kThreadCount; - - // Create the listening socket. - FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - DisableSave ds; // Too many system calls. - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - FileDescriptor clients[kFDs]; - std::unique_ptr<ScopedThread> threads[kThreadCount]; - for (int i = 0; i < kFDs; i++) { - clients[i] = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - } - for (int i = 0; i < kThreadCount; i++) { - threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr, - &clients, i]() { - for (int j = 0; j < kFDsPerThread; j++) { - int k = i * kFDsPerThread + j; - int ret = - connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - } - } - }); - } - for (int i = 0; i < kThreadCount; i++) { - threads[i]->Join(); - } - for (int i = 0; i < kAcceptCount; i++) { - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - } - // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked - // before function end. - // ds.reset(); -} - -TEST_P(SocketInetLoopbackTest, TCPbacklog) { - auto const& param = GetParam(); - - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), 2), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - int i = 0; - while (1) { - int ret; - - // Connect to the listening socket. - const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ret = connect(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - struct pollfd pfd = { - .fd = conn_fd.get(), - .events = POLLOUT, - }; - ret = poll(&pfd, 1, 3000); - if (ret == 0) break; - EXPECT_THAT(ret, SyscallSucceedsWithValue(1)); - } - EXPECT_THAT(RetryEINTR(send)(conn_fd.get(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds()); - i++; - } - - for (; i != 0; i--) { - // Accept the connection. - // - // We have to assign a name to the accepted socket, as unamed temporary - // objects are destructed upon full evaluation of the expression it is in, - // potentially causing the connecting socket to fail to shutdown properly. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - } -} - -// TCPFinWait2Test creates a pair of connected sockets then closes one end to -// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local -// IP/port on a new socket and tries to connect. The connect should fail w/ -// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the -// connect again with a new socket and this time it should succeed. -// -// TCP timers are not S/R today, this can cause this test to be flaky when run -// under random S/R due to timer being reset on a restore. -TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - // Lower FIN_WAIT2 state to 5 seconds for test. - constexpr int kTCPLingerTimeout = 5; - EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, - &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), - SyscallSucceedsWithValue(0)); - - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; - ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - &conn_addrlen), - SyscallSucceeds()); - - // close the connecting FD to trigger FIN_WAIT2 on the connected fd. - conn_fd.reset(); - - // Now bind and connect a new socket. - const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - // Disable cooperative saves after this point. As a save between the first - // bind/connect and the second one can cause the linger timeout timer to - // be restarted causing the final bind/connect to fail. - DisableSave ds; - - // TODO(gvisor.dev/issue/1030): Portmanager does not track all 5 tuple - // reservations which causes the bind() to succeed on gVisor but connect - // correctly fails. - if (IsRunningOnGvisor()) { - ASSERT_THAT( - bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - conn_addrlen), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), - SyscallFailsWithErrno(EADDRINUSE)); - } else { - ASSERT_THAT( - bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - conn_addrlen), - SyscallFailsWithErrno(EADDRINUSE)); - } - - // Sleep for a little over the linger timeout to reduce flakiness in - // save/restore tests. - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2)); - - ds.reset(); - - if (!IsRunningOnGvisor()) { - ASSERT_THAT( - bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - conn_addrlen), - SyscallSucceeds()); - } - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), - SyscallSucceeds()); -} - -// TCPLinger2TimeoutAfterClose creates a pair of connected sockets -// then closes one end to trigger FIN_WAIT2 state for the closed endpont. -// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ -// connecting the same address succeeds. -// -// TCP timers are not S/R today, this can cause this test to be flaky when run -// under random S/R due to timer being reset on a restore. -TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; - ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - &conn_addrlen), - SyscallSucceeds()); - - constexpr int kTCPLingerTimeout = 5; - EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, - &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), - SyscallSucceedsWithValue(0)); - - // close the connecting FD to trigger FIN_WAIT2 on the connected fd. - conn_fd.reset(); - - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); - - // Now bind and connect a new socket and verify that we can immediately - // rebind the address bound by the conn_fd as it never entered TIME_WAIT. - const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - ASSERT_THAT(bind(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), - SyscallSucceeds()); -} - -// TCPResetAfterClose creates a pair of connected sockets then closes -// one end to trigger FIN_WAIT2 state for the closed endpoint verifies -// that we generate RSTs for any new data after the socket is fully -// closed. -TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // close the connecting FD to trigger FIN_WAIT2 on the connected fd. - conn_fd.reset(); - - int data = 1234; - - // Now send data which should trigger a RST as the other end should - // have timed out and closed the socket. - EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), - SyscallSucceeds()); - // Sleep for a shortwhile to get a RST back. - absl::SleepFor(absl::Seconds(1)); - - // Try writing again and we should get an EPIPE back. - EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), - SyscallFailsWithErrno(EPIPE)); - - // Trying to read should return zero as the other end did send - // us a FIN. We do it twice to verify that the RST does not cause an - // ECONNRESET on the read after EOF has been read by applicaiton. - EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), - SyscallSucceedsWithValue(0)); -} - -// This test is disabled under random save as the the restore run -// results in the stack.Seed() being different which can cause -// sequence number of final connect to be one that is considered -// old and can cause the test to be flaky. -TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - // We disable saves after this point as a S/R causes the netstack seed - // to be regenerated which changes what ports/ISN is picked for a given - // tuple (src ip,src port, dst ip, dst port). This can cause the final - // SYN to use a sequence number that looks like one from the current - // connection in TIME_WAIT and will not be accepted causing the test - // to timeout. - // - // TODO(gvisor.dev/issue/940): S/R portSeed/portHint - DisableSave ds; - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; - ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - &conn_addrlen), - SyscallSucceeds()); - - // close the accept FD to trigger TIME_WAIT on the accepted socket which - // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of - // TIME_WAIT. - accepted.reset(); - absl::SleepFor(absl::Seconds(1)); - conn_fd.reset(); - absl::SleepFor(absl::Seconds(1)); - - // Now bind and connect a new socket and verify that we can immediately - // rebind the address bound by the conn_fd as it never entered TIME_WAIT. - const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - ASSERT_THAT(bind(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), - SyscallSucceeds()); -} - -TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - - const uint16_t port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Set the userTimeout on the listening socket. - constexpr int kUserTimeout = 10; - ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &kUserTimeout, sizeof(kUserTimeout)), - SyscallSucceeds()); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - // Verify that the accepted socket inherited the user timeout set on - // listening socket. - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), - SyscallSucceeds()); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kUserTimeout); -} - -// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not -// saved. Enable S/R once issue is fixed. -TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { - // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not - // saved. Enable S/R issue is fixed. - DisableSave ds; - - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - - const uint16_t port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Set the TCP_DEFER_ACCEPT on the listening socket. - constexpr int kTCPDeferAccept = 3; - ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, - &kTCPDeferAccept, sizeof(kTCPDeferAccept)), - SyscallSucceeds()); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Set the listening socket to nonblock so that we can verify that there is no - // connection in queue despite the connect above succeeding since the peer has - // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the - // FD to O_NONBLOCK. - int opts; - ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); - opts |= O_NONBLOCK; - ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); - - ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Set FD back to blocking. - opts &= ~O_NONBLOCK; - ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); - - // Now write some data to the socket. - int data = 0; - ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)), - SyscallSucceedsWithValue(sizeof(data))); - - // This should now cause the connection to complete and be delivered to the - // accept socket. - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // Verify that the accepted socket returns the data written. - int get = -1; - ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0), - SyscallSucceedsWithValue(sizeof(get))); - - EXPECT_EQ(get, data); -} - -// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not -// saved. Enable S/R once issue is fixed. -TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) { - // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not - // saved. Enable S/R once issue is fixed. - DisableSave ds; - - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - - const uint16_t port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Set the TCP_DEFER_ACCEPT on the listening socket. - constexpr int kTCPDeferAccept = 3; - ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, - &kTCPDeferAccept, sizeof(kTCPDeferAccept)), - SyscallSucceeds()); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Set the listening socket to nonblock so that we can verify that there is no - // connection in queue despite the connect above succeeding since the peer has - // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the - // FD to O_NONBLOCK. - int opts; - ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); - opts |= O_NONBLOCK; - ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); - - // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT - // timeout is hit. - absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1)); - ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Set FD back to blocking. - opts &= ~O_NONBLOCK; - ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); - - // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout - // is hit a SYN-ACK should be retransmitted by the listener as a last ditch - // attempt to complete the connection with or without data. - absl::SleepFor(absl::Seconds(2)); - - // Verify that we have a connection that can be accepted even though no - // data was written. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); -} - -INSTANTIATE_TEST_SUITE_P( - All, SocketInetLoopbackTest, - ::testing::Values( - // Listeners bound to IPv4 addresses refuse connections using IPv6 - // addresses. - TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()}, - TestParam{V4Any(), V4MappedAny()}, - TestParam{V4Any(), V4MappedLoopback()}, - TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()}, - TestParam{V4Loopback(), V4MappedLoopback()}, - TestParam{V4MappedAny(), V4Any()}, - TestParam{V4MappedAny(), V4Loopback()}, - TestParam{V4MappedAny(), V4MappedAny()}, - TestParam{V4MappedAny(), V4MappedLoopback()}, - TestParam{V4MappedLoopback(), V4Any()}, - TestParam{V4MappedLoopback(), V4Loopback()}, - TestParam{V4MappedLoopback(), V4MappedLoopback()}, - - // Listeners bound to IN6ADDR_ANY accept all connections. - TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()}, - TestParam{V6Any(), V4MappedAny()}, - TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()}, - TestParam{V6Any(), V6Loopback()}, - - // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 - // addresses. - TestParam{V6Loopback(), V6Any()}, - TestParam{V6Loopback(), V6Loopback()}), - DescribeTestParam); - -using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>; - -// TODO(gvisor.dev/issue/940): Remove _NoRandomSave when portHint/stack.Seed is -// saved/restored. -TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { - auto const& param = GetParam(); - - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - sockaddr_storage listen_addr = listener.addr; - sockaddr_storage conn_addr = connector.addr; - constexpr int kThreadCount = 3; - constexpr int kConnectAttempts = 10000; - - // Create the listening socket. - FileDescriptor listener_fds[kThreadCount]; - for (int i = 0; i < kThreadCount; i++) { - listener_fds[i] = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - int fd = listener_fds[i].get(); - - ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); - - // On the first bind we need to determine which port was bound. - if (i != 0) { - continue; - } - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - } - - std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); - std::unique_ptr<ScopedThread> listen_thread[kThreadCount]; - int accept_counts[kThreadCount] = {}; - // TODO(avagin): figure how to not disable S/R for the whole test. - // We need to take into account that this test executes a lot of system - // calls from many threads. - DisableSave ds; - - for (int i = 0; i < kThreadCount; i++) { - listen_thread[i] = absl::make_unique<ScopedThread>( - [&listener_fds, &accept_counts, i, &connects_received]() { - do { - auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); - if (!fd.ok()) { - if (connects_received >= kConnectAttempts) { - // Another thread have shutdown our read side causing the - // accept to fail. - break; - } - ASSERT_NO_ERRNO(fd); - break; - } - // Receive some data from a socket to be sure that the connect() - // system call has been completed on another side. - // Do a short read and then close the socket to trigger a RST. This - // ensures that both ends of the connection are cleaned up and no - // goroutines hang around in TIME-WAIT. We do this so that this test - // does not timeout under gotsan runs where lots of goroutines can - // cause the test to use absurd amounts of memory. - // - // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17 - uint16_t data; - EXPECT_THAT( - RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), - SyscallSucceedsWithValue(sizeof(data))); - accept_counts[i]++; - } while (++connects_received < kConnectAttempts); - - // Shutdown all sockets to wake up other threads. - for (int j = 0; j < kThreadCount; j++) { - shutdown(listener_fds[j].get(), SHUT_RDWR); - } - }); - } - - ScopedThread connecting_thread([&connector, &conn_addr]() { - for (int i = 0; i < kConnectAttempts; i++) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT( - RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Do two separate sends to ensure two segments are received. This is - // required for netstack where read is incorrectly assuming a whole - // segment is read when endpoint.Read() is called which is technically - // incorrect as the syscall that invoked endpoint.Read() may only - // consume it partially. This results in a case where a close() of - // such a socket does not trigger a RST in netstack due to the - // endpoint assuming that the endpoint has no unread data. - EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - - // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly - // generates a RST. - if (IsRunningOnGvisor()) { - EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - } - } - }); - - // Join threads to be sure that all connections have been counted - connecting_thread.Join(); - for (int i = 0; i < kThreadCount; i++) { - listen_thread[i]->Join(); - } - // Check that connections are distributed fairly between listening sockets - for (int i = 0; i < kThreadCount; i++) - EXPECT_THAT(accept_counts[i], - EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); -} - -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { - auto const& param = GetParam(); - - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - sockaddr_storage listen_addr = listener.addr; - sockaddr_storage conn_addr = connector.addr; - constexpr int kThreadCount = 3; - - // Create the listening socket. - FileDescriptor listener_fds[kThreadCount]; - for (int i = 0; i < kThreadCount; i++) { - listener_fds[i] = - ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); - int fd = listener_fds[i].get(); - - ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); - - // On the first bind we need to determine which port was bound. - if (i != 0) { - continue; - } - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - } - - constexpr int kConnectAttempts = 10000; - std::atomic<int> packets_received = ATOMIC_VAR_INIT(0); - std::unique_ptr<ScopedThread> receiver_thread[kThreadCount]; - int packets_per_socket[kThreadCount] = {}; - // TODO(avagin): figure how to not disable S/R for the whole test. - DisableSave ds; // Too expensive. - - for (int i = 0; i < kThreadCount; i++) { - receiver_thread[i] = absl::make_unique<ScopedThread>( - [&listener_fds, &packets_per_socket, i, &packets_received]() { - do { - struct sockaddr_storage addr = {}; - socklen_t addrlen = sizeof(addr); - int data; - - auto ret = RetryEINTR(recvfrom)( - listener_fds[i].get(), &data, sizeof(data), 0, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen); - - if (packets_received < kConnectAttempts) { - ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); - } - - if (ret != sizeof(data)) { - // Another thread may have shutdown our read side causing the - // recvfrom to fail. - break; - } - - packets_received++; - packets_per_socket[i]++; - - // A response is required to synchronize with the main thread, - // otherwise the main thread can send more than can fit into receive - // queues. - EXPECT_THAT(RetryEINTR(sendto)( - listener_fds[i].get(), &data, sizeof(data), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(data))); - } while (packets_received < kConnectAttempts); - - // Shutdown all sockets to wake up other threads. - for (int j = 0; j < kThreadCount; j++) - shutdown(listener_fds[j].get(), SHUT_RDWR); - }); - } - - ScopedThread main_thread([&connector, &conn_addr]() { - for (int i = 0; i < kConnectAttempts; i++) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); - EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceedsWithValue(sizeof(i))); - int data; - EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), - SyscallSucceedsWithValue(sizeof(data))); - } - }); - - main_thread.Join(); - - // Join threads to be sure that all connections have been counted - for (int i = 0; i < kThreadCount; i++) { - receiver_thread[i]->Join(); - } - // Check that packets are distributed fairly between listening sockets. - for (int i = 0; i < kThreadCount; i++) - EXPECT_THAT(packets_per_socket[i], - EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); -} - -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { - auto const& param = GetParam(); - - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - sockaddr_storage listen_addr = listener.addr; - sockaddr_storage conn_addr = connector.addr; - constexpr int kThreadCount = 3; - - // TODO(b/141211329): endpointsByNic.seed has to be saved/restored. - const DisableSave ds141211329; - - // Create listening sockets. - FileDescriptor listener_fds[kThreadCount]; - for (int i = 0; i < kThreadCount; i++) { - listener_fds[i] = - ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); - int fd = listener_fds[i].get(); - - ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); - - // On the first bind we need to determine which port was bound. - if (i != 0) { - continue; - } - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - } - - constexpr int kConnectAttempts = 10; - FileDescriptor client_fds[kConnectAttempts]; - - // Do the first run without save/restore. - DisableSave ds; - for (int i = 0; i < kConnectAttempts; i++) { - client_fds[i] = - ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); - EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceedsWithValue(sizeof(i))); - } - ds.reset(); - - // Check that a mapping of client and server sockets has - // not been change after save/restore. - for (int i = 0; i < kConnectAttempts; i++) { - EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceedsWithValue(sizeof(i))); - } - - struct pollfd pollfds[kThreadCount]; - for (int i = 0; i < kThreadCount; i++) { - pollfds[i].fd = listener_fds[i].get(); - pollfds[i].events = POLLIN; - } - - std::map<uint16_t, int> portToFD; - - int received = 0; - while (received < kConnectAttempts * 2) { - ASSERT_THAT(poll(pollfds, kThreadCount, -1), - SyscallSucceedsWithValue(Gt(0))); - - for (int i = 0; i < kThreadCount; i++) { - if ((pollfds[i].revents & POLLIN) == 0) { - continue; - } - - received++; - - const int fd = pollfds[i].fd; - struct sockaddr_storage addr = {}; - socklen_t addrlen = sizeof(addr); - int data; - EXPECT_THAT(RetryEINTR(recvfrom)( - fd, &data, sizeof(data), 0, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceedsWithValue(sizeof(data))); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr)); - auto prev_port = portToFD.find(port); - // Check that all packets from one client have been delivered to the - // same server socket. - if (prev_port == portToFD.end()) { - portToFD[port] = fd; - } else { - EXPECT_EQ(portToFD[port], fd); - } - } - } -} - -INSTANTIATE_TEST_SUITE_P( - All, SocketInetReusePortTest, - ::testing::Values( - // Listeners bound to IPv4 addresses refuse connections using IPv6 - // addresses. - TestParam{V4Any(), V4Loopback()}, - TestParam{V4Loopback(), V4MappedLoopback()}, - - // Listeners bound to IN6ADDR_ANY accept all connections. - TestParam{V6Any(), V4Loopback()}, TestParam{V6Any(), V6Loopback()}, - - // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 - // addresses. - TestParam{V6Loopback(), V6Loopback()}), - DescribeTestParam); - -struct ProtocolTestParam { - std::string description; - int type; -}; - -std::string DescribeProtocolTestParam( - ::testing::TestParamInfo<ProtocolTestParam> const& info) { - return info.param.description; -} - -using SocketMultiProtocolInetLoopbackTest = - ::testing::TestWithParam<ProtocolTestParam>; - -TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) { - auto const& param = GetParam(); - - for (int i = 0; true; i++) { - // Bind the v4 loopback on a dual stack socket. - TestAddress const& test_addr_dual = V4MappedLoopback(); - sockaddr_storage addr_dual = test_addr_dual.addr; - const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_dual.family(), param.type, 0)); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); - - // Get the port that we bound. - socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); - - // Verify that we can still bind the v6 loopback on the same port. - TestAddress const& test_addr_v6 = V6Loopback(); - sockaddr_storage addr_v6 = test_addr_v6.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); - const FileDescriptor fd_v6 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - int ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len); - if (ret == -1 && errno == EADDRINUSE) { - // Port may have been in use. - ASSERT_LT(i, 100); // Give up after 100 tries. - continue; - } - ASSERT_THAT(ret, SyscallSucceeds()); - - // Verify that binding the v4 loopback with the same port on a v4 socket - // fails. - TestAddress const& test_addr_v4 = V4Loopback(); - sockaddr_storage addr_v4 = test_addr_v4.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); - const FileDescriptor fd_v4 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // No need to try again. - break; - } -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) { - auto const& param = GetParam(); - - for (int i = 0; true; i++) { - // Bind the v4 any on a dual stack socket. - TestAddress const& test_addr_dual = V4MappedAny(); - sockaddr_storage addr_dual = test_addr_dual.addr; - const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_dual.family(), param.type, 0)); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); - - // Get the port that we bound. - socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); - - // Verify that we can still bind the v6 loopback on the same port. - TestAddress const& test_addr_v6 = V6Loopback(); - sockaddr_storage addr_v6 = test_addr_v6.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); - const FileDescriptor fd_v6 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - int ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len); - if (ret == -1 && errno == EADDRINUSE) { - // Port may have been in use. - ASSERT_LT(i, 100); // Give up after 100 tries. - continue; - } - ASSERT_THAT(ret, SyscallSucceeds()); - - // Verify that binding the v4 loopback with the same port on a v4 socket - // fails. - TestAddress const& test_addr_v4 = V4Loopback(); - sockaddr_storage addr_v4 = test_addr_v4.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); - const FileDescriptor fd_v4 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // No need to try again. - break; - } -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { - auto const& param = GetParam(); - - // Bind the v6 any on a dual stack socket. - TestAddress const& test_addr_dual = V6Any(); - sockaddr_storage addr_dual = test_addr_dual.addr; - const FileDescriptor fd_dual = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); - - // Get the port that we bound. - socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); - - // Verify that binding the v6 loopback with the same port fails. - TestAddress const& test_addr_v6 = V6Loopback(); - sockaddr_storage addr_v6 = test_addr_v6.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); - const FileDescriptor fd_v6 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that binding the v4 loopback on the same port with a v6 socket - // fails. - TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); - sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port)); - const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v4_mapped.family(), param.type, 0)); - ASSERT_THAT( - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that binding the v4 loopback on the same port with a v4 socket - // fails. - TestAddress const& test_addr_v4 = V4Loopback(); - sockaddr_storage addr_v4 = test_addr_v4.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); - const FileDescriptor fd_v4 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { - auto const& param = GetParam(); - - for (int i = 0; true; i++) { - // Bind the v6 any on a v6-only socket. - TestAddress const& test_addr_dual = V6Any(); - sockaddr_storage addr_dual = test_addr_dual.addr; - const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_dual.family(), param.type, 0)); - EXPECT_THAT(setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); - - // Get the port that we bound. - socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); - - // Verify that binding the v6 loopback with the same port fails. - TestAddress const& test_addr_v6 = V6Loopback(); - sockaddr_storage addr_v6 = test_addr_v6.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); - const FileDescriptor fd_v6 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that we can still bind the v4 loopback on the same port. - TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); - sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port)); - const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v4_mapped.family(), param.type, 0)); - int ret = - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len); - if (ret == -1 && errno == EADDRINUSE) { - // Port may have been in use. - ASSERT_LT(i, 100); // Give up after 100 tries. - continue; - } - ASSERT_THAT(ret, SyscallSucceeds()); - - // No need to try again. - break; - } -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { - auto const& param = GetParam(); - - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - - for (int i = 0; true; i++) { - // Bind the v6 loopback on a dual stack socket. - TestAddress const& test_addr = V6Loopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT( - connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that binding the v6 loopback with the same port fails. - TestAddress const& test_addr_v6 = V6Loopback(); - sockaddr_storage addr_v6 = test_addr_v6.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port)); - const FileDescriptor fd_v6 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that binding the v4 any with the same port fails. - TestAddress const& test_addr_v4_any = V4Any(); - sockaddr_storage addr_v4_any = test_addr_v4_any.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, ephemeral_port)); - const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v4_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), - test_addr_v4_any.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that we can still bind the v4 loopback on the same port. - TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); - sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, - ephemeral_port)); - const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v4_mapped.family(), param.type, 0)); - int ret = - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len); - if (ret == -1 && errno == EADDRINUSE) { - // Port may have been in use. - ASSERT_LT(i, 100); // Give up after 100 tries. - continue; - } - EXPECT_THAT(ret, SyscallSucceeds()); - - // No need to try again. - break; - } -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { - auto const& param = GetParam(); - - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); - - // Bind the v6 loopback on a dual stack socket. - TestAddress const& test_addr = V6Loopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(connect(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is not reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallSucceeds()); -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { - auto const& param = GetParam(); - - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - - for (int i = 0; true; i++) { - // Bind the v4 loopback on a dual stack socket. - TestAddress const& test_addr = V4MappedLoopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT( - connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that binding the v4 loopback on the same port with a v4 socket - // fails. - TestAddress const& test_addr_v4 = V4Loopback(); - sockaddr_storage addr_v4 = test_addr_v4.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v4.family(), &addr_v4, ephemeral_port)); - const FileDescriptor fd_v4 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - EXPECT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that binding the v6 any on the same port with a dual-stack socket - // fails. - TestAddress const& test_addr_v6_any = V6Any(); - sockaddr_storage addr_v6_any = test_addr_v6_any.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v6_any.family(), &addr_v6_any, ephemeral_port)); - const FileDescriptor fd_v6_any = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v6_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), - test_addr_v6_any.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // For some reason, binding the TCP v6-only any is flaky on Linux. Maybe we - // tend to run out of ephemeral ports? Regardless, binding the v6 loopback - // seems pretty reliable. Only try to bind the v6-only any on UDP and - // gVisor. - - int ret = -1; - - if (!IsRunningOnGvisor() && param.type == SOCK_STREAM) { - // Verify that we can still bind the v6 loopback on the same port. - TestAddress const& test_addr_v6 = V6Loopback(); - sockaddr_storage addr_v6 = test_addr_v6.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port)); - const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v6.family(), param.type, 0)); - ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len); - } else { - // Verify that we can still bind the v6 any on the same port with a - // v6-only socket. - const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v6_any.family(), param.type, 0)); - EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ret = - bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), - test_addr_v6_any.addr_len); - } - - if (ret == -1 && errno == EADDRINUSE) { - // Port may have been in use. - ASSERT_LT(i, 100); // Give up after 100 tries. - continue; - } - EXPECT_THAT(ret, SyscallSucceeds()); - - // No need to try again. - break; - } -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, - V4MappedEphemeralPortReservedResueAddr) { - auto const& param = GetParam(); - - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); - - // Bind the v4 loopback on a dual stack socket. - TestAddress const& test_addr = V4MappedLoopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); - - ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(connect(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is not reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallSucceeds()); -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { - auto const& param = GetParam(); - - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - - for (int i = 0; true; i++) { - // Bind the v4 loopback on a v4 socket. - TestAddress const& test_addr = V4Loopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT( - connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that binding the v4 loopback on the same port with a v6 socket - // fails. - TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); - sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; - ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, - ephemeral_port)); - const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v4_mapped.family(), param.type, 0)); - EXPECT_THAT( - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // Verify that binding the v6 any on the same port with a dual-stack socket - // fails. - TestAddress const& test_addr_v6_any = V6Any(); - sockaddr_storage addr_v6_any = test_addr_v6_any.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v6_any.family(), &addr_v6_any, ephemeral_port)); - const FileDescriptor fd_v6_any = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v6_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), - test_addr_v6_any.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - - // For some reason, binding the TCP v6-only any is flaky on Linux. Maybe we - // tend to run out of ephemeral ports? Regardless, binding the v6 loopback - // seems pretty reliable. Only try to bind the v6-only any on UDP and - // gVisor. - - int ret = -1; - - if (!IsRunningOnGvisor() && param.type == SOCK_STREAM) { - // Verify that we can still bind the v6 loopback on the same port. - TestAddress const& test_addr_v6 = V6Loopback(); - sockaddr_storage addr_v6 = test_addr_v6.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port)); - const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v6.family(), param.type, 0)); - ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len); - } else { - // Verify that we can still bind the v6 any on the same port with a - // v6-only socket. - const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v6_any.family(), param.type, 0)); - EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ret = - bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), - test_addr_v6_any.addr_len); - } - - if (ret == -1 && errno == EADDRINUSE) { - // Port may have been in use. - ASSERT_LT(i, 100); // Give up after 100 tries. - continue; - } - EXPECT_THAT(ret, SyscallSucceeds()); - - // No need to try again. - break; - } -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { - auto const& param = GetParam(); - - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); - - // Bind the v4 loopback on a v4 socket. - TestAddress const& test_addr = V4Loopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - - ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - - ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT(connect(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is not reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallSucceeds()); -} - -TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { - auto const& param = GetParam(); - TestAddress const& test_addr = V4Loopback(); - sockaddr_storage addr = test_addr.addr; - - for (int i = 0; i < 2; i++) { - const int portreuse1 = i % 2; - auto s1 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - int fd1 = s1.get(); - socklen_t addrlen = test_addr.addr_len; - - EXPECT_THAT( - setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &portreuse1, sizeof(int)), - SyscallSucceeds()); - - ASSERT_THAT(bind(fd1, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - ASSERT_THAT(getsockname(fd1, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(fd1, 1), SyscallSucceeds()); - } - - // j is less than 4 to check that the port reuse logic works correctly after - // closing bound sockets. - for (int j = 0; j < 4; j++) { - const int portreuse2 = j % 2; - auto s2 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - int fd2 = s2.get(); - - EXPECT_THAT( - setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)), - SyscallSucceeds()); - - std::cout << portreuse1 << " " << portreuse2; - int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen); - - // Verify that two sockets can be bound to the same port only if - // SO_REUSEPORT is set for both of them. - if (!portreuse1 || !portreuse2) { - ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE)); - } else { - ASSERT_THAT(ret, SyscallSucceeds()); - } - } - } -} - -// Check that when a socket was bound to an address with REUSEPORT and then -// closed, we can bind a different socket to the same address without needing -// REUSEPORT. -TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) { - auto const& param = GetParam(); - TestAddress const& test_addr = V4Loopback(); - sockaddr_storage addr = test_addr.addr; - - auto s = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - int fd = s.get(); - socklen_t addrlen = test_addr.addr_len; - int portreuse = 1; - ASSERT_THAT( - setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &portreuse, sizeof(portreuse)), - SyscallSucceeds()); - ASSERT_THAT(bind(fd, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - ASSERT_THAT(getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - ASSERT_EQ(addrlen, test_addr.addr_len); - - s.reset(); - - // Open a new socket and bind to the same address, but w/o REUSEPORT. - s = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - fd = s.get(); - portreuse = 0; - ASSERT_THAT( - setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &portreuse, sizeof(portreuse)), - SyscallSucceeds()); - ASSERT_THAT(bind(fd, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); -} - -INSTANTIATE_TEST_SUITE_P( - AllFamilies, SocketMultiProtocolInetLoopbackTest, - ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, - ProtocolTestParam{"UDP", SOCK_DGRAM}), - DescribeProtocolTestParam); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_loopback_blocking.cc b/test/syscalls/linux/socket_ip_loopback_blocking.cc deleted file mode 100644 index fda252dd7..000000000 --- a/test/syscalls/linux/socket_ip_loopback_blocking.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/tcp.h> - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return VecCat<SocketPairKind>( - std::vector<SocketPairKind>{ - IPv6UDPBidirectionalBindSocketPair(0), - IPv4UDPBidirectionalBindSocketPair(0), - }, - ApplyVecToVec<SocketPairKind>( - std::vector<Middleware>{ - NoOp, SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &kSockOptOn)}, - std::vector<SocketPairKind>{ - IPv6TCPAcceptBindSocketPair(0), - IPv4TCPAcceptBindSocketPair(0), - })); -} - -INSTANTIATE_TEST_SUITE_P( - BlockingIPSockets, BlockingSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc deleted file mode 100644 index 27779e47c..000000000 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ /dev/null @@ -1,912 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_ip_tcp_generic.h" - -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <poll.h> -#include <stdio.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(TCPSocketPairTest, TcpInfoSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct tcp_info opt = {}; - socklen_t optLen = sizeof(opt); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen), - SyscallSucceeds()); -} - -TEST_P(TCPSocketPairTest, ShortTcpInfoSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct tcp_info opt = {}; - socklen_t optLen = 1; - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen), - SyscallSucceeds()); -} - -TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct tcp_info opt = {}; - socklen_t optLen = 0; - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen), - SyscallSucceeds()); -} - -// This test validates that an RST is sent instead of a FIN when data is -// unread on calls to close(2). -TEST_P(TCPSocketPairTest, RSTSentOnCloseWithUnreadData) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // Wait until t_ sees the data on its side but don't read it. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; - constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Now close the connected without reading the data. - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); - - // Wait for the other end to receive the RST (up to 20 seconds). - struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN | POLLHUP, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // A shutdown with unread data will cause a RST to be sent instead - // of a FIN, per RFC 2525 section 2.17; this is also what Linux does. - ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), - SyscallFailsWithErrno(ECONNRESET)); -} - -// This test will validate that a RST will cause POLLHUP to trigger. -TEST_P(TCPSocketPairTest, RSTCausesPollHUP) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // Wait until second sees the data on its side but don't read it. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0}; - constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN); - - // Confirm we at least have one unread byte. - int bytes_available = 0; - ASSERT_THAT( - RetryEINTR(ioctl)(sockets->second_fd(), FIONREAD, &bytes_available), - SyscallSucceeds()); - EXPECT_GT(bytes_available, 0); - - // Now close the connected socket without reading the data from the second, - // this will cause a RST and we should see that with POLLHUP. - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); - - // Wait for the other end to receive the RST (up to 20 seconds). - struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - ASSERT_NE(poll_fd3.revents & POLLHUP, 0); -} - -// This test validates that even if a RST is sent the other end will not -// get an ECONNRESET until it's read all data. -TEST_P(TCPSocketPairTest, RSTSentOnCloseWithUnreadDataAllowsReadBuffered) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // Wait until second sees the data on its side but don't read it. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0}; - constexpr int kPollTimeoutMs = 30000; // Wait up to 30 seconds for the data. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Wait until first sees the data on its side but don't read it. - struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Now close the connected socket without reading the data from the second. - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); - - // Wait for the other end to receive the RST (up to 30 seconds). - struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Since we also have data buffered we should be able to read it before - // the syscall will fail with ECONNRESET. - ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // A shutdown with unread data will cause a RST to be sent instead - // of a FIN, per RFC 2525 section 2.17; this is also what Linux does. - ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), - SyscallFailsWithErrno(ECONNRESET)); -} - -// This test will verify that a clean shutdown (FIN) is preformed when there -// is unread data but only the write side is closed. -TEST_P(TCPSocketPairTest, FINSentOnShutdownWrWithUnreadData) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // Wait until t_ sees the data on its side but don't read it. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; - constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Now shutdown the write end leaving the read end open. - ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds()); - - // Wait for the other end to receive the FIN (up to 20 seconds). - struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN | POLLHUP, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Since we didn't shutdown the read end this will be a clean close. - ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(0)); -} - -// This test will verify that when data is received by a socket, even if it's -// not read SHUT_RD will not cause any packets to be generated. -TEST_P(TCPSocketPairTest, ShutdownRdShouldCauseNoPacketsWithUnreadData) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // Wait until t_ sees the data on its side but don't read it. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; - constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Now shutdown the read end, this will generate no packets to the other end. - ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RD), SyscallSucceeds()); - - // We should not receive any events on the other side of the socket. - struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN | POLLHUP, 0}; - constexpr int kPollNoResponseTimeoutMs = 3000; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollNoResponseTimeoutMs), - SyscallSucceedsWithValue(0)); // Timeout. -} - -// This test will verify that a socket which has unread data will still allow -// the data to be read after shutting down the read side, and once there is no -// unread data left, then read will return an EOF. -TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // Wait until t_ sees the data on its side but don't read it. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; - constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Now shutdown the read end. - ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RD), SyscallSucceeds()); - - // Even though we did a SHUT_RD on the read end we can still read the data. - ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // After reading all of the data, reading the closed read end returns EOF. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(0)); -} - -// This test verifies that a shutdown(wr) by the server after sending -// data allows the client to still read() the queued data and a client -// close after sending response allows server to read the incoming -// response. -TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char buf[10] = {}; - ScopedThread t([&]() { - ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(close(sockets->release_first_fd()), - SyscallSucceedsWithValue(0)); - }); - ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR), - SyscallSucceedsWithValue(0)); - t.Join(); - - ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Set the read end to O_NONBLOCK. - int opts = 0; - ASSERT_THAT(opts = fcntl(sockets->second_fd(), F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(sockets->second_fd(), F_SETFL, opts | O_NONBLOCK), - SyscallSucceeds()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(buf))); - - // Wait until second_fd sees the data and then recv it. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0}; - constexpr int kPollTimeoutMs = 2000; // Wait up to 2 seconds for the data. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(buf))); - - // Now shutdown the write end leaving the read end open. - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - - // Wait for close notification and recv again. - struct pollfd poll_fd2 = {sockets->second_fd(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(TCPSocketPairTest, - ShutdownRdUnreadDataShouldCauseNoPacketsUnlessClosed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // Wait until t_ sees the data on its side but don't read it. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; - constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - - // Now shutdown the read end, this will generate no packets to the other end. - ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RD), SyscallSucceeds()); - - // We should not receive any events on the other side of the socket. - struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN | POLLHUP, 0}; - constexpr int kPollNoResponseTimeoutMs = 3000; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollNoResponseTimeoutMs), - SyscallSucceedsWithValue(0)); // Timeout. - - // Now since we've fully closed the connection it will generate a RST. - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); // The other end has closed. - - // A shutdown with unread data will cause a RST to be sent instead - // of a FIN, per RFC 2525 section 2.17; this is also what Linux does. - ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), - SyscallFailsWithErrno(ECONNRESET)); -} - -TEST_P(TCPSocketPairTest, TCPCorkDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(TCPSocketPairTest, SetTCPCork) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - EXPECT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(TCPSocketPairTest, TCPCork) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - constexpr char kData[] = "abc"; - ASSERT_THAT(WriteFd(sockets->first_fd(), kData, sizeof(kData)), - SyscallSucceedsWithValue(sizeof(kData))); - - ASSERT_NO_FATAL_FAILURE(RecvNoData(sockets->second_fd())); - - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - // Create a receive buffer larger than kData. - char buf[(sizeof(kData) + 1) * 2] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(kData))); - EXPECT_EQ(absl::string_view(kData, sizeof(kData)), - absl::string_view(buf, sizeof(kData))); -} - -TEST_P(TCPSocketPairTest, TCPQuickAckDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK, &get, - &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - -TEST_P(TCPSocketPairTest, SetTCPQuickAck) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK, &get, - &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK, &get, - &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - -TEST_P(TCPSocketPairTest, SoKeepaliveDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(TCPSocketPairTest, SetSoKeepalive) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); - - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(TCPSocketPairTest, TCPKeepidleDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &get, - &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 2 * 60 * 60); // 2 hours. -} - -TEST_P(TCPSocketPairTest, TCPKeepintvlDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, &get, - &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 75); // 75 seconds. -} - -TEST_P(TCPSocketPairTest, SetTCPKeepidleZero) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kZero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &kZero, - sizeof(kZero)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(TCPSocketPairTest, SetTCPKeepintvlZero) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kZero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, - &kZero, sizeof(kZero)), - SyscallFailsWithErrno(EINVAL)); -} - -// Copied from include/net/tcp.h. -constexpr int MAX_TCP_KEEPIDLE = 32767; -constexpr int MAX_TCP_KEEPINTVL = 32767; - -TEST_P(TCPSocketPairTest, SetTCPKeepidleAboveMax) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kAboveMax = MAX_TCP_KEEPIDLE + 1; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, - &kAboveMax, sizeof(kAboveMax)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(TCPSocketPairTest, SetTCPKeepintvlAboveMax) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kAboveMax = MAX_TCP_KEEPINTVL + 1; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, - &kAboveMax, sizeof(kAboveMax)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(TCPSocketPairTest, SetTCPKeepidleToMax) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, - &MAX_TCP_KEEPIDLE, sizeof(MAX_TCP_KEEPIDLE)), - SyscallSucceedsWithValue(0)); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &get, - &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, MAX_TCP_KEEPIDLE); -} - -TEST_P(TCPSocketPairTest, SetTCPKeepintvlToMax) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, - &MAX_TCP_KEEPINTVL, sizeof(MAX_TCP_KEEPINTVL)), - SyscallSucceedsWithValue(0)); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, &get, - &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, MAX_TCP_KEEPINTVL); -} - -TEST_P(TCPSocketPairTest, SetOOBInline) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_OOBINLINE, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_OOBINLINE, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - -TEST_P(TCPSocketPairTest, MsgTruncMsgPeek) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - // Read half of the data with MSG_TRUNC | MSG_PEEK. This way there will still - // be some data left to read in the next step even if the data gets consumed. - char received_data1[sizeof(sent_data) / 2] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data1, - sizeof(received_data1), MSG_TRUNC | MSG_PEEK), - SyscallSucceedsWithValue(sizeof(received_data1))); - - // Check that we didn't get anything. - char zeros[sizeof(received_data1)] = {}; - EXPECT_EQ(0, memcmp(zeros, received_data1, sizeof(received_data1))); - - // Check that all of the data is still there. - char received_data2[sizeof(sent_data)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data2, - sizeof(received_data2), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - EXPECT_EQ(0, memcmp(received_data2, sent_data, sizeof(sent_data))); -} - -TEST_P(TCPSocketPairTest, SetCongestionControlSucceedsForSupported) { - // This is Linux's net/tcp.h TCP_CA_NAME_MAX. - const int kTcpCaNameMax = 16; - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - // Netstack only supports reno & cubic so we only test these two values here. - { - const char kSetCC[kTcpCaNameMax] = "reno"; - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &kSetCC, strlen(kSetCC)), - SyscallSucceedsWithValue(0)); - - char got_cc[kTcpCaNameMax]; - memset(got_cc, '1', sizeof(got_cc)); - socklen_t optlen = sizeof(got_cc); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); - } - { - const char kSetCC[kTcpCaNameMax] = "cubic"; - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &kSetCC, strlen(kSetCC)), - SyscallSucceedsWithValue(0)); - - char got_cc[kTcpCaNameMax]; - memset(got_cc, '1', sizeof(got_cc)); - socklen_t optlen = sizeof(got_cc); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); - } -} - -TEST_P(TCPSocketPairTest, SetGetTCPCongestionShortReadBuffer) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - { - // Verify that getsockopt/setsockopt work with buffers smaller than - // kTcpCaNameMax. - const char kSetCC[] = "cubic"; - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &kSetCC, strlen(kSetCC)), - SyscallSucceedsWithValue(0)); - - char got_cc[sizeof(kSetCC)]; - socklen_t optlen = sizeof(got_cc); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc))); - } -} - -TEST_P(TCPSocketPairTest, SetGetTCPCongestionLargeReadBuffer) { - // This is Linux's net/tcp.h TCP_CA_NAME_MAX. - const int kTcpCaNameMax = 16; - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - { - // Verify that getsockopt works with buffers larger than - // kTcpCaNameMax. - const char kSetCC[] = "cubic"; - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &kSetCC, strlen(kSetCC)), - SyscallSucceedsWithValue(0)); - - char got_cc[kTcpCaNameMax + 5]; - socklen_t optlen = sizeof(got_cc); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - // Linux copies the minimum of kTcpCaNameMax or the length of the passed in - // buffer and sets optlen to the number of bytes actually copied - // irrespective of the actual length of the congestion control name. - EXPECT_EQ(kTcpCaNameMax, optlen); - EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); - } -} - -TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { - // This is Linux's net/tcp.h TCP_CA_NAME_MAX. - const int kTcpCaNameMax = 16; - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char old_cc[kTcpCaNameMax]; - socklen_t optlen = sizeof(old_cc); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &old_cc, &optlen), - SyscallSucceedsWithValue(0)); - - const char kSetCC[] = "invalid_ca_cc"; - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &kSetCC, strlen(kSetCC)), - SyscallFailsWithErrno(ENOENT)); - - char got_cc[kTcpCaNameMax]; - optlen = sizeof(got_cc); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, - &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); -} - -// Linux and Netstack both default to a 60s TCP_LINGER2 timeout. -constexpr int kDefaultTCPLingerTimeout = 60; - -TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kDefaultTCPLingerTimeout); -} - -TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kZero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, - sizeof(kZero)), - SyscallSucceedsWithValue(0)); - - constexpr int kNegative = -1234; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, - &kNegative, sizeof(kNegative)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout - // on linux (defaults to 60 seconds on linux). - constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, - &kAboveDefault, sizeof(kAboveDefault)), - SyscallSucceedsWithValue(0)); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kDefaultTCPLingerTimeout); -} - -TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout - // on linux (defaults to 60 seconds on linux). - constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, - &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), - SyscallSucceedsWithValue(0)); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kTCPLingerTimeout); -} - -TEST_P(TCPSocketPairTest, TestTCPCloseWithData) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ScopedThread t([&]() { - // Close one end to trigger sending of a FIN. - ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds()); - char buf[3]; - ASSERT_THAT(read(sockets->second_fd(), buf, 3), - SyscallSucceedsWithValue(3)); - absl::SleepFor(absl::Milliseconds(50)); - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); - }); - - absl::SleepFor(absl::Milliseconds(50)); - // Send some data then close. - constexpr char kStr[] = "abc"; - ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3)); - t.Join(); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); -} - -TEST_P(TCPSocketPairTest, TCPUserTimeoutDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &get, &get_len), - SyscallSucceeds()); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 0); // 0 ms (disabled). -} - -TEST_P(TCPSocketPairTest, SetTCPUserTimeoutZero) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kZero = 0; - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &kZero, sizeof(kZero)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &get, &get_len), - SyscallSucceeds()); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 0); // 0 ms (disabled). -} - -TEST_P(TCPSocketPairTest, SetTCPUserTimeoutBelowZero) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kNeg = -10; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &kNeg, sizeof(kNeg)), - SyscallFailsWithErrno(EINVAL)); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &get, &get_len), - SyscallSucceeds()); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 0); // 0 ms (disabled). -} - -TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kAbove = 10; - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &kAbove, sizeof(kAbove)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &get, &get_len), - SyscallSucceeds()); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kAbove); -} - -TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { - DisableSave ds; // Too many syscalls. - constexpr int kThreadCount = 1000; - std::unique_ptr<ScopedThread> instances[kThreadCount]; - for (int i = 0; i < kThreadCount; i++) { - instances[i] = absl::make_unique<ScopedThread>([&]() { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ScopedThread t([&]() { - // Close one end to trigger sending of a FIN. - struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; - // Wait up to 20 seconds for the data. - constexpr int kPollTimeoutMs = 20000; - ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), - SyscallSucceedsWithValue(1)); - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); - }); - - // Send some data then close. - constexpr char kStr[] = "abc"; - ASSERT_THAT(write(sockets->first_fd(), kStr, 3), - SyscallSucceedsWithValue(3)); - absl::SleepFor(absl::Milliseconds(10)); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - t.Join(); - }); - } - for (int i = 0; i < kThreadCount; i++) { - instances[i]->Join(); - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_generic.h b/test/syscalls/linux/socket_ip_tcp_generic.h deleted file mode 100644 index a3eff3c73..000000000 --- a/test/syscalls/linux/socket_ip_tcp_generic.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_TCP_GENERIC_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_TCP_GENERIC_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected TCP sockets. -using TCPSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_TCP_GENERIC_H_ diff --git a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc deleted file mode 100644 index 4e79d21f4..000000000 --- a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/tcp.h> - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_ip_tcp_generic.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVecToVec<SocketPairKind>( - std::vector<Middleware>{ - NoOp, SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &kSockOptOn)}, - std::vector<SocketPairKind>{ - IPv6TCPAcceptBindSocketPair(0), - IPv4TCPAcceptBindSocketPair(0), - DualStackTCPAcceptBindSocketPair(0), - }); -} - -INSTANTIATE_TEST_SUITE_P( - AllTCPSockets, TCPSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_loopback.cc b/test/syscalls/linux/socket_ip_tcp_loopback.cc deleted file mode 100644 index 9db3037bc..000000000 --- a/test/syscalls/linux/socket_ip_tcp_loopback.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_generic.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return { - IPv6TCPAcceptBindSocketPair(0), - IPv4TCPAcceptBindSocketPair(0), - DualStackTCPAcceptBindSocketPair(0), - }; -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, AllSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc deleted file mode 100644 index f996b93d2..000000000 --- a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/tcp.h> - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_stream_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVecToVec<SocketPairKind>( - std::vector<Middleware>{ - NoOp, SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &kSockOptOn)}, - std::vector<SocketPairKind>{ - IPv6TCPAcceptBindSocketPair(0), - IPv4TCPAcceptBindSocketPair(0), - DualStackTCPAcceptBindSocketPair(0), - }); -} - -INSTANTIATE_TEST_SUITE_P( - BlockingTCPSockets, BlockingStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc deleted file mode 100644 index ffa377210..000000000 --- a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/tcp.h> - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_non_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVecToVec<SocketPairKind>( - std::vector<Middleware>{ - NoOp, SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &kSockOptOn)}, - std::vector<SocketPairKind>{ - IPv6TCPAcceptBindSocketPair(SOCK_NONBLOCK), - IPv4TCPAcceptBindSocketPair(SOCK_NONBLOCK), - }); -} - -INSTANTIATE_TEST_SUITE_P( - NonBlockingTCPSockets, NonBlockingSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc deleted file mode 100644 index f178f1af9..000000000 --- a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <poll.h> -#include <stdio.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Test fixture for tests that apply to pairs of TCP and UDP sockets. -using TcpUdpSocketPairTest = SocketPairTest; - -TEST_P(TcpUdpSocketPairTest, ShutdownWrFollowedBySendIsError) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Now shutdown the write end of the first. - ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_WR), SyscallSucceeds()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(EPIPE)); -} - -std::vector<SocketPairKind> GetSocketPairs() { - return VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>( - IPv6UDPBidirectionalBindSocketPair, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - IPv4UDPBidirectionalBindSocketPair, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - DualStackUDPBidirectionalBindSocketPair, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - IPv6TCPAcceptBindSocketPair, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - IPv4TCPAcceptBindSocketPair, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - DualStackTCPAcceptBindSocketPair, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK}))); -} - -INSTANTIATE_TEST_SUITE_P( - AllIPSockets, TcpUdpSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc deleted file mode 100644 index 1c533fdf2..000000000 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ /dev/null @@ -1,458 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_ip_udp_generic.h" - -#include <errno.h> -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <poll.h> -#include <stdio.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(UDPSocketPairTest, MulticastTTLDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 1); -} - -TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMin) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kMin = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &kMin, sizeof(kMin)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kMin); -} - -TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMax) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kMax = 255; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &kMax, sizeof(kMax)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kMax); -} - -TEST_P(UDPSocketPairTest, SetUDPMulticastTTLNegativeOne) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kArbitrary = 6; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &kArbitrary, sizeof(kArbitrary)), - SyscallSucceeds()); - - constexpr int kNegOne = -1; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &kNegOne, sizeof(kNegOne)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 1); -} - -TEST_P(UDPSocketPairTest, SetUDPMulticastTTLBelowMin) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kBelowMin = -2; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &kBelowMin, sizeof(kBelowMin)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UDPSocketPairTest, SetUDPMulticastTTLAboveMax) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr int kAboveMax = 256; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &kAboveMax, sizeof(kAboveMax)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UDPSocketPairTest, SetUDPMulticastTTLChar) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr char kArbitrary = 6; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &kArbitrary, sizeof(kArbitrary)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kArbitrary); -} - -TEST_P(UDPSocketPairTest, SetEmptyIPAddMembership) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct ip_mreqn req = {}; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &req, sizeof(req)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UDPSocketPairTest, MulticastLoopDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - -TEST_P(UDPSocketPairTest, SetMulticastLoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - -TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - constexpr char kSockOptOnChar = kSockOptOn; - constexpr char kSockOptOffChar = kSockOptOff; - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOffChar, sizeof(kSockOptOffChar)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOnChar, sizeof(kSockOptOnChar)), - SyscallSucceeds()); - - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, - &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - -TEST_P(UDPSocketPairTest, ReuseAddrDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(UDPSocketPairTest, SetReuseAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); - - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(UDPSocketPairTest, ReusePortDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(UDPSocketPairTest, SetReusePort) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); - - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - -// Test getsockopt for a socket which is not set with IP_PKTINFO option. -TEST_P(UDPSocketPairTest, IPPKTINFODefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -// Test setsockopt and getsockopt for a socket with IP_PKTINFO option. -TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int level = SOL_IP; - int type = IP_PKTINFO; - - // Check getsockopt before IP_PKTINFO is set. - int get = -1; - socklen_t get_len = sizeof(get); - - ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get, kSockOptOn); - EXPECT_EQ(get_len, sizeof(get)); - - ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff, - sizeof(kSockOptOff)), - SyscallSucceedsWithValue(0)); - - ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get, kSockOptOff); - EXPECT_EQ(get_len, sizeof(get)); -} - -// Holds TOS or TClass information for IPv4 or IPv6 respectively. -struct RecvTosOption { - int level; - int option; -}; - -RecvTosOption GetRecvTosOption(int domain) { - TEST_CHECK(domain == AF_INET || domain == AF_INET6); - RecvTosOption opt; - switch (domain) { - case AF_INET: - opt.level = IPPROTO_IP; - opt.option = IP_RECVTOS; - break; - case AF_INET6: - opt.level = IPPROTO_IPV6; - opt.option = IPV6_RECVTCLASS; - break; - } - return opt; -} - -// Ensure that Receiving TOS or TCLASS is off by default. -TEST_P(UDPSocketPairTest, RecvTosDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - RecvTosOption t = GetRecvTosOption(GetParam().domain); - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as -// expected. -TEST_P(UDPSocketPairTest, SetRecvTos) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - RecvTosOption t = GetRecvTosOption(GetParam().domain); - - ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff, - sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - -// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this -// mirrors behavior in linux. -TEST_P(UDPSocketPairTest, TOSRecvMismatch) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - RecvTosOption t = GetRecvTosOption(AF_INET); - int get = -1; - socklen_t get_len = sizeof(get); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), - SyscallSucceedsWithValue(0)); -} - -// Test that an IPv4 socket does not support the IPv6 TClass option. -TEST_P(UDPSocketPairTest, TClassRecvMismatch) { - // This should only test AF_INET sockets for the mismatch behavior. - SKIP_IF(GetParam().domain != AF_INET); - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - - ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS, - &get, &get_len), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_generic.h b/test/syscalls/linux/socket_ip_udp_generic.h deleted file mode 100644 index 106c54e9f..000000000 --- a/test/syscalls/linux/socket_ip_udp_generic.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_GENERIC_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_GENERIC_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected UDP sockets. -using UDPSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_GENERIC_H_ diff --git a/test/syscalls/linux/socket_ip_udp_loopback.cc b/test/syscalls/linux/socket_ip_udp_loopback.cc deleted file mode 100644 index c7fa44884..000000000 --- a/test/syscalls/linux/socket_ip_udp_loopback.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_generic.h" -#include "test/syscalls/linux/socket_ip_udp_generic.h" -#include "test/syscalls/linux/socket_non_stream.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return { - IPv6UDPBidirectionalBindSocketPair(0), - IPv4UDPBidirectionalBindSocketPair(0), - DualStackUDPBidirectionalBindSocketPair(0), - }; -} - -INSTANTIATE_TEST_SUITE_P( - AllUDPSockets, AllSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - AllUDPSockets, NonStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - AllUDPSockets, UDPSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc deleted file mode 100644 index d6925a8df..000000000 --- a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_non_stream_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return { - IPv6UDPBidirectionalBindSocketPair(0), - IPv4UDPBidirectionalBindSocketPair(0), - }; -} - -INSTANTIATE_TEST_SUITE_P( - BlockingUDPSockets, BlockingNonStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc deleted file mode 100644 index d675eddc6..000000000 --- a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_non_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return { - IPv6UDPBidirectionalBindSocketPair(SOCK_NONBLOCK), - IPv4UDPBidirectionalBindSocketPair(SOCK_NONBLOCK), - }; -} - -INSTANTIATE_TEST_SUITE_P( - NonBlockingUDPSockets, NonBlockingSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc deleted file mode 100644 index ca597e267..000000000 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ /dev/null @@ -1,443 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <netinet/in.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <cstdio> -#include <cstring> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of IP sockets. -using IPUnboundSocketTest = SimpleSocketTest; - -TEST_P(IPUnboundSocketTest, TtlDefault) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get, 64); - EXPECT_EQ(get_sz, sizeof(get)); -} - -TEST_P(IPUnboundSocketTest, SetTtl) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int get1 = -1; - socklen_t get1_sz = sizeof(get1); - EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get1_sz, sizeof(get1)); - - int set = 100; - if (set == get1) { - set += 1; - } - socklen_t set_sz = sizeof(set); - EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz), - SyscallSucceedsWithValue(0)); - - int get2 = -1; - socklen_t get2_sz = sizeof(get2); - EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get2_sz, sizeof(get2)); - EXPECT_EQ(get2, set); -} - -TEST_P(IPUnboundSocketTest, ResetTtlToDefault) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int get1 = -1; - socklen_t get1_sz = sizeof(get1); - EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get1_sz, sizeof(get1)); - - int set1 = 100; - if (set1 == get1) { - set1 += 1; - } - socklen_t set1_sz = sizeof(set1); - EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set1, set1_sz), - SyscallSucceedsWithValue(0)); - - int set2 = -1; - socklen_t set2_sz = sizeof(set2); - EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set2, set2_sz), - SyscallSucceedsWithValue(0)); - - int get2 = -1; - socklen_t get2_sz = sizeof(get2); - EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get2_sz, sizeof(get2)); - EXPECT_EQ(get2, get1); -} - -TEST_P(IPUnboundSocketTest, ZeroTtl) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int set = 0; - socklen_t set_sz = sizeof(set); - EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(IPUnboundSocketTest, InvalidLargeTtl) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int set = 256; - socklen_t set_sz = sizeof(set); - EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(IPUnboundSocketTest, InvalidNegativeTtl) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int set = -2; - socklen_t set_sz = sizeof(set); - EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz), - SyscallFailsWithErrno(EINVAL)); -} - -struct TOSOption { - int level; - int option; - int cmsg_level; -}; - -constexpr int INET_ECN_MASK = 3; - -static TOSOption GetTOSOption(int domain) { - TOSOption opt; - switch (domain) { - case AF_INET: - opt.level = IPPROTO_IP; - opt.option = IP_TOS; - opt.cmsg_level = SOL_IP; - break; - case AF_INET6: - opt.level = IPPROTO_IPV6; - opt.option = IPV6_TCLASS; - opt.cmsg_level = SOL_IPV6; - break; - } - return opt; -} - -TEST_P(IPUnboundSocketTest, TOSDefault) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - TOSOption t = GetTOSOption(GetParam().domain); - int get = -1; - socklen_t get_sz = sizeof(get); - constexpr int kDefaultTOS = 0; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, sizeof(get)); - EXPECT_EQ(get, kDefaultTOS); -} - -TEST_P(IPUnboundSocketTest, SetTOS) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int set = 0xC0; - socklen_t set_sz = sizeof(set); - TOSOption t = GetTOSOption(GetParam().domain); - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallSucceedsWithValue(0)); - - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, sizeof(get)); - EXPECT_EQ(get, set); -} - -TEST_P(IPUnboundSocketTest, ZeroTOS) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int set = 0; - socklen_t set_sz = sizeof(set); - TOSOption t = GetTOSOption(GetParam().domain); - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallSucceedsWithValue(0)); - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, sizeof(get)); - EXPECT_EQ(get, set); -} - -TEST_P(IPUnboundSocketTest, InvalidLargeTOS) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - // Test with exceeding the byte space. - int set = 256; - constexpr int kDefaultTOS = 0; - socklen_t set_sz = sizeof(set); - TOSOption t = GetTOSOption(GetParam().domain); - if (GetParam().domain == AF_INET) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallSucceedsWithValue(0)); - } else { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallFailsWithErrno(EINVAL)); - } - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, sizeof(get)); - EXPECT_EQ(get, kDefaultTOS); -} - -TEST_P(IPUnboundSocketTest, CheckSkipECN) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int set = 0xFF; - socklen_t set_sz = sizeof(set); - TOSOption t = GetTOSOption(GetParam().domain); - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallSucceedsWithValue(0)); - int expect = static_cast<uint8_t>(set); - if (GetParam().protocol == IPPROTO_TCP) { - expect &= ~INET_ECN_MASK; - } - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, sizeof(get)); - EXPECT_EQ(get, expect); -} - -TEST_P(IPUnboundSocketTest, ZeroTOSOptionSize) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int set = 0xC0; - socklen_t set_sz = 0; - TOSOption t = GetTOSOption(GetParam().domain); - if (GetParam().domain == AF_INET) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallSucceedsWithValue(0)); - } else { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallFailsWithErrno(EINVAL)); - } - int get = -1; - socklen_t get_sz = 0; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, 0); - EXPECT_EQ(get, -1); -} - -TEST_P(IPUnboundSocketTest, SmallTOSOptionSize) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int set = 0xC0; - constexpr int kDefaultTOS = 0; - TOSOption t = GetTOSOption(GetParam().domain); - for (socklen_t i = 1; i < sizeof(int); i++) { - int expect_tos; - socklen_t expect_sz; - if (GetParam().domain == AF_INET) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i), - SyscallSucceedsWithValue(0)); - expect_tos = set; - expect_sz = sizeof(uint8_t); - } else { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i), - SyscallFailsWithErrno(EINVAL)); - expect_tos = kDefaultTOS; - expect_sz = i; - } - uint get = -1; - socklen_t get_sz = i; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, expect_sz); - // Account for partial copies by getsockopt, retrieve the lower - // bits specified by get_sz, while comparing against expect_tos. - EXPECT_EQ(get & ~(~0 << (get_sz * 8)), expect_tos); - } -} - -TEST_P(IPUnboundSocketTest, LargeTOSOptionSize) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int set = 0xC0; - TOSOption t = GetTOSOption(GetParam().domain); - for (socklen_t i = sizeof(int); i < 10; i++) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i), - SyscallSucceedsWithValue(0)); - int get = -1; - socklen_t get_sz = i; - // We expect the system call handler to only copy atmost sizeof(int) bytes - // as asserted by the check below. Hence, we do not expect the copy to - // overflow in getsockopt. - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, sizeof(int)); - EXPECT_EQ(get, set); - } -} - -TEST_P(IPUnboundSocketTest, NegativeTOS) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int set = -1; - socklen_t set_sz = sizeof(set); - TOSOption t = GetTOSOption(GetParam().domain); - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallSucceedsWithValue(0)); - int expect; - if (GetParam().domain == AF_INET) { - expect = static_cast<uint8_t>(set); - if (GetParam().protocol == IPPROTO_TCP) { - expect &= ~INET_ECN_MASK; - } - } else { - // On IPv6 TCLASS, setting -1 has the effect of resetting the - // TrafficClass. - expect = 0; - } - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, sizeof(get)); - EXPECT_EQ(get, expect); -} - -TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int set = -2; - socklen_t set_sz = sizeof(set); - TOSOption t = GetTOSOption(GetParam().domain); - int expect; - if (GetParam().domain == AF_INET) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallSucceedsWithValue(0)); - expect = static_cast<uint8_t>(set); - if (GetParam().protocol == IPPROTO_TCP) { - expect &= ~INET_ECN_MASK; - } - } else { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), - SyscallFailsWithErrno(EINVAL)); - expect = 0; - } - int get = 0; - socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_sz, sizeof(get)); - EXPECT_EQ(get, expect); -} - -TEST_P(IPUnboundSocketTest, NullTOS) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - TOSOption t = GetTOSOption(GetParam().domain); - int set_sz = sizeof(int); - if (GetParam().domain == AF_INET) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), - SyscallFailsWithErrno(EFAULT)); - } else { // AF_INET6 - // The AF_INET6 behavior is not yet compatible. gVisor will try to read - // optval from user memory at syscall handler, it needs substantial - // refactoring to implement this behavior just for IPv6. - if (IsRunningOnGvisor()) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), - SyscallFailsWithErrno(EFAULT)); - } else { - // Linux's IPv6 stack treats nullptr optval as input of 0, so the call - // succeeds. (net/ipv6/ipv6_sockglue.c, do_ipv6_setsockopt()) - // - // Linux's implementation would need fixing as passing a nullptr as optval - // and non-zero optlen may not be valid. - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), - SyscallSucceedsWithValue(0)); - } - } - socklen_t get_sz = sizeof(int); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, nullptr, &get_sz), - SyscallFailsWithErrno(EFAULT)); - int get = -1; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, nullptr), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) { - SKIP_IF(GetParam().protocol == IPPROTO_TCP); - - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - TOSOption t = GetTOSOption(GetParam().domain); - - in_addr addr4; - in6_addr addr6; - ASSERT_THAT(inet_pton(AF_INET, "127.0.0.1", &addr4), ::testing::Eq(1)); - ASSERT_THAT(inet_pton(AF_INET6, "fe80::", &addr6), ::testing::Eq(1)); - - cmsghdr cmsg = {}; - cmsg.cmsg_len = sizeof(cmsg); - cmsg.cmsg_level = t.cmsg_level; - cmsg.cmsg_type = t.option; - - msghdr msg = {}; - msg.msg_control = &cmsg; - msg.msg_controllen = sizeof(cmsg); - if (GetParam().domain == AF_INET) { - msg.msg_name = &addr4; - msg.msg_namelen = sizeof(addr4); - } else { - msg.msg_name = &addr6; - msg.msg_namelen = sizeof(addr6); - } - - EXPECT_THAT(sendmsg(socket->get(), &msg, 0), SyscallFailsWithErrno(EINVAL)); -} - -INSTANTIATE_TEST_SUITE_P( - IPUnboundSockets, IPUnboundSocketTest, - ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>( - ApplyVec<SocketKind>(IPv4UDPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_DGRAM}, - List<int>{0, - SOCK_NONBLOCK})), - ApplyVec<SocketKind>(IPv6UDPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_DGRAM}, - List<int>{0, - SOCK_NONBLOCK})), - ApplyVec<SocketKind>(IPv4TCPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, - SOCK_NONBLOCK})), - ApplyVec<SocketKind>(IPv6TCPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})))))); - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc deleted file mode 100644 index 80f12b0a9..000000000 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h" - -#include <netinet/in.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <cstdio> -#include <cstring> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// Verifies that a newly instantiated TCP socket does not have the -// broadcast socket option enabled. -TEST_P(IPv4TCPUnboundExternalNetworkingSocketTest, TCPBroadcastDefault) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT( - getsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get, kSockOptOff); - EXPECT_EQ(get_sz, sizeof(get)); -} - -// Verifies that a newly instantiated TCP socket returns true after enabling -// the broadcast socket option. -TEST_P(IPv4TCPUnboundExternalNetworkingSocketTest, SetTCPBroadcast) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - EXPECT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT( - getsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get, kSockOptOn); - EXPECT_EQ(get_sz, sizeof(get)); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h deleted file mode 100644 index fb582b224..000000000 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_TCP_UNBOUND_EXTERNAL_NETWORKING_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_TCP_UNBOUND_EXTERNAL_NETWORKING_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to unbound IPv4 TCP sockets in a sandbox -// with external networking support. -using IPv4TCPUnboundExternalNetworkingSocketTest = SimpleSocketTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_TCP_UNBOUND_EXTERNAL_NETWORKING_H_ diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc deleted file mode 100644 index 797c4174e..000000000 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h" - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketKind> GetSockets() { - return ApplyVec<SocketKind>( - IPv4TCPUnboundSocket, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P(IPv4TCPUnboundSockets, - IPv4TCPUnboundExternalNetworkingSocketTest, - ::testing::ValuesIn(GetSockets())); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc deleted file mode 100644 index bc4b07a62..000000000 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ /dev/null @@ -1,2216 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_ipv4_udp_unbound.h" - -#include <arpa/inet.h> -#include <net/if.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/un.h> - -#include <cstdio> - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// Check that packets are not received without a group membership. Default send -// interface configured by bind. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the first FD to the loopback. This is an alternative to - // IP_MULTICAST_IF for setting the default send interface. - auto sender_addr = V4Loopback(); - EXPECT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address. If multicast worked like unicast, - // this would ensure that we get the packet. - auto receiver_addr = V4Any(); - EXPECT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Send the multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that not setting a default send interface prevents multicast packets -// from being sent. Group membership interface configured by address. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the second FD to the v4 any address to ensure that we can receive any - // unicast packet. - auto receiver_addr = V4Any(); - EXPECT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallFailsWithErrno(ENETUNREACH)); -} - -// Check that not setting a default send interface prevents multicast packets -// from being sent. Group membership interface configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the second FD to the v4 any address to ensure that we can receive any - // unicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallFailsWithErrno(ENETUNREACH)); -} - -// Check that multicast works when the default send interface is configured by -// bind and the group membership is configured by address. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the first FD to the loopback. This is an alternative to - // IP_MULTICAST_IF for setting the default send interface. - auto sender_addr = V4Loopback(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// bind and the group membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the first FD to the loopback. This is an alternative to - // IP_MULTICAST_IF for setting the default send interface. - auto sender_addr = V4Loopback(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in sendto, and the group -// membership is configured by address. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreq iface = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in sendto, and the group -// membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreqn iface = {}; - iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in connect, and the group -// membership is configured by address. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreq iface = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto connect_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - ASSERT_THAT( - RetryEINTR(connect)(socket1->get(), - reinterpret_cast<sockaddr*>(&connect_addr.addr), - connect_addr.addr_len), - SyscallSucceeds()); - - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in connect, and the group -// membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreqn iface = {}; - iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto connect_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - ASSERT_THAT( - RetryEINTR(connect)(socket1->get(), - reinterpret_cast<sockaddr*>(&connect_addr.addr), - connect_addr.addr_len), - SyscallSucceeds()); - - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in sendto, and the group -// membership is configured by address. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreq iface = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Bind the first FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in sendto, and the group -// membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreqn iface = {}; - iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Bind the first FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in connect, and the group -// membership is configured by address. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreq iface = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Bind the first FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto connect_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - EXPECT_THAT( - RetryEINTR(connect)(socket1->get(), - reinterpret_cast<sockaddr*>(&connect_addr.addr), - connect_addr.addr_len), - SyscallSucceeds()); - - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in connect, and the group -// membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreqn iface = {}; - iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Bind the first FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto connect_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - ASSERT_THAT( - RetryEINTR(connect)(socket1->get(), - reinterpret_cast<sockaddr*>(&connect_addr.addr), - connect_addr.addr_len), - SyscallSucceeds()); - - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in sendto, and the group -// membership is configured by address. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreq iface = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - // Bind the first FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast works when the default send interface is configured by -// IP_MULTICAST_IF, the send address is specified in sendto, and the group -// membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Set the default send interface. - ip_mreqn iface = {}; - iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that dropping a group membership that does not exist fails. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastInvalidDrop) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Unregister from a membership that we didn't have. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, - sizeof(group)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -// Check that dropping a group membership prevents multicast packets from being -// delivered. Default send address configured by bind and group membership -// interface configured by address. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the first FD to the loopback. This is an alternative to - // IP_MULTICAST_IF for setting the default send interface. - auto sender_addr = V4Loopback(); - EXPECT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - EXPECT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register and unregister to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that dropping a group membership prevents multicast packets from being -// delivered. Default send address configured by bind and group membership -// interface configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the first FD to the loopback. This is an alternative to - // IP_MULTICAST_IF for setting the default send interface. - auto sender_addr = V4Loopback(); - EXPECT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - EXPECT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register and unregister to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn iface = {}; - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidNic) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn iface = {}; - iface.imr_ifindex = -1; - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreq iface = {}; - iface.imr_interface.s_addr = inet_addr("255.255.255"); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetShort) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Create a valid full-sized request. - ip_mreqn iface = {}; - iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - - // Send an optlen of 1 to check that optlen is enforced. - EXPECT_THAT( - setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefault) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - in_addr get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - EXPECT_EQ(size, sizeof(get)); - EXPECT_EQ(get.s_addr, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefaultReqn) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the - // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. - // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. - EXPECT_EQ(size, sizeof(in_addr)); - - // getsockopt(IP_MULTICAST_IF) will only return the interface address which - // hasn't been set. - EXPECT_EQ(get.imr_multiaddr.s_addr, 0); - EXPECT_EQ(get.imr_address.s_addr, 0); - EXPECT_EQ(get.imr_ifindex, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddrGetReqn) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - in_addr set = {}; - set.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - ip_mreqn get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the - // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. - // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. - EXPECT_EQ(size, sizeof(in_addr)); - EXPECT_EQ(get.imr_multiaddr.s_addr, set.s_addr); - EXPECT_EQ(get.imr_address.s_addr, 0); - EXPECT_EQ(get.imr_ifindex, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreq set = {}; - set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - ip_mreqn get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the - // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. - // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. - EXPECT_EQ(size, sizeof(in_addr)); - EXPECT_EQ(get.imr_multiaddr.s_addr, set.imr_interface.s_addr); - EXPECT_EQ(get.imr_address.s_addr, 0); - EXPECT_EQ(get.imr_ifindex, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNicGetReqn) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn set = {}; - set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - ip_mreqn get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - EXPECT_EQ(size, sizeof(in_addr)); - EXPECT_EQ(get.imr_multiaddr.s_addr, 0); - EXPECT_EQ(get.imr_address.s_addr, 0); - EXPECT_EQ(get.imr_ifindex, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - in_addr set = {}; - set.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - in_addr get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - EXPECT_EQ(size, sizeof(get)); - EXPECT_EQ(get.s_addr, set.s_addr); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreq set = {}; - set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - in_addr get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - EXPECT_EQ(size, sizeof(get)); - EXPECT_EQ(get.s_addr, set.imr_interface.s_addr); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn set = {}; - set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - in_addr get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - EXPECT_EQ(size, sizeof(get)); - EXPECT_EQ(get.s_addr, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallFailsWithErrno(ENODEV)); -} - -TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupInvalidIf) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn group = {}; - group.imr_address.s_addr = inet_addr("255.255.255"); - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallFailsWithErrno(ENODEV)); -} - -// Check that multiple memberships are not allowed on the same socket. -TEST_P(IPv4UDPUnboundSocketTest, TestMultipleJoinsOnSingleSocket) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto fd = socket1->get(); - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - - EXPECT_THAT( - setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), - SyscallSucceeds()); - - EXPECT_THAT( - setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), - SyscallFailsWithErrno(EADDRINUSE)); -} - -// Check that two sockets can join the same multicast group at the same time. -TEST_P(IPv4UDPUnboundSocketTest, TestTwoSocketsJoinSameMulticastGroup) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Drop the membership twice on each socket, the second call for each socket - // should fail. - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, - sizeof(group)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, - sizeof(group)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -// Check that two sockets can join the same multicast group at the same time, -// and both will receive data on it. -TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { - std::unique_ptr<SocketPair> socket_pairs[2] = { - absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocket())), - absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))}; - - ip_mreq iface = {}, group = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - auto receiver_addr = V4Any(); - int bound_port = 0; - - // Create two socketpairs with the exact same configuration. - for (auto& sockets : socket_pairs) { - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), - SyscallSucceeds()); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - // Get the port assigned. - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - // On the first iteration, save the port we are bound to. On the second - // iteration, verify the port is the same as the one from the first - // iteration. In other words, both sockets listen on the same port. - if (bound_port == 0) { - bound_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - } else { - EXPECT_EQ(bound_port, - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); - } - } - - // Send a multicast packet to the group from two different sockets and verify - // it is received by both sockets that joined that group. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; - for (auto& sockets : socket_pairs) { - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet on both sockets. - for (auto& sockets : socket_pairs) { - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); - } - } -} - -// Check that on two sockets that joined a group and listen on ANY, dropping -// memberships one by one will continue to deliver packets to both sockets until -// both memberships have been dropped. -TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { - std::unique_ptr<SocketPair> socket_pairs[2] = { - absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocket())), - absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))}; - - ip_mreq iface = {}, group = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - auto receiver_addr = V4Any(); - int bound_port = 0; - - // Create two socketpairs with the exact same configuration. - for (auto& sockets : socket_pairs) { - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), - SyscallSucceeds()); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - // Get the port assigned. - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - // On the first iteration, save the port we are bound to. On the second - // iteration, verify the port is the same as the one from the first - // iteration. In other words, both sockets listen on the same port. - if (bound_port == 0) { - bound_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - } else { - EXPECT_EQ(bound_port, - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); - } - } - - // Drop the membership of the first socket pair and verify data is still - // received. - ASSERT_THAT(setsockopt(socket_pairs[0]->second_fd(), IPPROTO_IP, - IP_DROP_MEMBERSHIP, &group, sizeof(group)), - SyscallSucceeds()); - // Send a packet from each socket_pair. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; - for (auto& sockets : socket_pairs) { - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet on both sockets. - for (auto& sockets : socket_pairs) { - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); - } - } - - // Drop the membership of the second socket pair and verify data stops being - // received. - ASSERT_THAT(setsockopt(socket_pairs[1]->second_fd(), IPPROTO_IP, - IP_DROP_MEMBERSHIP, &group, sizeof(group)), - SyscallSucceeds()); - // Send a packet from each socket_pair. - for (auto& sockets : socket_pairs) { - char send_buf[200]; - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - char recv_buf[sizeof(send_buf)] = {}; - for (auto& sockets : socket_pairs) { - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, - sizeof(recv_buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); - } - } -} - -// Check that a receiving socket can bind to the multicast address before -// joining the group and receive data once the group has been joined. -TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind second socket (receiver) to the multicast address. - auto receiver_addr = V4Multicast(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - // Update receiver_addr with the correct port number. - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet on the first socket out the loopback interface. - ip_mreq iface = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - auto sendto_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that a receiving socket can bind to the multicast address and won't -// receive multicast data if it hasn't joined the group. -TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind second socket (receiver) to the multicast address. - auto receiver_addr = V4Multicast(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - // Update receiver_addr with the correct port number. - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Send a multicast packet on the first socket out the loopback interface. - ip_mreq iface = {}; - iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - auto sendto_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we don't receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that a socket can bind to a multicast address and still send out -// packets. -TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind second socket (receiver) to the ANY address. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Bind the first socket (sender) to the multicast address. - auto sender_addr = V4Multicast(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); - socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&sender_addr.addr), - &sender_addr_len), - SyscallSucceeds()); - EXPECT_EQ(sender_addr_len, sender_addr.addr_len); - - // Send a packet on the first socket to the loopback address. - auto sendto_addr = V4Loopback(); - reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that a receiving socket can bind to the broadcast address and receive -// broadcast packets. -TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind second socket (receiver) to the broadcast address. - auto receiver_addr = V4Broadcast(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Send a broadcast packet on the first socket out the loopback interface. - EXPECT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - // Note: Binding to the loopback interface makes the broadcast go out of it. - auto sender_bind_addr = V4Loopback(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), - sender_bind_addr.addr_len), - SyscallSucceeds()); - auto sendto_addr = V4Broadcast(); - reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that a socket can bind to the broadcast address and still send out -// packets. -TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind second socket (receiver) to the ANY address. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Bind the first socket (sender) to the broadcast address. - auto sender_addr = V4Broadcast(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); - socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&sender_addr.addr), - &sender_addr_len), - SyscallSucceeds()); - EXPECT_EQ(sender_addr_len, sender_addr.addr_len); - - // Send a packet on the first socket to the loopback address. - auto sendto_addr = V4Loopback(); - reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that SO_REUSEADDR always delivers to the most recently bound socket. -TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - std::vector<std::unique_ptr<FileDescriptor>> sockets; - sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); - - ASSERT_THAT(setsockopt(sockets[0]->get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(sockets[0]->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(sockets[0]->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - constexpr int kMessageSize = 200; - - for (int i = 0; i < 10; i++) { - // Add a new receiver. - sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); - auto& last = sockets.back(); - ASSERT_THAT(setsockopt(last->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(last->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - - // Send a new message to the SO_REUSEADDR group. We use a new socket each - // time so that a new ephemeral port will be used each time. This ensures - // that we aren't doing REUSEPORT-like hash load blancing. - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - char send_buf[kMessageSize]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Verify that the most recent socket got the message. We don't expect any - // of the other sockets to have received it, but we will check that later. - char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT( - RetryEINTR(recv)(last->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(send_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); - } - - // Verify that no other messages were received. - for (auto& socket : sockets) { - char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); - } -} - -TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind socket1 with REUSEADDR. - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind socket2 to the same address as socket1, only with REUSEPORT. - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind socket1 with REUSEPORT. - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind socket2 to the same address as socket1, only with REUSEADDR. - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind socket1 with REUSEADDR and REUSEPORT. - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind socket2 to the same address as socket1, only with REUSEPORT. - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - - // Bind socket3 to the same address as socket1, only with REUSEADDR. - ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind socket1 with REUSEADDR and REUSEPORT. - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind socket2 to the same address as socket1, only with REUSEADDR. - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - - // Bind socket3 to the same address as socket1, only with REUSEPORT. - ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind socket1 with REUSEADDR and REUSEPORT. - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind socket2 to the same address as socket1, only with REUSEPORT. - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - - // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT. - socket2->reset(); - - // Bind socket3 to the same address as socket1, only with REUSEADDR. - ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); -} - -TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind socket1 with REUSEADDR and REUSEPORT. - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind socket2 to the same address as socket1, only with REUSEADDR. - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - - // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT. - socket2->reset(); - - // Bind socket3 to the same address as socket1, only with REUSEPORT. - ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); -} - -TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind socket1 with REUSEADDR and REUSEPORT. - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind socket2 to the same address as socket1, also with REUSEADDR and - // REUSEPORT. - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - - // Bind socket3 to the same address as socket1, only with REUSEPORT. - ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); -} - -TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind socket1 with REUSEADDR and REUSEPORT. - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind socket2 to the same address as socket1, also with REUSEADDR and - // REUSEPORT. - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - - // Bind socket3 to the same address as socket1, only with REUSEADDR. - ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); -} - -// Check that REUSEPORT takes precedence over REUSEADDR. -TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { - auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto receiver2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(receiver1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Bind receiver2 to the same address as socket1, also with REUSEADDR and - // REUSEPORT. - ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(receiver2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - - constexpr int kMessageSize = 10; - - for (int i = 0; i < 100; ++i) { - // Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket - // each time so that a new ephemerial port will be used each time. This - // ensures that we cycle through hashes. - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - char send_buf[kMessageSize] = {}; - EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - } - - // Check that both receivers got messages. This checks that we are using load - // balancing (REUSEPORT) instead of the most recently bound socket - // (REUSEADDR). - char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RetryEINTR(recv)(receiver1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(kMessageSize)); - EXPECT_THAT(RetryEINTR(recv)(receiver2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(kMessageSize)); -} - -// Test that socket will receive packet info control message. -TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF((IsRunningWithHostinet())); - - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto sender_addr = V4Loopback(); - int level = SOL_IP; - int type = IP_PKTINFO; - - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); - socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&sender_addr.addr), - &sender_addr_len), - SyscallSucceeds()); - EXPECT_EQ(sender_addr_len, sender_addr.addr_len); - - auto receiver_addr = V4Loopback(); - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port; - ASSERT_THAT( - connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - - // Allow socket to receive control message. - ASSERT_THAT( - setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - msghdr sent_msg = {}; - iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = sent_data; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - sent_msg.msg_flags = 0; - - ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - msghdr received_msg = {}; - iovec received_iov = {}; - char received_data[kDataLength]; - char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; - size_t cmsg_data_len = sizeof(in_pktinfo); - received_iov.iov_base = received_data; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - received_msg.msg_control = received_cmsg_buf; - - ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, level); - EXPECT_EQ(cmsg->cmsg_type, type); - - // Get loopback index. - ifreq ifr = {}; - absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); - ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds()); - ASSERT_NE(ifr.ifr_ifindex, 0); - - // Check the data - in_pktinfo received_pktinfo = {}; - memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); - EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); - EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); - EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK)); -} -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.h b/test/syscalls/linux/socket_ipv4_udp_unbound.h deleted file mode 100644 index f64c57645..000000000 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to IPv4 UDP sockets. -using IPv4UDPUnboundSocketTest = SimpleSocketTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc deleted file mode 100644 index 40e673625..000000000 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ /dev/null @@ -1,1104 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h" - -#include <arpa/inet.h> -#include <ifaddrs.h> -#include <netinet/in.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <cstdint> -#include <cstdio> -#include <cstring> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -TestAddress V4EmptyAddress() { - TestAddress t("V4Empty"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - return t; -} - -void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { - got_if_infos_ = false; - - // Get interface list. - std::vector<std::string> if_names; - ASSERT_NO_ERRNO(if_helper_.Load()); - if_names = if_helper_.InterfaceList(AF_INET); - if (if_names.size() != 2) { - return; - } - - // Figure out which interface is where. - int lo = 0, eth = 1; - if (if_names[lo] != "lo") { - lo = 1; - eth = 0; - } - - if (if_names[lo] != "lo") { - return; - } - - lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[lo])); - lo_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[lo]); - if (lo_if_addr_ == nullptr) { - return; - } - lo_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(lo_if_addr_)->sin_addr; - - eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[eth])); - eth_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[eth]); - if (eth_if_addr_ == nullptr) { - return; - } - eth_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(eth_if_addr_)->sin_addr; - - got_if_infos_ = true; -} - -// Verifies that a newly instantiated UDP socket does not have the -// broadcast socket option enabled. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastDefault) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT( - getsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get, kSockOptOff); - EXPECT_EQ(get_sz, sizeof(get)); -} - -// Verifies that a newly instantiated UDP socket returns true after enabling -// the broadcast socket option. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - EXPECT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - int get = -1; - socklen_t get_sz = sizeof(get); - EXPECT_THAT( - getsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &get, &get_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get, kSockOptOn); - EXPECT_EQ(get_sz, sizeof(get)); -} - -// Verifies that a broadcast UDP packet will arrive at all UDP sockets with -// the destination port number. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastReceivedOnExpectedPort) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto norcv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Enable SO_BROADCAST on the sending socket. - ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - // Enable SO_REUSEPORT on the receiving sockets so that they may both be bound - // to the broadcast messages destination port. - ASSERT_THAT(setsockopt(rcvr1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(setsockopt(rcvr2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - // Bind the first socket to the ANY address and let the system assign a port. - auto rcv1_addr = V4Any(); - ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - rcv1_addr.addr_len), - SyscallSucceedsWithValue(0)); - // Retrieve port number from first socket so that it can be bound to the - // second socket. - socklen_t rcv_addr_sz = rcv1_addr.addr_len; - ASSERT_THAT( - getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - &rcv_addr_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); - auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port; - - // Bind the second socket to the same address:port as the first. - ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - rcv_addr_sz), - SyscallSucceedsWithValue(0)); - - // Bind the non-receiving socket to an ephemeral port. - auto norecv_addr = V4Any(); - ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), - norecv_addr.addr_len), - SyscallSucceedsWithValue(0)); - - // Broadcast a test message. - auto dst_addr = V4Broadcast(); - reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port; - constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT( - sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - - // Verify that the receiving sockets received the test message. - char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); - memset(buf, 0, sizeof(buf)); - EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); - - // Verify that the non-receiving socket did not receive the test message. - memset(buf, 0, sizeof(buf)); - EXPECT_THAT(RetryEINTR(recv)(norcv->get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Verifies that a broadcast UDP packet will arrive at all UDP sockets bound to -// the destination port number and either INADDR_ANY or INADDR_BROADCAST, but -// not a unicast address. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastReceivedOnExpectedAddresses) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); - - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto norcv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Enable SO_BROADCAST on the sending socket. - ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - // Enable SO_REUSEPORT on all sockets so that they may all be bound to the - // broadcast messages destination port. - ASSERT_THAT(setsockopt(rcvr1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(setsockopt(rcvr2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - ASSERT_THAT(setsockopt(norcv->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - // Bind the first socket the ANY address and let the system assign a port. - auto rcv1_addr = V4Any(); - ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - rcv1_addr.addr_len), - SyscallSucceedsWithValue(0)); - // Retrieve port number from first socket so that it can be bound to the - // second socket. - socklen_t rcv_addr_sz = rcv1_addr.addr_len; - ASSERT_THAT( - getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - &rcv_addr_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); - auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port; - - // Bind the second socket to the broadcast address. - auto rcv2_addr = V4Broadcast(); - reinterpret_cast<sockaddr_in*>(&rcv2_addr.addr)->sin_port = port; - ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv2_addr.addr), - rcv2_addr.addr_len), - SyscallSucceedsWithValue(0)); - - // Bind the non-receiving socket to the unicast ethernet address. - auto norecv_addr = rcv1_addr; - reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr = - eth_if_sin_addr_; - ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), - norecv_addr.addr_len), - SyscallSucceedsWithValue(0)); - - // Broadcast a test message. - auto dst_addr = V4Broadcast(); - reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port; - constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT( - sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - - // Verify that the receiving sockets received the test message. - char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); - memset(buf, 0, sizeof(buf)); - EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); - - // Verify that the non-receiving socket did not receive the test message. - memset(buf, 0, sizeof(buf)); - EXPECT_THAT(RetryEINTR(recv)(norcv->get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Verifies that a UDP broadcast can be sent and then received back on the same -// socket that is bound to the broadcast address (255.255.255.255). -// FIXME(b/141938460): This can be combined with the next test -// (UDPBroadcastSendRecvOnSocketBoundToAny). -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastSendRecvOnSocketBoundToBroadcast) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Enable SO_BROADCAST. - ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - // Bind the sender to the broadcast address. - auto src_addr = V4Broadcast(); - ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr), - src_addr.addr_len), - SyscallSucceedsWithValue(0)); - socklen_t src_sz = src_addr.addr_len; - ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(src_sz, src_addr.addr_len); - - // Send the message. - auto dst_addr = V4Broadcast(); - reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port; - constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT( - sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - - // Verify that the message was received. - char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); -} - -// Verifies that a UDP broadcast can be sent and then received back on the same -// socket that is bound to the ANY address (0.0.0.0). -// FIXME(b/141938460): This can be combined with the previous test -// (UDPBroadcastSendRecvOnSocketBoundToBroadcast). -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastSendRecvOnSocketBoundToAny) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Enable SO_BROADCAST. - ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - // Bind the sender to the ANY address. - auto src_addr = V4Any(); - ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr), - src_addr.addr_len), - SyscallSucceedsWithValue(0)); - socklen_t src_sz = src_addr.addr_len; - ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(src_sz, src_addr.addr_len); - - // Send the message. - auto dst_addr = V4Broadcast(); - reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port; - constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT( - sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - - // Verify that the message was received. - char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); -} - -// Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST -// disabled. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Broadcast a test message without having enabled SO_BROADCAST on the sending - // socket. - auto addr = V4Broadcast(); - reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = htons(12345); - constexpr char kTestMsg[] = "hello, world"; - - EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len), - SyscallFailsWithErrno(EACCES)); -} - -// Verifies that a UDP unicast on an unbound socket reaches its destination. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto rcvr = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the receiver and retrieve its address and port number. - sockaddr_in addr = {}; - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = htonl(INADDR_ANY); - addr.sin_port = htons(0); - ASSERT_THAT(bind(rcvr->get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceedsWithValue(0)); - memset(&addr, 0, sizeof(addr)); - socklen_t addr_sz = sizeof(addr); - ASSERT_THAT(getsockname(rcvr->get(), - reinterpret_cast<struct sockaddr*>(&addr), &addr_sz), - SyscallSucceedsWithValue(0)); - - // Send a test message to the receiver. - constexpr char kTestMsg[] = "hello, world"; - ASSERT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<struct sockaddr*>(&addr), addr_sz), - SyscallSucceedsWithValue(sizeof(kTestMsg))); - char buf[sizeof(kTestMsg)] = {}; - ASSERT_THAT(recv(rcvr->get(), buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(kTestMsg))); -} - -// Check that multicast packets won't be delivered to the sending socket with no -// set interface or group membership. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - TestSendMulticastSelfNoGroup) { - // FIXME(b/125485338): A group membership is not required for external - // multicast on gVisor. - SKIP_IF(IsRunningOnGvisor()); - - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - auto bind_addr = V4Any(); - ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - bind_addr.addr_len), - SyscallSucceeds()); - socklen_t bind_addr_len = bind_addr.addr_len; - ASSERT_THAT( - getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - &bind_addr_len), - SyscallSucceeds()); - EXPECT_EQ(bind_addr_len, bind_addr.addr_len); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that multicast packets will be delivered to the sending socket without -// setting an interface. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - auto bind_addr = V4Any(); - ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - bind_addr.addr_len), - SyscallSucceeds()); - socklen_t bind_addr_len = bind_addr.addr_len; - ASSERT_THAT( - getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - &bind_addr_len), - SyscallSucceeds()); - EXPECT_EQ(bind_addr_len, bind_addr.addr_len); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - ASSERT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast packets won't be delivered to the sending socket with no -// set interface and IP_MULTICAST_LOOP disabled. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - TestSendMulticastSelfLoopOff) { - auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - auto bind_addr = V4Any(); - ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - bind_addr.addr_len), - SyscallSucceeds()); - socklen_t bind_addr_len = bind_addr.addr_len; - ASSERT_THAT( - getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - &bind_addr_len), - SyscallSucceeds()); - EXPECT_EQ(bind_addr_len, bind_addr.addr_len); - - // Disable multicast looping. - EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - // Register to receive multicast packets. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT( - RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that multicast packets won't be delivered to another socket with no -// set interface or group membership. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { - // FIXME(b/125485338): A group membership is not required for external - // multicast on gVisor. - SKIP_IF(IsRunningOnGvisor()); - - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that multicast packets will be delivered to another socket without -// setting an interface. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that multicast packets won't be delivered to another socket with no -// set interface and IP_MULTICAST_LOOP disabled on the sending socket. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - TestSendMulticastSenderNoLoop) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Disable multicast looping on the sender. - EXPECT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we did not receive the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Check that multicast packets will be delivered to the sending socket without -// setting an interface and IP_MULTICAST_LOOP disabled on the receiving socket. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - TestSendMulticastReceiverNoLoop) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Bind the second FD to the v4 any address to ensure that we can receive the - // multicast packet. - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - - // Disable multicast looping on the receiver. - ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_MULTICAST_LOOP, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - // Register to receive multicast packets. - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Send a multicast packet. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Check that we received the multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); -} - -// Check that two sockets can join the same multicast group at the same time, -// and both will receive data on it when bound to the ANY address. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - TestSendMulticastToTwoBoundToAny) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - std::unique_ptr<FileDescriptor> receivers[2] = { - ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; - - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - auto receiver_addr = V4Any(); - int bound_port = 0; - for (auto& receiver : receivers) { - ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - // Bind to ANY to receive multicast packets. - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - EXPECT_EQ( - htonl(INADDR_ANY), - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); - // On the first iteration, save the port we are bound to. On the second - // iteration, verify the port is the same as the one from the first - // iteration. In other words, both sockets listen on the same port. - if (bound_port == 0) { - bound_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - } else { - EXPECT_EQ(bound_port, - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); - } - - // Register to receive multicast packets. - ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), - SyscallSucceeds()); - } - - // Send a multicast packet to the group and verify both receivers get it. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - for (auto& receiver : receivers) { - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); - } -} - -// Check that two sockets can join the same multicast group at the same time, -// and both will receive data on it when bound to the multicast address. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - TestSendMulticastToTwoBoundToMulticastAddress) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - std::unique_ptr<FileDescriptor> receivers[2] = { - ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; - - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - auto receiver_addr = V4Multicast(); - int bound_port = 0; - for (auto& receiver : receivers) { - ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - EXPECT_EQ( - inet_addr(kMulticastAddress), - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); - // On the first iteration, save the port we are bound to. On the second - // iteration, verify the port is the same as the one from the first - // iteration. In other words, both sockets listen on the same port. - if (bound_port == 0) { - bound_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - } else { - EXPECT_EQ( - inet_addr(kMulticastAddress), - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); - EXPECT_EQ(bound_port, - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); - } - - // Register to receive multicast packets. - ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), - SyscallSucceeds()); - } - - // Send a multicast packet to the group and verify both receivers get it. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - for (auto& receiver : receivers) { - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); - } -} - -// Check that two sockets can join the same multicast group at the same time, -// and with one bound to the wildcard address and the other bound to the -// multicast address, both will receive data. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - TestSendMulticastToTwoBoundToAnyAndMulticastAddress) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - std::unique_ptr<FileDescriptor> receivers[2] = { - ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; - - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - // The first receiver binds to the wildcard address. - auto receiver_addr = V4Any(); - int bound_port = 0; - for (auto& receiver : receivers) { - ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - // On the first iteration, save the port we are bound to and change the - // receiver address from V4Any to V4Multicast so the second receiver binds - // to that. On the second iteration, verify the port is the same as the one - // from the first iteration but the address is different. - if (bound_port == 0) { - EXPECT_EQ( - htonl(INADDR_ANY), - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); - bound_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - receiver_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port = - bound_port; - } else { - EXPECT_EQ( - inet_addr(kMulticastAddress), - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); - EXPECT_EQ(bound_port, - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); - } - - // Register to receive multicast packets. - ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), - SyscallSucceeds()); - } - - // Send a multicast packet to the group and verify both receivers get it. - auto send_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; - char send_buf[200]; - RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - for (auto& receiver : receivers) { - char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); - } -} - -// Check that when receiving a looped-back multicast packet, its source address -// is not a multicast address. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - IpMulticastLoopbackFromAddr) { - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - int receiver_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Connect to the multicast address. This binds us to the outgoing interface - // and allows us to get its IP (to be compared against the src-IP on the - // receiver side). - auto sendto_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port; - ASSERT_THAT(RetryEINTR(connect)( - sender->get(), reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceeds()); - auto sender_addr = V4EmptyAddress(); - ASSERT_THAT( - getsockname(sender->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - &sender_addr.addr_len), - SyscallSucceeds()); - ASSERT_EQ(sizeof(struct sockaddr_in), sender_addr.addr_len); - sockaddr_in* sender_addr_in = - reinterpret_cast<sockaddr_in*>(&sender_addr.addr); - - // Send a multicast packet. - char send_buf[4] = {}; - ASSERT_THAT(RetryEINTR(send)(sender->get(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Receive a multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - auto src_addr = V4EmptyAddress(); - ASSERT_THAT( - RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0, - reinterpret_cast<sockaddr*>(&src_addr.addr), - &src_addr.addr_len), - SyscallSucceedsWithValue(sizeof(recv_buf))); - ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len); - sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr); - - // Verify that the received source IP:port matches the sender one. - EXPECT_EQ(sender_addr_in->sin_port, src_addr_in->sin_port); - EXPECT_EQ(sender_addr_in->sin_addr.s_addr, src_addr_in->sin_addr.s_addr); -} - -// Check that when setting the IP_MULTICAST_IF option to both an index pointing -// to the loopback interface and an address pointing to the non-loopback -// interface, a multicast packet sent out uses the latter as its source address. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - IpMulticastLoopbackIfNicAndAddr) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); - - // Create receiver, bind to ANY and join the multicast group. - auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); - socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - int receiver_port = - reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = lo_if_idx_; - ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallSucceeds()); - - // Set outgoing multicast interface config, with NIC and addr pointing to - // different interfaces. - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - ip_mreqn iface = {}; - iface.imr_ifindex = lo_if_idx_; - iface.imr_address = eth_if_sin_addr_; - ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); - - // Send a multicast packet. - auto sendto_addr = V4Multicast(); - reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port; - char send_buf[4] = {}; - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); - - // Receive a multicast packet. - char recv_buf[sizeof(send_buf)] = {}; - auto src_addr = V4EmptyAddress(); - ASSERT_THAT( - RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0, - reinterpret_cast<sockaddr*>(&src_addr.addr), - &src_addr.addr_len), - SyscallSucceedsWithValue(sizeof(recv_buf))); - ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len); - sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr); - - // FIXME (b/137781162): When sending a multicast packet use the proper logic - // to determine the packet's src-IP. - SKIP_IF(IsRunningOnGvisor()); - - // Verify the received source address. - EXPECT_EQ(eth_if_sin_addr_.s_addr, src_addr_in->sin_addr.s_addr); -} - -// Check that when we are bound to one interface we can set IP_MULTICAST_IF to -// another interface. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - IpMulticastLoopbackBindToOneIfSetMcastIfToAnother) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); - - // FIXME (b/137790511): When bound to one interface it is not possible to set - // IP_MULTICAST_IF to a different interface. - SKIP_IF(IsRunningOnGvisor()); - - // Create sender and bind to eth interface. - auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - ASSERT_THAT(bind(sender->get(), eth_if_addr_, sizeof(sockaddr_in)), - SyscallSucceeds()); - - // Run through all possible combinations of index and address for - // IP_MULTICAST_IF that selects the loopback interface. - struct { - int imr_ifindex; - struct in_addr imr_address; - } test_data[] = { - {lo_if_idx_, {}}, - {0, lo_if_sin_addr_}, - {lo_if_idx_, lo_if_sin_addr_}, - {lo_if_idx_, eth_if_sin_addr_}, - }; - for (auto t : test_data) { - ip_mreqn iface = {}; - iface.imr_ifindex = t.imr_ifindex; - iface.imr_address = t.imr_address; - EXPECT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()) - << "imr_index=" << iface.imr_ifindex - << " imr_address=" << GetAddr4Str(&iface.imr_address); - } -} -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h deleted file mode 100644 index bec2e96ee..000000000 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_EXTERNAL_NETWORKING_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_EXTERNAL_NETWORKING_H_ - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to unbound IPv4 UDP sockets in a sandbox -// with external networking support. -class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest { - protected: - void SetUp(); - - IfAddrHelper if_helper_; - - // got_if_infos_ is set to false if SetUp() could not obtain all interface - // infos that we need. - bool got_if_infos_; - - // Interface infos. - int lo_if_idx_; - int eth_if_idx_; - sockaddr* lo_if_addr_; - sockaddr* eth_if_addr_; - in_addr lo_if_sin_addr_; - in_addr eth_if_sin_addr_; -}; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_EXTERNAL_NETWORKING_H_ diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc deleted file mode 100644 index f6e64c157..000000000 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h" - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketKind> GetSockets() { - return ApplyVec<SocketKind>( - IPv4UDPUnboundSocket, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P(IPv4UDPUnboundSockets, - IPv4UDPUnboundExternalNetworkingSocketTest, - ::testing::ValuesIn(GetSockets())); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc deleted file mode 100644 index f121c044d..000000000 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_ipv4_udp_unbound.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -INSTANTIATE_TEST_SUITE_P( - IPv4UDPSockets, IPv4UDPUnboundSocketTest, - ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, - AllBitwiseCombinations(List<int>{ - 0, SOCK_NONBLOCK})))); - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc deleted file mode 100644 index 15d4b85a7..000000000 --- a/test/syscalls/linux/socket_netdevice.cc +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <linux/netlink.h> -#include <linux/rtnetlink.h> -#include <linux/sockios.h> -#include <sys/ioctl.h> -#include <sys/socket.h> - -#include "gtest/gtest.h" -#include "absl/base/internal/endian.h" -#include "test/syscalls/linux/socket_netlink_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -// Tests for netdevice queries. - -namespace gvisor { -namespace testing { - -namespace { - -using ::testing::AnyOf; -using ::testing::Eq; - -TEST(NetdeviceTest, Loopback) { - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - - // Prepare the request. - struct ifreq ifr; - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - - // Check for a non-zero interface index. - ASSERT_THAT(ioctl(sock.get(), SIOCGIFINDEX, &ifr), SyscallSucceeds()); - EXPECT_NE(ifr.ifr_ifindex, 0); - - // Check that the loopback is zero hardware address. - ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds()); - EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0); - EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0); - EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0); - EXPECT_EQ(ifr.ifr_hwaddr.sa_data[3], 0); - EXPECT_EQ(ifr.ifr_hwaddr.sa_data[4], 0); - EXPECT_EQ(ifr.ifr_hwaddr.sa_data[5], 0); -} - -TEST(NetdeviceTest, Netmask) { - // We need an interface index to identify the loopback device. - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - struct ifreq ifr; - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - ASSERT_THAT(ioctl(sock.get(), SIOCGIFINDEX, &ifr), SyscallSucceeds()); - EXPECT_NE(ifr.ifr_ifindex, 0); - - // Use a netlink socket to get the netmask, which we'll then compare to the - // netmask obtained via ioctl. - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); - - struct request { - struct nlmsghdr hdr; - struct rtgenmsg rgm; - }; - - constexpr uint32_t kSeq = 12345; - - struct request req; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.rgm.rtgen_family = AF_UNSPEC; - - // Iterate through messages until we find the one containing the prefix length - // (i.e. netmask) for the loopback device. - int prefixlen = -1; - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); - - EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) - << std::hex << hdr->nlmsg_flags; - - EXPECT_EQ(hdr->nlmsg_seq, kSeq); - EXPECT_EQ(hdr->nlmsg_pid, port); - - if (hdr->nlmsg_type != RTM_NEWADDR) { - return; - } - - // RTM_NEWADDR contains at least the header and ifaddrmsg. - EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg)); - - struct ifaddrmsg* ifaddrmsg = - reinterpret_cast<struct ifaddrmsg*>(NLMSG_DATA(hdr)); - if (ifaddrmsg->ifa_index == static_cast<uint32_t>(ifr.ifr_ifindex) && - ifaddrmsg->ifa_family == AF_INET) { - prefixlen = ifaddrmsg->ifa_prefixlen; - } - }, - false)); - - ASSERT_GE(prefixlen, 0); - - // Netmask is stored big endian in struct sockaddr_in, so we do the same for - // comparison. - uint32_t mask = 0xffffffff << (32 - prefixlen); - mask = absl::gbswap_32(mask); - - // Check that the loopback interface has the correct subnet mask. - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - ASSERT_THAT(ioctl(sock.get(), SIOCGIFNETMASK, &ifr), SyscallSucceeds()); - EXPECT_EQ(ifr.ifr_netmask.sa_family, AF_INET); - struct sockaddr_in* sin = - reinterpret_cast<struct sockaddr_in*>(&ifr.ifr_netmask); - EXPECT_EQ(sin->sin_addr.s_addr, mask); -} - -TEST(NetdeviceTest, InterfaceName) { - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - - // Prepare the request. - struct ifreq ifr; - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - - // Check for a non-zero interface index. - ASSERT_THAT(ioctl(sock.get(), SIOCGIFINDEX, &ifr), SyscallSucceeds()); - EXPECT_NE(ifr.ifr_ifindex, 0); - - // Check that SIOCGIFNAME finds the loopback interface. - snprintf(ifr.ifr_name, IFNAMSIZ, "foo"); - ASSERT_THAT(ioctl(sock.get(), SIOCGIFNAME, &ifr), SyscallSucceeds()); - EXPECT_STREQ(ifr.ifr_name, "lo"); -} - -TEST(NetdeviceTest, InterfaceFlags) { - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - - // Prepare the request. - struct ifreq ifr; - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - - // Check that SIOCGIFFLAGS marks the interface with IFF_LOOPBACK, IFF_UP, and - // IFF_RUNNING. - ASSERT_THAT(ioctl(sock.get(), SIOCGIFFLAGS, &ifr), SyscallSucceeds()); - EXPECT_EQ(ifr.ifr_flags & IFF_UP, IFF_UP); - EXPECT_EQ(ifr.ifr_flags & IFF_RUNNING, IFF_RUNNING); -} - -TEST(NetdeviceTest, InterfaceMTU) { - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - - // Prepare the request. - struct ifreq ifr = {}; - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - - // Check that SIOCGIFMTU returns a nonzero MTU. - ASSERT_THAT(ioctl(sock.get(), SIOCGIFMTU, &ifr), SyscallSucceeds()); - EXPECT_GT(ifr.ifr_mtu, 0); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink.cc b/test/syscalls/linux/socket_netlink.cc deleted file mode 100644 index 4ec0fd4fa..000000000 --- a/test/syscalls/linux/socket_netlink.cc +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <linux/netlink.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -// Tests for all netlink socket protocols. - -namespace gvisor { -namespace testing { - -namespace { - -// NetlinkTest parameter is the protocol to test. -using NetlinkTest = ::testing::TestWithParam<int>; - -// Netlink sockets must be SOCK_DGRAM or SOCK_RAW. -TEST_P(NetlinkTest, Types) { - const int protocol = GetParam(); - - EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, protocol), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, protocol), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, protocol), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, protocol), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, protocol), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - - int fd; - EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, protocol), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, protocol), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST_P(NetlinkTest, AutomaticPort) { - const int protocol = GetParam(); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - EXPECT_THAT( - bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallSucceeds()); - - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - // This is the only netlink socket in the process, so it should get the PID as - // the port id. - // - // N.B. Another process could theoretically have explicitly reserved our pid - // as a port ID, but that is very unlikely. - EXPECT_EQ(addr.nl_pid, getpid()); -} - -// Calling connect automatically binds to an automatic port. -TEST_P(NetlinkTest, ConnectBinds) { - const int protocol = GetParam(); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - - // Each test is running in a pid namespace, so another process can explicitly - // reserve our pid as a port ID. In this case, a negative portid value will be - // set. - if (static_cast<pid_t>(addr.nl_pid) > 0) { - EXPECT_EQ(addr.nl_pid, getpid()); - } - - memset(&addr, 0, sizeof(addr)); - addr.nl_family = AF_NETLINK; - - // Connecting again is allowed, but keeps the same port. - EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - EXPECT_EQ(addr.nl_pid, getpid()); -} - -TEST_P(NetlinkTest, GetPeerName) { - const int protocol = GetParam(); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); - - struct sockaddr_nl addr = {}; - socklen_t addrlen = sizeof(addr); - - EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, sizeof(addr)); - EXPECT_EQ(addr.nl_family, AF_NETLINK); - // Peer is the kernel if we didn't connect elsewhere. - EXPECT_EQ(addr.nl_pid, 0); -} - -INSTANTIATE_TEST_SUITE_P(ProtocolTest, NetlinkTest, - ::testing::Values(NETLINK_ROUTE, - NETLINK_KOBJECT_UEVENT)); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc deleted file mode 100644 index e5aed1eec..000000000 --- a/test/syscalls/linux/socket_netlink_route.cc +++ /dev/null @@ -1,990 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <ifaddrs.h> -#include <linux/if.h> -#include <linux/netlink.h> -#include <linux/rtnetlink.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include <iostream> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/str_format.h" -#include "absl/types/optional.h" -#include "test/syscalls/linux/socket_netlink_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -// Tests for NETLINK_ROUTE sockets. - -namespace gvisor { -namespace testing { - -namespace { - -constexpr uint32_t kSeq = 12345; - -using ::testing::AnyOf; -using ::testing::Eq; - -// Parameters for SockOptTest. They are: -// 0: Socket option to query. -// 1: A predicate to run on the returned sockopt value. Should return true if -// the value is considered ok. -// 2: A description of what the sockopt value is expected to be. Should complete -// the sentence "<value> was unexpected, expected <description>" -using SockOptTest = ::testing::TestWithParam< - std::tuple<int, std::function<bool(int)>, std::string>>; - -TEST_P(SockOptTest, GetSockOpt) { - int sockopt = std::get<0>(GetParam()); - auto verifier = std::get<1>(GetParam()); - std::string verifier_description = std::get<2>(GetParam()); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - int res; - socklen_t len = sizeof(res); - - EXPECT_THAT(getsockopt(fd.get(), SOL_SOCKET, sockopt, &res, &len), - SyscallSucceeds()); - - EXPECT_EQ(len, sizeof(res)); - EXPECT_TRUE(verifier(res)) << absl::StrFormat( - "getsockopt(%d, SOL_SOCKET, %d, &res, &len) => res=%d was unexpected, " - "expected %s", - fd.get(), sockopt, res, verifier_description); -} - -std::function<bool(int)> IsPositive() { - return [](int val) { return val > 0; }; -} - -std::function<bool(int)> IsEqual(int target) { - return [target](int val) { return val == target; }; -} - -INSTANTIATE_TEST_SUITE_P( - NetlinkRouteTest, SockOptTest, - ::testing::Values( - std::make_tuple(SO_SNDBUF, IsPositive(), "positive send buffer size"), - std::make_tuple(SO_RCVBUF, IsPositive(), - "positive receive buffer size"), - std::make_tuple(SO_TYPE, IsEqual(SOCK_RAW), - absl::StrFormat("SOCK_RAW (%d)", SOCK_RAW)), - std::make_tuple(SO_DOMAIN, IsEqual(AF_NETLINK), - absl::StrFormat("AF_NETLINK (%d)", AF_NETLINK)), - std::make_tuple(SO_PROTOCOL, IsEqual(NETLINK_ROUTE), - absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)), - std::make_tuple(SO_PASSCRED, IsEqual(0), "0"))); - -// Validates the reponses to RTM_GETLINK + NLM_F_DUMP. -void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { - EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWLINK), Eq(NLMSG_DONE))); - - EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) - << std::hex << hdr->nlmsg_flags; - - EXPECT_EQ(hdr->nlmsg_seq, seq); - EXPECT_EQ(hdr->nlmsg_pid, port); - - if (hdr->nlmsg_type != RTM_NEWLINK) { - return; - } - - // RTM_NEWLINK contains at least the header and ifinfomsg. - EXPECT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); - - // TODO(mpratt): Check ifinfomsg contents and following attrs. -} - -PosixError DumpLinks( - const FileDescriptor& fd, uint32_t seq, - const std::function<void(const struct nlmsghdr* hdr)>& fn) { - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = seq; - req.ifm.ifi_family = AF_UNSPEC; - - return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); -} - -TEST(NetlinkRouteTest, GetLinkDump) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); - - // Loopback is common among all tests, check that it's found. - bool loopbackFound = false; - ASSERT_NO_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { - CheckGetLinkResponse(hdr, kSeq, port); - if (hdr->nlmsg_type != RTM_NEWLINK) { - return; - } - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); - const struct ifinfomsg* msg = - reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); - std::cout << "Found interface idx=" << msg->ifi_index - << ", type=" << std::hex << msg->ifi_type; - if (msg->ifi_type == ARPHRD_LOOPBACK) { - loopbackFound = true; - EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); - } - })); - EXPECT_TRUE(loopbackFound); -} - -struct Link { - int index; - std::string name; -}; - -PosixErrorOr<absl::optional<Link>> FindLoopbackLink() { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - absl::optional<Link> link; - RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { - if (hdr->nlmsg_type != RTM_NEWLINK || - hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { - return; - } - const struct ifinfomsg* msg = - reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); - if (msg->ifi_type == ARPHRD_LOOPBACK) { - const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); - if (rta == nullptr) { - // Ignore links that do not have a name. - return; - } - - link = Link(); - link->index = msg->ifi_index; - link->name = std::string(reinterpret_cast<const char*>(RTA_DATA(rta))); - } - })); - return link; -} - -// CheckLinkMsg checks a netlink message against an expected link. -void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { - ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK)); - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); - const struct ifinfomsg* msg = - reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); - EXPECT_EQ(msg->ifi_index, link.index); - - const struct rtattr* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); - EXPECT_NE(nullptr, rta) << "IFLA_IFNAME not found in message."; - if (rta != nullptr) { - std::string name(reinterpret_cast<const char*>(RTA_DATA(rta))); - EXPECT_EQ(name, link.name); - } -} - -TEST(NetlinkRouteTest, GetLinkByIndex) { - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - req.ifm.ifi_index = loopback_link->index; - - bool found = false; - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - CheckLinkMsg(hdr, *loopback_link); - found = true; - }, - false)); - EXPECT_TRUE(found) << "Netlink response does not contain any links."; -} - -TEST(NetlinkRouteTest, GetLinkByName) { - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - struct rtattr rtattr; - char ifname[IFNAMSIZ]; - char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; - }; - - struct request req = {}; - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - req.rtattr.rta_type = IFLA_IFNAME; - req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1); - strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname)); - req.hdr.nlmsg_len = - NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); - - bool found = false; - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - CheckLinkMsg(hdr, *loopback_link); - found = true; - }, - false)); - EXPECT_TRUE(found) << "Netlink response does not contain any links."; -} - -TEST(NetlinkRouteTest, GetLinkByIndexNotFound) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - req.ifm.ifi_index = 1234590; - - EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), - PosixErrorIs(ENODEV, ::testing::_)); -} - -TEST(NetlinkRouteTest, GetLinkByNameNotFound) { - const std::string name = "nodevice?!"; - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - struct rtattr rtattr; - char ifname[IFNAMSIZ]; - char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; - }; - - struct request req = {}; - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - req.rtattr.rta_type = IFLA_IFNAME; - req.rtattr.rta_len = RTA_LENGTH(name.size() + 1); - strncpy(req.ifname, name.c_str(), sizeof(req.ifname)); - req.hdr.nlmsg_len = - NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); - - EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), - PosixErrorIs(ENODEV, ::testing::_)); -} - -TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - // If type & 0x3 is equal to 0x2, this means a get request - // which doesn't require CAP_SYS_ADMIN. - req.hdr.nlmsg_type = ((__RTM_MAX + 1024) & (~0x3)) | 0x2; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - - EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), - PosixErrorIs(EOPNOTSUPP, ::testing::_)); -} - -TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - - struct iovec iov = {}; - iov.iov_base = &req; - iov.iov_len = sizeof(req); - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - // No destination required; it defaults to pid 0, the kernel. - - ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); - - // Small enough to ensure that the response doesn't fit. - constexpr size_t kBufferSize = 10; - std::vector<char> buf(kBufferSize); - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), - SyscallSucceedsWithValue(kBufferSize)); - EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC); -} - -TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - - struct iovec iov = {}; - iov.iov_base = &req; - iov.iov_len = sizeof(req); - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - // No destination required; it defaults to pid 0, the kernel. - - ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); - - // Small enough to ensure that the response doesn't fit. - constexpr size_t kBufferSize = 10; - std::vector<char> buf(kBufferSize); - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - int res = 0; - ASSERT_THAT(res = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC), - SyscallSucceeds()); - EXPECT_GT(res, kBufferSize); - EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC); -} - -TEST(NetlinkRouteTest, ControlMessageIgnored) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); - - struct request { - struct nlmsghdr control_hdr; - struct nlmsghdr message_hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - - // This control message is ignored. We still receive a response for the - // following RTM_GETLINK. - req.control_hdr.nlmsg_len = sizeof(req.control_hdr); - req.control_hdr.nlmsg_type = NLMSG_DONE; - req.control_hdr.nlmsg_seq = kSeq; - - req.message_hdr.nlmsg_len = sizeof(req.message_hdr) + sizeof(req.ifm); - req.message_hdr.nlmsg_type = RTM_GETLINK; - req.message_hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.message_hdr.nlmsg_seq = kSeq; - - req.ifm.ifi_family = AF_UNSPEC; - - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - CheckGetLinkResponse(hdr, kSeq, port); - }, - false)); -} - -TEST(NetlinkRouteTest, GetAddrDump) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); - - struct request { - struct nlmsghdr hdr; - struct rtgenmsg rgm; - }; - - struct request req; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.rgm.rtgen_family = AF_UNSPEC; - - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); - - EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) - << std::hex << hdr->nlmsg_flags; - - EXPECT_EQ(hdr->nlmsg_seq, kSeq); - EXPECT_EQ(hdr->nlmsg_pid, port); - - if (hdr->nlmsg_type != RTM_NEWADDR) { - return; - } - - // RTM_NEWADDR contains at least the header and ifaddrmsg. - EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg)); - - // TODO(mpratt): Check ifaddrmsg contents and following attrs. - }, - false)); -} - -TEST(NetlinkRouteTest, LookupAll) { - struct ifaddrs* if_addr_list = nullptr; - auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); - - // Not a syscall but we can use the syscall matcher as glibc sets errno. - ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); - - int count = 0; - for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { - if (!i->ifa_addr || (i->ifa_addr->sa_family != AF_INET && - i->ifa_addr->sa_family != AF_INET6)) { - continue; - } - count++; - } - ASSERT_GT(count, 0); -} - -TEST(NetlinkRouteTest, AddAddr) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); - - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifaddrmsg ifa; - struct rtattr rtattr; - struct in_addr addr; - char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; - }; - - struct request req = {}; - req.hdr.nlmsg_type = RTM_NEWADDR; - req.hdr.nlmsg_seq = kSeq; - req.ifa.ifa_family = AF_INET; - req.ifa.ifa_prefixlen = 24; - req.ifa.ifa_flags = 0; - req.ifa.ifa_scope = 0; - req.ifa.ifa_index = loopback_link->index; - req.rtattr.rta_type = IFA_LOCAL; - req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); - inet_pton(AF_INET, "10.0.0.1", &req.addr); - req.hdr.nlmsg_len = - NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len); - - // Create should succeed, as no such address in kernel. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK; - EXPECT_NO_ERRNO( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); - - // Replace an existing address should succeed. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; - req.hdr.nlmsg_seq++; - EXPECT_NO_ERRNO( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); - - // Create exclusive should fail, as we created the address above. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK; - req.hdr.nlmsg_seq++; - EXPECT_THAT( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len), - PosixErrorIs(EEXIST, ::testing::_)); -} - -// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. -TEST(NetlinkRouteTest, GetRouteDump) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); - - struct request { - struct nlmsghdr hdr; - struct rtmsg rtm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETROUTE; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.rtm.rtm_family = AF_UNSPEC; - - bool routeFound = false; - bool dstFound = true; - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - // Validate the reponse to RTM_GETROUTE + NLM_F_DUMP. - EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWROUTE), Eq(NLMSG_DONE))); - - EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) - << std::hex << hdr->nlmsg_flags; - - EXPECT_EQ(hdr->nlmsg_seq, kSeq); - EXPECT_EQ(hdr->nlmsg_pid, port); - - // The test should not proceed if it's not a RTM_NEWROUTE message. - if (hdr->nlmsg_type != RTM_NEWROUTE) { - return; - } - - // RTM_NEWROUTE contains at least the header and rtmsg. - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct rtmsg))); - const struct rtmsg* msg = - reinterpret_cast<const struct rtmsg*>(NLMSG_DATA(hdr)); - // NOTE: rtmsg fields are char fields. - std::cout << "Found route table=" << static_cast<int>(msg->rtm_table) - << ", protocol=" << static_cast<int>(msg->rtm_protocol) - << ", scope=" << static_cast<int>(msg->rtm_scope) - << ", type=" << static_cast<int>(msg->rtm_type); - - int len = RTM_PAYLOAD(hdr); - bool rtDstFound = false; - for (struct rtattr* attr = RTM_RTA(msg); RTA_OK(attr, len); - attr = RTA_NEXT(attr, len)) { - if (attr->rta_type == RTA_DST) { - char address[INET_ADDRSTRLEN] = {}; - inet_ntop(AF_INET, RTA_DATA(attr), address, sizeof(address)); - std::cout << ", dst=" << address; - rtDstFound = true; - } - } - - std::cout << std::endl; - - if (msg->rtm_table == RT_TABLE_MAIN) { - routeFound = true; - dstFound = rtDstFound && dstFound; - } - }, - false)); - // At least one route found in main route table. - EXPECT_TRUE(routeFound); - // Found RTA_DST for each route in main table. - EXPECT_TRUE(dstFound); -} - -// GetRouteRequest tests a RTM_GETROUTE request with RTM_F_LOOKUP_TABLE flag. -TEST(NetlinkRouteTest, GetRouteRequest) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); - - struct __attribute__((__packed__)) request { - struct nlmsghdr hdr; - struct rtmsg rtm; - struct nlattr nla; - struct in_addr sin_addr; - }; - - constexpr uint32_t kSeq = 12345; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETROUTE; - req.hdr.nlmsg_flags = NLM_F_REQUEST; - req.hdr.nlmsg_seq = kSeq; - - req.rtm.rtm_family = AF_INET; - req.rtm.rtm_dst_len = 32; - req.rtm.rtm_src_len = 0; - req.rtm.rtm_tos = 0; - req.rtm.rtm_table = RT_TABLE_UNSPEC; - req.rtm.rtm_protocol = RTPROT_UNSPEC; - req.rtm.rtm_scope = RT_SCOPE_UNIVERSE; - req.rtm.rtm_type = RTN_UNSPEC; - req.rtm.rtm_flags = RTM_F_LOOKUP_TABLE; - - req.nla.nla_len = 8; - req.nla.nla_type = RTA_DST; - inet_aton("127.0.0.2", &req.sin_addr); - - bool rtDstFound = false; - ASSERT_NO_ERRNO(NetlinkRequestResponseSingle( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - // Validate the reponse to RTM_GETROUTE request with RTM_F_LOOKUP_TABLE - // flag. - EXPECT_THAT(hdr->nlmsg_type, RTM_NEWROUTE); - - EXPECT_TRUE(hdr->nlmsg_flags == 0) << std::hex << hdr->nlmsg_flags; - - EXPECT_EQ(hdr->nlmsg_seq, kSeq); - EXPECT_EQ(hdr->nlmsg_pid, port); - - // RTM_NEWROUTE contains at least the header and rtmsg. - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct rtmsg))); - const struct rtmsg* msg = - reinterpret_cast<const struct rtmsg*>(NLMSG_DATA(hdr)); - - // NOTE: rtmsg fields are char fields. - std::cout << "Found route table=" << static_cast<int>(msg->rtm_table) - << ", protocol=" << static_cast<int>(msg->rtm_protocol) - << ", scope=" << static_cast<int>(msg->rtm_scope) - << ", type=" << static_cast<int>(msg->rtm_type); - - EXPECT_EQ(msg->rtm_family, AF_INET); - EXPECT_EQ(msg->rtm_dst_len, 32); - EXPECT_TRUE((msg->rtm_flags & RTM_F_CLONED) == RTM_F_CLONED) - << std::hex << msg->rtm_flags; - - int len = RTM_PAYLOAD(hdr); - std::cout << ", len=" << len; - for (struct rtattr* attr = RTM_RTA(msg); RTA_OK(attr, len); - attr = RTA_NEXT(attr, len)) { - if (attr->rta_type == RTA_DST) { - char address[INET_ADDRSTRLEN] = {}; - inet_ntop(AF_INET, RTA_DATA(attr), address, sizeof(address)); - std::cout << ", dst=" << address; - rtDstFound = true; - } else if (attr->rta_type == RTA_OIF) { - const char* oif = reinterpret_cast<const char*>(RTA_DATA(attr)); - std::cout << ", oif=" << oif; - } - } - - std::cout << std::endl; - })); - // Found RTA_DST for RTM_F_LOOKUP_TABLE. - EXPECT_TRUE(rtDstFound); -} - -// RecvmsgTrunc tests the recvmsg MSG_TRUNC flag with zero length output -// buffer. MSG_TRUNC with a zero length buffer should consume subsequent -// messages off the socket. -TEST(NetlinkRouteTest, RecvmsgTrunc) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct rtgenmsg rgm; - }; - - struct request req; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.rgm.rtgen_family = AF_UNSPEC; - - struct iovec iov = {}; - iov.iov_base = &req; - iov.iov_len = sizeof(req); - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); - - iov.iov_base = NULL; - iov.iov_len = 0; - - int trunclen, trunclen2; - - // Note: This test assumes at least two messages are returned by the - // RTM_GETADDR request. That means at least one RTM_NEWLINK message and one - // NLMSG_DONE message. We cannot read all the messages without blocking - // because we would need to read the message into a buffer and check the - // nlmsg_type for NLMSG_DONE. However, the test depends on reading into a - // zero-length buffer. - - // First, call recvmsg with MSG_TRUNC. This will read the full message from - // the socket and return it's full length. Subsequent calls to recvmsg will - // read the next messages from the socket. - ASSERT_THAT(trunclen = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC), - SyscallSucceeds()); - - // Message should always be truncated. However, While the destination iov is - // zero length, MSG_TRUNC returns the size of the next message so it should - // not be zero. - ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); - ASSERT_NE(trunclen, 0); - // Returned length is at least the header and ifaddrmsg. - EXPECT_GE(trunclen, sizeof(struct nlmsghdr) + sizeof(struct ifaddrmsg)); - - // Reset the msg_flags to make sure that the recvmsg call is setting them - // properly. - msg.msg_flags = 0; - - // Make a second recvvmsg call to get the next message. - ASSERT_THAT(trunclen2 = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC), - SyscallSucceeds()); - ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); - ASSERT_NE(trunclen2, 0); - - // Assert that the received messages are not the same. - // - // We are calling recvmsg with a zero length buffer so we have no way to - // inspect the messages to make sure they are not equal in value. The best - // we can do is to compare their lengths. - ASSERT_NE(trunclen, trunclen2); -} - -// RecvmsgTruncPeek tests recvmsg with the combination of the MSG_TRUNC and -// MSG_PEEK flags and a zero length output buffer. This is normally used to -// read the full length of the next message on the socket without consuming -// it, so a properly sized buffer can be allocated to store the message. This -// test tests that scenario. -TEST(NetlinkRouteTest, RecvmsgTruncPeek) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct rtgenmsg rgm; - }; - - struct request req; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.rgm.rtgen_family = AF_UNSPEC; - - struct iovec iov = {}; - iov.iov_base = &req; - iov.iov_len = sizeof(req); - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); - - int type = -1; - do { - int peeklen; - int len; - - iov.iov_base = NULL; - iov.iov_len = 0; - - // Call recvmsg with MSG_PEEK and MSG_TRUNC. This will peek at the message - // and return it's full length. - // See: MSG_TRUNC http://man7.org/linux/man-pages/man2/recv.2.html - ASSERT_THAT( - peeklen = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_PEEK | MSG_TRUNC), - SyscallSucceeds()); - - // Message should always be truncated. - ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); - ASSERT_NE(peeklen, 0); - - // Reset the message flags for the next call. - msg.msg_flags = 0; - - // Make the actual call to recvmsg to get the actual data. We will use - // the length returned from the peek call for the allocated buffer size.. - std::vector<char> buf(peeklen); - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - ASSERT_THAT(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0), - SyscallSucceeds()); - - // Message should not be truncated since we allocated the correct buffer - // size. - EXPECT_NE(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); - - // MSG_PEEK should have left data on the socket and the subsequent call - // with should have retrieved the same data. Both calls should have - // returned the message's full length so they should be equal. - ASSERT_NE(len, 0); - ASSERT_EQ(peeklen, len); - - for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data()); - NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { - type = hdr->nlmsg_type; - } - } while (type != NLMSG_DONE && type != NLMSG_ERROR); -} - -// No SCM_CREDENTIALS are received without SO_PASSCRED set. -TEST(NetlinkRouteTest, NoPasscredNoCreds) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOff, - sizeof(kSockOptOff)), - SyscallSucceeds()); - - struct request { - struct nlmsghdr hdr; - struct rtgenmsg rgm; - }; - - struct request req; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.rgm.rtgen_family = AF_UNSPEC; - - struct iovec iov = {}; - iov.iov_base = &req; - iov.iov_len = sizeof(req); - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); - - iov.iov_base = NULL; - iov.iov_len = 0; - - char control[CMSG_SPACE(sizeof(struct ucred))] = {}; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - // Note: This test assumes at least one message is returned by the - // RTM_GETADDR request. - ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds()); - - // No control messages. - EXPECT_EQ(CMSG_FIRSTHDR(&msg), nullptr); -} - -// SCM_CREDENTIALS are received with SO_PASSCRED set. -TEST(NetlinkRouteTest, PasscredCreds) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - struct request { - struct nlmsghdr hdr; - struct rtgenmsg rgm; - }; - - struct request req; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = kSeq; - req.rgm.rtgen_family = AF_UNSPEC; - - struct iovec iov = {}; - iov.iov_base = &req; - iov.iov_len = sizeof(req); - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); - - iov.iov_base = NULL; - iov.iov_len = 0; - - char control[CMSG_SPACE(sizeof(struct ucred))] = {}; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - // Note: This test assumes at least one message is returned by the - // RTM_GETADDR request. - ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds()); - - struct ucred creds; - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(creds))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); - - memcpy(&creds, CMSG_DATA(cmsg), sizeof(creds)); - - // The peer is the kernel, which is "PID" 0. - EXPECT_EQ(creds.pid, 0); - // The kernel identifies as root. Also allow nobody in case this test is - // running in a userns without root mapped. - EXPECT_THAT(creds.uid, AnyOf(Eq(0), Eq(65534))); - EXPECT_THAT(creds.gid, AnyOf(Eq(0), Eq(65534))); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc deleted file mode 100644 index 53eb3b6b2..000000000 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_netlink_route_util.h" - -#include <linux/if.h> -#include <linux/netlink.h> -#include <linux/rtnetlink.h> - -#include "absl/types/optional.h" -#include "test/syscalls/linux/socket_netlink_util.h" - -namespace gvisor { -namespace testing { -namespace { - -constexpr uint32_t kSeq = 12345; - -} // namespace - -PosixError DumpLinks( - const FileDescriptor& fd, uint32_t seq, - const std::function<void(const struct nlmsghdr* hdr)>& fn) { - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = seq; - req.ifm.ifi_family = AF_UNSPEC; - - return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); -} - -PosixErrorOr<std::vector<Link>> DumpLinks() { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - std::vector<Link> links; - RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { - if (hdr->nlmsg_type != RTM_NEWLINK || - hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { - return; - } - const struct ifinfomsg* msg = - reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); - const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); - if (rta == nullptr) { - // Ignore links that do not have a name. - return; - } - - links.emplace_back(); - links.back().index = msg->ifi_index; - links.back().type = msg->ifi_type; - links.back().name = - std::string(reinterpret_cast<const char*>(RTA_DATA(rta))); - })); - return links; -} - -PosixErrorOr<absl::optional<Link>> FindLoopbackLink() { - ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); - for (const auto& link : links) { - if (link.type == ARPHRD_LOOPBACK) { - return absl::optional<Link>(link); - } - } - return absl::optional<Link>(); -} - -PosixError LinkAddLocalAddr(int index, int family, int prefixlen, - const void* addr, int addrlen) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifaddrmsg ifaddr; - char attrbuf[512]; - }; - - struct request req = {}; - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); - req.hdr.nlmsg_type = RTM_NEWADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifaddr.ifa_index = index; - req.ifaddr.ifa_family = family; - req.ifaddr.ifa_prefixlen = prefixlen; - - struct rtattr* rta = reinterpret_cast<struct rtattr*>( - reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); - rta->rta_type = IFA_LOCAL; - rta->rta_len = RTA_LENGTH(addrlen); - req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); - memcpy(RTA_DATA(rta), addr, addrlen); - - return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); -} - -PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifinfo; - char pad[NLMSG_ALIGNTO]; - }; - - struct request req = {}; - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); - req.hdr.nlmsg_type = RTM_NEWLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifinfo.ifi_index = index; - req.ifinfo.ifi_flags = flags; - req.ifinfo.ifi_change = change; - - return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); -} - -PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifinfo; - char attrbuf[512]; - }; - - struct request req = {}; - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); - req.hdr.nlmsg_type = RTM_NEWLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifinfo.ifi_index = index; - - struct rtattr* rta = reinterpret_cast<struct rtattr*>( - reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); - rta->rta_type = IFLA_ADDRESS; - rta->rta_len = RTA_LENGTH(addrlen); - req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); - memcpy(RTA_DATA(rta), addr, addrlen); - - return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h deleted file mode 100644 index 2c018e487..000000000 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ - -#include <linux/netlink.h> -#include <linux/rtnetlink.h> - -#include <vector> - -#include "absl/types/optional.h" -#include "test/syscalls/linux/socket_netlink_util.h" - -namespace gvisor { -namespace testing { - -struct Link { - int index; - int16_t type; - std::string name; -}; - -PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, - const std::function<void(const struct nlmsghdr* hdr)>& fn); - -PosixErrorOr<std::vector<Link>> DumpLinks(); - -PosixErrorOr<absl::optional<Link>> FindLoopbackLink(); - -// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. -PosixError LinkAddLocalAddr(int index, int family, int prefixlen, - const void* addr, int addrlen); - -// LinkChangeFlags changes interface flags. E.g. IFF_UP. -PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); - -// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface. -PosixError LinkSetMacAddr(int index, const void* addr, int addrlen); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ diff --git a/test/syscalls/linux/socket_netlink_uevent.cc b/test/syscalls/linux/socket_netlink_uevent.cc deleted file mode 100644 index da425bed4..000000000 --- a/test/syscalls/linux/socket_netlink_uevent.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <linux/filter.h> -#include <linux/netlink.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_netlink_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -// Tests for NETLINK_KOBJECT_UEVENT sockets. -// -// gVisor never sends any messages on these sockets, so we don't test the events -// themselves. - -namespace gvisor { -namespace testing { - -namespace { - -// SO_PASSCRED can be enabled. Since no messages are sent in gVisor, we don't -// actually test receiving credentials. -TEST(NetlinkUeventTest, PassCred) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); - - EXPECT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); -} - -// SO_DETACH_FILTER fails without a filter already installed. -TEST(NetlinkUeventTest, DetachNoFilter) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); - - int opt; - EXPECT_THAT( - setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)), - SyscallFailsWithErrno(ENOENT)); -} - -// We can attach a BPF filter. -TEST(NetlinkUeventTest, AttachFilter) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); - - // Minimal BPF program: a single ret. - struct sock_filter filter = {0x6, 0, 0, 0}; - struct sock_fprog prog = {}; - prog.len = 1; - prog.filter = &filter; - - EXPECT_THAT( - setsockopt(fd.get(), SOL_SOCKET, SO_ATTACH_FILTER, &prog, sizeof(prog)), - SyscallSucceeds()); - - int opt; - EXPECT_THAT( - setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)), - SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc deleted file mode 100644 index 952eecfe8..000000000 --- a/test/syscalls/linux/socket_netlink_util.cc +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_netlink_util.h" - -#include <linux/if_arp.h> -#include <linux/netlink.h> -#include <linux/rtnetlink.h> -#include <sys/socket.h> - -#include <vector> - -#include "absl/strings/str_cat.h" -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) { - FileDescriptor fd; - ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); - MaybeSave(); - - return std::move(fd); -} - -PosixErrorOr<uint32_t> NetlinkPortID(int fd) { - struct sockaddr_nl addr; - socklen_t addrlen = sizeof(addr); - - RETURN_ERROR_IF_SYSCALL_FAIL( - getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen)); - MaybeSave(); - - return static_cast<uint32_t>(addr.nl_pid); -} - -PosixError NetlinkRequestResponse( - const FileDescriptor& fd, void* request, size_t len, - const std::function<void(const struct nlmsghdr* hdr)>& fn, - bool expect_nlmsgerr) { - struct iovec iov = {}; - iov.iov_base = request; - iov.iov_len = len; - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - // No destination required; it defaults to pid 0, the kernel. - - RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0)); - - constexpr size_t kBufferSize = 4096; - std::vector<char> buf(kBufferSize); - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - // If NLM_F_MULTI is set, response is a series of messages that ends with a - // NLMSG_DONE message. - int type = -1; - int flags = 0; - do { - int len; - RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); - - // We don't bother with the complexity of dealing with truncated messages. - // We must allocate a large enough buffer up front. - if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) { - return PosixError(EIO, - absl::StrCat("Received truncated message with flags: ", - msg.msg_flags)); - } - - for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data()); - NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { - fn(hdr); - flags = hdr->nlmsg_flags; - type = hdr->nlmsg_type; - // Done should include an integer payload for dump_done_errno. - // See net/netlink/af_netlink.c:netlink_dump - // Some tools like the 'ip' tool check the minimum length of the - // NLMSG_DONE message. - if (type == NLMSG_DONE) { - EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int))); - } - } - } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR); - - if (expect_nlmsgerr) { - EXPECT_EQ(type, NLMSG_ERROR); - } else if (flags & NLM_F_MULTI) { - EXPECT_EQ(type, NLMSG_DONE); - } - return NoError(); -} - -PosixError NetlinkRequestResponseSingle( - const FileDescriptor& fd, void* request, size_t len, - const std::function<void(const struct nlmsghdr* hdr)>& fn) { - struct iovec iov = {}; - iov.iov_base = request; - iov.iov_len = len; - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - // No destination required; it defaults to pid 0, the kernel. - - RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0)); - - constexpr size_t kBufferSize = 4096; - std::vector<char> buf(kBufferSize); - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - int ret; - RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); - - // We don't bother with the complexity of dealing with truncated messages. - // We must allocate a large enough buffer up front. - if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) { - return PosixError( - EIO, - absl::StrCat("Received truncated message with flags: ", msg.msg_flags)); - } - - for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data()); - NLMSG_OK(hdr, ret); hdr = NLMSG_NEXT(hdr, ret)) { - fn(hdr); - } - - return NoError(); -} - -PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, - void* request, size_t len) { - // Dummy negative number for no error message received. - // We won't get a negative error number so there will be no confusion. - int err = -42; - RETURN_IF_ERRNO(NetlinkRequestResponse( - fd, request, len, - [&](const struct nlmsghdr* hdr) { - EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type); - EXPECT_EQ(hdr->nlmsg_seq, seq); - EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); - - const struct nlmsgerr* msg = - reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr)); - err = -msg->error; - }, - true)); - return PosixError(err); -} - -const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, - const struct ifinfomsg* msg, int16_t attr) { - const int ifi_space = NLMSG_SPACE(sizeof(*msg)); - int attrlen = hdr->nlmsg_len - ifi_space; - const struct rtattr* rta = reinterpret_cast<const struct rtattr*>( - reinterpret_cast<const uint8_t*>(hdr) + NLMSG_ALIGN(ifi_space)); - for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) { - if (rta->rta_type == attr) { - return rta; - } - } - return nullptr; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h deleted file mode 100644 index e13ead406..000000000 --- a/test/syscalls/linux/socket_netlink_util.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_ -#define GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_ - -#include <sys/socket.h> -// socket.h has to be included before if_arp.h. -#include <linux/if_arp.h> -#include <linux/netlink.h> -#include <linux/rtnetlink.h> - -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// Returns a bound netlink socket. -PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol); - -// Returns the port ID of the passed socket. -PosixErrorOr<uint32_t> NetlinkPortID(int fd); - -// Send the passed request and call fn on all response netlink messages. -// -// To be used on requests with NLM_F_MULTI reponses. -PosixError NetlinkRequestResponse( - const FileDescriptor& fd, void* request, size_t len, - const std::function<void(const struct nlmsghdr* hdr)>& fn, - bool expect_nlmsgerr); - -// Send the passed request and call fn on all response netlink messages. -// -// To be used on requests without NLM_F_MULTI reponses. -PosixError NetlinkRequestResponseSingle( - const FileDescriptor& fd, void* request, size_t len, - const std::function<void(const struct nlmsghdr* hdr)>& fn); - -// Send the passed request then expect and return an ack or error. -PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, - void* request, size_t len); - -// Find rtnetlink attribute in message. -const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, - const struct ifinfomsg* msg, int16_t attr); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_ diff --git a/test/syscalls/linux/socket_non_blocking.cc b/test/syscalls/linux/socket_non_blocking.cc deleted file mode 100644 index c3520cadd..000000000 --- a/test/syscalls/linux/socket_non_blocking.cc +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_non_blocking.h" - -#include <stdio.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -TEST_P(NonBlockingSocketPairTest, ReadNothingAvailable) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char buf[20] = {}; - ASSERT_THAT(ReadFd(sockets->first_fd(), buf, sizeof(buf)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(NonBlockingSocketPairTest, RecvNothingAvailable) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char buf[20] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(NonBlockingSocketPairTest, RecvMsgNothingAvailable) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct iovec iov; - char buf[20] = {}; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EAGAIN)); -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_non_blocking.h b/test/syscalls/linux/socket_non_blocking.h deleted file mode 100644 index bd3e02fd2..000000000 --- a/test/syscalls/linux/socket_non_blocking.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_BLOCKING_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_BLOCKING_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected non-blocking sockets. -using NonBlockingSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_BLOCKING_H_ diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc deleted file mode 100644 index c61817f14..000000000 --- a/test/syscalls/linux/socket_non_stream.cc +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_non_stream.h" - -#include <stdio.h> -#include <sys/socket.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(NonStreamSocketPairTest, SendMsgTooLarge) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int sndbuf; - socklen_t length = sizeof(sndbuf); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &sndbuf, &length), - SyscallSucceeds()); - - // Make the call too large to fit in the send buffer. - const int buffer_size = 3 * sndbuf; - - EXPECT_THAT(SendLargeSendMsg(sockets, buffer_size, false /* reader */), - SyscallFailsWithErrno(EMSGSIZE)); -} - -// Stream sockets allow data sent with a single (e.g. write, sendmsg) syscall -// to be read in pieces with multiple (e.g. read, recvmsg) syscalls. -// -// SplitRecv checks that control messages can only be read on the first (e.g. -// read, recvmsg) syscall, even if it doesn't provide space for the control -// message. -TEST_P(NonStreamSocketPairTest, SplitRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data) / 2]; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data))); - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -// Stream sockets allow data sent with multiple sends to be read in a single -// recv. Datagram sockets do not. -// -// SingleRecv checks that only a single message is readable in a single recv. -TEST_P(NonStreamSocketPairTest, SingleRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data1, sizeof(sent_data1), 0), - SyscallSucceedsWithValue(sizeof(sent_data1))); - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data2, sizeof(sent_data2), 0), - SyscallSucceedsWithValue(sizeof(sent_data2))); - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); -} - -TEST_P(NonStreamSocketPairTest, RecvmsgMsghdrFlagMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data) / 2] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); - - // Check that msghdr flags were updated. - EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); -} - -// Stream sockets allow data sent with multiple sends to be peeked at in a -// single recv. Datagram sockets (except for unix sockets) do not. -// -// SinglePeek checks that only a single message is peekable in a single recv. -TEST_P(NonStreamSocketPairTest, SinglePeek) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data1, sizeof(sent_data1), 0), - SyscallSucceedsWithValue(sizeof(sent_data1))); - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data2, sizeof(sent_data2), 0), - SyscallSucceedsWithValue(sizeof(sent_data2))); - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - for (int i = 0; i < 3; i++) { - memset(received_data, 0, sizeof(received_data)); - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_PEEK), - SyscallSucceedsWithValue(sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - } - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(sent_data1), 0), - SyscallSucceedsWithValue(sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(sent_data2), 0), - SyscallSucceedsWithValue(sizeof(sent_data2))); - EXPECT_EQ(0, memcmp(sent_data2, received_data, sizeof(sent_data2))); -} - -TEST_P(NonStreamSocketPairTest, MsgTruncTruncation) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data) / 2, MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(sent_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2)); - - // Check that we didn't get any extra data. - EXPECT_NE(0, memcmp(sent_data + sizeof(sent_data) / 2, - received_data + sizeof(received_data) / 2, - sizeof(sent_data) / 2)); -} - -TEST_P(NonStreamSocketPairTest, MsgTruncTruncationRecvmsgMsghdrFlagMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data) / 2] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(sent_data))); - EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); - - // Check that msghdr flags were updated. - EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); -} - -TEST_P(NonStreamSocketPairTest, MsgTruncSameSize) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data)]; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(NonStreamSocketPairTest, MsgTruncNotFull) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[2 * sizeof(sent_data)]; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(sent_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -// This test tests reading from a socket with MSG_TRUNC and a zero length -// receive buffer. The user should be able to get the message length. -TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncZeroLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - // The receive buffer is of zero length. - char received_data[0] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - // The syscall succeeds returning the full size of the message on the socket. - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(sent_data))); - - // Check that MSG_TRUNC is set on msghdr flags. - EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); -} - -// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero -// length receive buffer. The user should be able to get the message length -// without reading data off the socket. -TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncMsgPeekZeroLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - // The receive buffer is of zero length. - char peek_data[0] = {}; - - struct iovec peek_iov; - peek_iov.iov_base = peek_data; - peek_iov.iov_len = sizeof(peek_data); - struct msghdr peek_msg = {}; - peek_msg.msg_flags = -1; - peek_msg.msg_iov = &peek_iov; - peek_msg.msg_iovlen = 1; - - // The syscall succeeds returning the full size of the message on the socket. - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg, - MSG_TRUNC | MSG_PEEK), - SyscallSucceedsWithValue(sizeof(sent_data))); - - // Check that MSG_TRUNC is set on msghdr flags because the receive buffer is - // smaller than the message size. - EXPECT_EQ(peek_msg.msg_flags & MSG_TRUNC, MSG_TRUNC); - - char received_data[sizeof(sent_data)] = {}; - - struct iovec received_iov; - received_iov.iov_base = received_data; - received_iov.iov_len = sizeof(received_data); - struct msghdr received_msg = {}; - received_msg.msg_flags = -1; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - - // Next we can read the actual data. - ASSERT_THAT( - RetryEINTR(recvmsg)(sockets->second_fd(), &received_msg, MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(sent_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // Check that MSG_TRUNC is not set on msghdr flags because we read the whole - // message. - EXPECT_EQ(received_msg.msg_flags & MSG_TRUNC, 0); -} - -// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero -// length receive buffer and MSG_DONTWAIT. The user should be able to get an -// EAGAIN or EWOULDBLOCK error response. -TEST_P(NonStreamSocketPairTest, RecvmsgTruncPeekDontwaitZeroLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // NOTE: We don't send any data on the socket. - - // The receive buffer is of zero length. - char peek_data[0] = {}; - - struct iovec peek_iov; - peek_iov.iov_base = peek_data; - peek_iov.iov_len = sizeof(peek_data); - struct msghdr peek_msg = {}; - peek_msg.msg_flags = -1; - peek_msg.msg_iov = &peek_iov; - peek_msg.msg_iovlen = 1; - - // recvmsg fails with EAGAIN because no data is available on the socket. - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg, - MSG_TRUNC | MSG_PEEK | MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_non_stream.h b/test/syscalls/linux/socket_non_stream.h deleted file mode 100644 index 469fbe6a2..000000000 --- a/test/syscalls/linux/socket_non_stream.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected non-stream sockets. -using NonStreamSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_H_ diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc deleted file mode 100644 index b052f6e61..000000000 --- a/test/syscalls/linux/socket_non_stream_blocking.cc +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_non_stream_blocking.h" - -#include <stdio.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[100]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data) * 2] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_WAITALL), - SyscallSucceedsWithValue(sizeof(sent_data))); -} - -// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero -// length receive buffer and MSG_DONTWAIT. The recvmsg call should block on -// reading the data. -TEST_P(BlockingNonStreamSocketPairTest, - RecvmsgTruncPeekDontwaitZeroLenBlocking) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // NOTE: We don't initially send any data on the socket. - const int data_size = 10; - char sent_data[data_size]; - RandomizeBuffer(sent_data, data_size); - - // The receive buffer is of zero length. - char peek_data[0] = {}; - - struct iovec peek_iov; - peek_iov.iov_base = peek_data; - peek_iov.iov_len = sizeof(peek_data); - struct msghdr peek_msg = {}; - peek_msg.msg_flags = -1; - peek_msg.msg_iov = &peek_iov; - peek_msg.msg_iovlen = 1; - - ScopedThread t([&]() { - // The syscall succeeds returning the full size of the message on the - // socket. This should block until there is data on the socket. - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg, - MSG_TRUNC | MSG_PEEK), - SyscallSucceedsWithValue(data_size)); - }); - - absl::SleepFor(absl::Seconds(1)); - ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), sent_data, data_size, 0), - SyscallSucceedsWithValue(data_size)); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_non_stream_blocking.h b/test/syscalls/linux/socket_non_stream_blocking.h deleted file mode 100644 index 6e205a039..000000000 --- a/test/syscalls/linux/socket_non_stream_blocking.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_BLOCKING_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_BLOCKING_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of blocking connected non-stream -// sockets. -using BlockingNonStreamSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_BLOCKING_H_ diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc deleted file mode 100644 index 6522b2e01..000000000 --- a/test/syscalls/linux/socket_stream.cc +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_stream.h" - -#include <stdio.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(StreamSocketPairTest, SplitRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data) / 2]; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data))); - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data + sizeof(received_data), received_data, - sizeof(received_data))); -} - -// Stream sockets allow data sent with multiple sends to be read in a single -// recv. -// -// CoalescedRecv checks that multiple messages are readable in a single recv. -TEST_P(StreamSocketPairTest, CoalescedRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data1, sizeof(sent_data1), 0), - SyscallSucceedsWithValue(sizeof(sent_data1))); - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data2, sizeof(sent_data2), 0), - SyscallSucceedsWithValue(sizeof(sent_data2))); - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1), - sizeof(sent_data2))); -} - -TEST_P(StreamSocketPairTest, WriteOneSideClosed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - const char str[] = "abc"; - ASSERT_THAT(write(sockets->second_fd(), str, 3), - SyscallFailsWithErrno(EPIPE)); -} - -TEST_P(StreamSocketPairTest, RecvmsgMsghdrFlagsNoMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data) / 2] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); - - // Check that msghdr flags were cleared (MSG_TRUNC was not set). - ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0); -} - -TEST_P(StreamSocketPairTest, RecvmsgTruncZeroLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[0] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC), - SyscallSucceedsWithValue(0)); - - // Check that msghdr flags were cleared (MSG_TRUNC was not set). - ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0); -} - -TEST_P(StreamSocketPairTest, RecvmsgTruncPeekZeroLen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[0] = {}; - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT( - RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC | MSG_PEEK), - SyscallSucceedsWithValue(0)); - - // Check that msghdr flags were cleared (MSG_TRUNC was not set). - ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0); -} - -TEST_P(StreamSocketPairTest, MsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data)]; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data) / 2, MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2)); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_stream.h b/test/syscalls/linux/socket_stream.h deleted file mode 100644 index b837b8f8c..000000000 --- a/test/syscalls/linux/socket_stream.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of blocking and non-blocking -// connected stream sockets. -using StreamSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_H_ diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc deleted file mode 100644 index 538ee2268..000000000 --- a/test/syscalls/linux/socket_stream_blocking.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_stream_blocking.h" - -#include <stdio.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(BlockingStreamSocketPairTest, BlockPartialWriteClosed) { - // FIXME(b/35921550): gVisor doesn't support SO_SNDBUF on UDS, nor does it - // enforce any limit; it will write arbitrary amounts of data without - // blocking. - SKIP_IF(IsRunningOnGvisor()); - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int buffer_size; - socklen_t length = sizeof(buffer_size); - ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, - &buffer_size, &length), - SyscallSucceeds()); - - int wfd = sockets->first_fd(); - ScopedThread t([wfd, buffer_size]() { - std::vector<char> buf(2 * buffer_size); - // Write more than fits in the buffer. Blocks then returns partial write - // when the other end is closed. The next call returns EPIPE. - // - // N.B. writes occur in chunks, so we may see less than buffer_size from - // the first call. - ASSERT_THAT(write(wfd, buf.data(), buf.size()), - SyscallSucceedsWithValue(::testing::Gt(0))); - ASSERT_THAT(write(wfd, buf.data(), buf.size()), - ::testing::AnyOf(SyscallFailsWithErrno(EPIPE), - SyscallFailsWithErrno(ECONNRESET))); - }); - - // Leave time for write to become blocked. - absl::SleepFor(absl::Seconds(1)); - - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); -} - -// Random save may interrupt the call to sendmsg() in SendLargeSendMsg(), -// causing the write to be incomplete and the test to hang. -TEST_P(BlockingStreamSocketPairTest, SendMsgTooLarge_NoRandomSave) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int sndbuf; - socklen_t length = sizeof(sndbuf); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &sndbuf, &length), - SyscallSucceeds()); - - // Make the call too large to fit in the send buffer. - const int buffer_size = 3 * sndbuf; - - EXPECT_THAT(SendLargeSendMsg(sockets, buffer_size, true /* reader */), - SyscallSucceedsWithValue(buffer_size)); -} - -TEST_P(BlockingStreamSocketPairTest, RecvLessThanBuffer) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[100]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[200] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); -} - -// Test that MSG_WAITALL causes recv to block until all requested data is -// received. Random save can interrupt blocking and cause received data to be -// returned, even if the amount received is less than the full requested amount. -TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll_NoRandomSave) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[100]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - constexpr auto kDuration = absl::Milliseconds(200); - auto before = Now(CLOCK_MONOTONIC); - - const ScopedThread t([&]() { - absl::SleepFor(kDuration); - - // Don't let saving after the write interrupt the blocking recv. - const DisableSave ds; - - ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - }); - - char received_data[sizeof(sent_data) * 2] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_WAITALL), - SyscallSucceedsWithValue(sizeof(received_data))); - - auto after = Now(CLOCK_MONOTONIC); - EXPECT_GE(after - before, kDuration); -} - -TEST_P(BlockingStreamSocketPairTest, SendTimeout) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - std::vector<char> buf(kPageSize); - // We don't know how much data the socketpair will buffer, so we may do an - // arbitrarily large number of writes; saving after each write causes this - // test's time to explode. - const DisableSave ds; - for (;;) { - int ret; - ASSERT_THAT( - ret = RetryEINTR(send)(sockets->first_fd(), buf.data(), buf.size(), 0), - ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN))); - if (ret == -1) { - break; - } - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_stream_blocking.h b/test/syscalls/linux/socket_stream_blocking.h deleted file mode 100644 index 9fd19ff90..000000000 --- a/test/syscalls/linux/socket_stream_blocking.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_BLOCKING_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_BLOCKING_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of blocking connected stream -// sockets. -using BlockingStreamSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_BLOCKING_H_ diff --git a/test/syscalls/linux/socket_stream_nonblock.cc b/test/syscalls/linux/socket_stream_nonblock.cc deleted file mode 100644 index 74d608741..000000000 --- a/test/syscalls/linux/socket_stream_nonblock.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_stream_nonblock.h" - -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/uio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -using ::testing::Le; - -TEST_P(NonBlockingStreamSocketPairTest, SendMsgTooLarge) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int sndbuf; - socklen_t length = sizeof(sndbuf); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &sndbuf, &length), - SyscallSucceeds()); - - // Make the call too large to fit in the send buffer. - const int buffer_size = 3 * sndbuf; - - EXPECT_THAT(SendLargeSendMsg(sockets, buffer_size, false /* reader */), - SyscallSucceedsWithValue(Le(buffer_size))); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_stream_nonblock.h b/test/syscalls/linux/socket_stream_nonblock.h deleted file mode 100644 index c3b7fad91..000000000 --- a/test/syscalls/linux/socket_stream_nonblock.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_NONBLOCK_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_NONBLOCK_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of non-blocking connected stream -// sockets. -using NonBlockingStreamSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_NONBLOCK_H_ diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc deleted file mode 100644 index 5d3a39868..000000000 --- a/test/syscalls/linux/socket_test_util.cc +++ /dev/null @@ -1,912 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_test_util.h" - -#include <arpa/inet.h> -#include <poll.h> -#include <sys/socket.h> - -#include <memory> - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/time/clock.h" -#include "absl/types/optional.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -Creator<SocketPair> SyscallSocketPairCreator(int domain, int type, - int protocol) { - return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> { - int pair[2]; - RETURN_ERROR_IF_SYSCALL_FAIL(socketpair(domain, type, protocol, pair)); - MaybeSave(); // Save on successful creation. - return absl::make_unique<FDSocketPair>(pair[0], pair[1]); - }; -} - -Creator<FileDescriptor> SyscallSocketCreator(int domain, int type, - int protocol) { - return [=]() -> PosixErrorOr<std::unique_ptr<FileDescriptor>> { - int fd = 0; - RETURN_ERROR_IF_SYSCALL_FAIL(fd = socket(domain, type, protocol)); - MaybeSave(); // Save on successful creation. - return absl::make_unique<FileDescriptor>(fd); - }; -} - -PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain) { - struct sockaddr_un addr = {}; - std::string path = NewTempAbsPathInDir("/tmp"); - if (path.size() >= sizeof(addr.sun_path)) { - return PosixError(EINVAL, - "Unable to generate a temp path of appropriate length"); - } - - if (abstract) { - // Indicate that the path is in the abstract namespace. - path[0] = 0; - } - memcpy(addr.sun_path, path.c_str(), path.length()); - addr.sun_family = domain; - return addr; -} - -Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain, - int type, int protocol) { - return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> { - ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un bind_addr, - UniqueUnixAddr(abstract, domain)); - ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un extra_addr, - UniqueUnixAddr(abstract, domain)); - - int bound; - RETURN_ERROR_IF_SYSCALL_FAIL(bound = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(bound, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr))); - MaybeSave(); // Successful bind. - RETURN_ERROR_IF_SYSCALL_FAIL(listen(bound, /* backlog = */ 5)); - MaybeSave(); // Successful listen. - - int connected; - RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - RETURN_ERROR_IF_SYSCALL_FAIL( - connect(connected, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr))); - MaybeSave(); // Successful connect. - - int accepted; - RETURN_ERROR_IF_SYSCALL_FAIL( - accepted = accept4(bound, nullptr, nullptr, - type & (SOCK_NONBLOCK | SOCK_CLOEXEC))); - MaybeSave(); // Successful connect. - - // Cleanup no longer needed resources. - RETURN_ERROR_IF_SYSCALL_FAIL(close(bound)); - MaybeSave(); // Dropped original socket. - - // Only unlink if path is not in abstract namespace. - if (bind_addr.sun_path[0] != 0) { - RETURN_ERROR_IF_SYSCALL_FAIL(unlink(bind_addr.sun_path)); - MaybeSave(); // Unlinked path. - } - - // accepted is before connected to destruct connected before accepted. - // Destructors for nonstatic member objects are called in the reverse order - // in which they appear in the class declaration. - return absl::make_unique<AddrFDSocketPair>(accepted, connected, bind_addr, - extra_addr); - }; -} - -Creator<SocketPair> FilesystemAcceptBindSocketPairCreator(int domain, int type, - int protocol) { - return AcceptBindSocketPairCreator(/* abstract= */ false, domain, type, - protocol); -} - -Creator<SocketPair> AbstractAcceptBindSocketPairCreator(int domain, int type, - int protocol) { - return AcceptBindSocketPairCreator(/* abstract= */ true, domain, type, - protocol); -} - -Creator<SocketPair> BidirectionalBindSocketPairCreator(bool abstract, - int domain, int type, - int protocol) { - return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> { - ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un addr1, - UniqueUnixAddr(abstract, domain)); - ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un addr2, - UniqueUnixAddr(abstract, domain)); - - int sock1; - RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(sock1, reinterpret_cast<struct sockaddr*>(&addr1), sizeof(addr1))); - MaybeSave(); // Successful bind. - - int sock2; - RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(sock2, reinterpret_cast<struct sockaddr*>(&addr2), sizeof(addr2))); - MaybeSave(); // Successful bind. - - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock1, reinterpret_cast<struct sockaddr*>(&addr2), sizeof(addr2))); - MaybeSave(); // Successful connect. - - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock2, reinterpret_cast<struct sockaddr*>(&addr1), sizeof(addr1))); - MaybeSave(); // Successful connect. - - // Cleanup no longer needed resources. - - // Only unlink if path is not in abstract namespace. - if (addr1.sun_path[0] != 0) { - RETURN_ERROR_IF_SYSCALL_FAIL(unlink(addr1.sun_path)); - MaybeSave(); // Successful unlink. - } - - // Only unlink if path is not in abstract namespace. - if (addr2.sun_path[0] != 0) { - RETURN_ERROR_IF_SYSCALL_FAIL(unlink(addr2.sun_path)); - MaybeSave(); // Successful unlink. - } - - return absl::make_unique<FDSocketPair>(sock1, sock2); - }; -} - -Creator<SocketPair> FilesystemBidirectionalBindSocketPairCreator(int domain, - int type, - int protocol) { - return BidirectionalBindSocketPairCreator(/* abstract= */ false, domain, type, - protocol); -} - -Creator<SocketPair> AbstractBidirectionalBindSocketPairCreator(int domain, - int type, - int protocol) { - return BidirectionalBindSocketPairCreator(/* abstract= */ true, domain, type, - protocol); -} - -Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type, - int protocol) { - return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> { - struct sockaddr_un addr = {}; - constexpr char kSocketGoferPath[] = "/socket"; - memcpy(addr.sun_path, kSocketGoferPath, sizeof(kSocketGoferPath)); - addr.sun_family = domain; - - int sock1; - RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock1, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); - MaybeSave(); // Successful connect. - - int sock2; - RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock2, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); - MaybeSave(); // Successful connect. - - // Make and close another socketpair to ensure that the duped ends of the - // first socketpair get closed. - // - // The problem is that there is no way to atomically send and close an FD. - // The closest that we can do is send and then immediately close the FD, - // which is what we do in the gofer. The gofer won't respond to another - // request until the reply is sent and the FD is closed, so forcing the - // gofer to handle another request will ensure that this has happened. - for (int i = 0; i < 2; i++) { - int sock; - RETURN_ERROR_IF_SYSCALL_FAIL(sock = socket(domain, type, protocol)); - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); - RETURN_ERROR_IF_SYSCALL_FAIL(close(sock)); - } - - return absl::make_unique<FDSocketPair>(sock1, sock2); - }; -} - -Creator<SocketPair> SocketpairGoferFileSocketPairCreator(int flags) { - return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> { - constexpr char kSocketGoferPath[] = "/socket"; - - int sock1; - RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = - open(kSocketGoferPath, O_RDWR | flags)); - MaybeSave(); // Successful socket creation. - - int sock2; - RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = - open(kSocketGoferPath, O_RDWR | flags)); - MaybeSave(); // Successful socket creation. - - return absl::make_unique<FDSocketPair>(sock1, sock2); - }; -} - -Creator<SocketPair> UnboundSocketPairCreator(bool abstract, int domain, - int type, int protocol) { - return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> { - ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un addr1, - UniqueUnixAddr(abstract, domain)); - ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un addr2, - UniqueUnixAddr(abstract, domain)); - - int sock1; - RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - int sock2; - RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - return absl::make_unique<AddrFDSocketPair>(sock1, sock2, addr1, addr2); - }; -} - -Creator<SocketPair> FilesystemUnboundSocketPairCreator(int domain, int type, - int protocol) { - return UnboundSocketPairCreator(/* abstract= */ false, domain, type, - protocol); -} - -Creator<SocketPair> AbstractUnboundSocketPairCreator(int domain, int type, - int protocol) { - return UnboundSocketPairCreator(/* abstract= */ true, domain, type, protocol); -} - -void LocalhostAddr(struct sockaddr_in* addr, bool dual_stack) { - addr->sin_family = AF_INET; - addr->sin_port = htons(0); - inet_pton(AF_INET, "127.0.0.1", - reinterpret_cast<void*>(&addr->sin_addr.s_addr)); -} - -void LocalhostAddr(struct sockaddr_in6* addr, bool dual_stack) { - addr->sin6_family = AF_INET6; - addr->sin6_port = htons(0); - if (dual_stack) { - inet_pton(AF_INET6, "::ffff:127.0.0.1", - reinterpret_cast<void*>(&addr->sin6_addr.s6_addr)); - } else { - inet_pton(AF_INET6, "::1", - reinterpret_cast<void*>(&addr->sin6_addr.s6_addr)); - } - addr->sin6_scope_id = 0; -} - -template <typename T> -PosixErrorOr<T> BindIP(int fd, bool dual_stack) { - T addr = {}; - LocalhostAddr(&addr, dual_stack); - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); - socklen_t addrlen = sizeof(addr); - RETURN_ERROR_IF_SYSCALL_FAIL( - getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen)); - return addr; -} - -template <typename T> -PosixErrorOr<T> TCPBindAndListen(int fd, bool dual_stack) { - ASSIGN_OR_RETURN_ERRNO(T addr, BindIP<T>(fd, dual_stack)); - RETURN_ERROR_IF_SYSCALL_FAIL(listen(fd, /* backlog = */ 5)); - return addr; -} - -template <typename T> -PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> -CreateTCPConnectAcceptSocketPair(int bound, int connected, int type, - bool dual_stack, T bind_addr) { - int connect_result = 0; - RETURN_ERROR_IF_SYSCALL_FAIL( - (connect_result = RetryEINTR(connect)( - connected, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr))) == -1 && - errno == EINPROGRESS - ? 0 - : connect_result); - MaybeSave(); // Successful connect. - - if (connect_result == -1) { - struct pollfd connect_poll = {connected, POLLOUT | POLLERR | POLLHUP, 0}; - RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(poll)(&connect_poll, 1, 0)); - int error = 0; - socklen_t errorlen = sizeof(error); - RETURN_ERROR_IF_SYSCALL_FAIL( - getsockopt(connected, SOL_SOCKET, SO_ERROR, &error, &errorlen)); - errno = error; - RETURN_ERROR_IF_SYSCALL_FAIL( - /* connect */ error == 0 ? 0 : -1); - } - - int accepted = -1; - struct pollfd accept_poll = {bound, POLLIN, 0}; - while (accepted == -1) { - RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(poll)(&accept_poll, 1, 0)); - - RETURN_ERROR_IF_SYSCALL_FAIL( - (accepted = RetryEINTR(accept4)( - bound, nullptr, nullptr, type & (SOCK_NONBLOCK | SOCK_CLOEXEC))) == - -1 && - errno == EAGAIN - ? 0 - : accepted); - } - MaybeSave(); // Successful accept. - - // FIXME(b/110484944) - if (connect_result == -1) { - absl::SleepFor(absl::Seconds(1)); - } - - T extra_addr = {}; - LocalhostAddr(&extra_addr, dual_stack); - return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr, - extra_addr); -} - -template <typename T> -PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair( - int bound, int connected, int type, bool dual_stack) { - ASSIGN_OR_RETURN_ERRNO(T bind_addr, TCPBindAndListen<T>(bound, dual_stack)); - - auto result = CreateTCPConnectAcceptSocketPair(bound, connected, type, - dual_stack, bind_addr); - - // Cleanup no longer needed resources. - RETURN_ERROR_IF_SYSCALL_FAIL(close(bound)); - MaybeSave(); // Successful close. - - return result; -} - -Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type, - int protocol, - bool dual_stack) { - return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> { - int bound; - RETURN_ERROR_IF_SYSCALL_FAIL(bound = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - - int connected; - RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - - if (domain == AF_INET) { - return CreateTCPAcceptBindSocketPair<sockaddr_in>(bound, connected, type, - dual_stack); - } - return CreateTCPAcceptBindSocketPair<sockaddr_in6>(bound, connected, type, - dual_stack); - }; -} - -Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator( - int domain, int type, int protocol, bool dual_stack) { - // These are lazily initialized below, on the first call to the returned - // lambda. These values are private to each returned lambda, but shared across - // invocations of a specific lambda. - // - // The sharing allows pairs created with the same parameters to share a - // listener. This prevents future connects from failing if the connecting - // socket selects a port which had previously been used by a listening socket - // that still has some connections in TIME-WAIT. - // - // The lazy initialization is to avoid creating sockets during parameter - // enumeration. This is important because parameters are enumerated during the - // build process where networking may not be available. - auto listener = std::make_shared<absl::optional<int>>(absl::optional<int>()); - auto addr4 = std::make_shared<absl::optional<sockaddr_in>>( - absl::optional<sockaddr_in>()); - auto addr6 = std::make_shared<absl::optional<sockaddr_in6>>( - absl::optional<sockaddr_in6>()); - - return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> { - int connected; - RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - - // Share the listener across invocations. - if (!listener->has_value()) { - int fd = socket(domain, type, protocol); - if (fd < 0) { - return PosixError(errno, absl::StrCat("socket(", domain, ", ", type, - ", ", protocol, ")")); - } - listener->emplace(fd); - MaybeSave(); // Successful socket creation. - } - - // Bind the listener once, but create a new connect/accept pair each - // time. - if (domain == AF_INET) { - if (!addr4->has_value()) { - addr4->emplace( - TCPBindAndListen<sockaddr_in>(listener->value(), dual_stack) - .ValueOrDie()); - } - return CreateTCPConnectAcceptSocketPair(listener->value(), connected, - type, dual_stack, addr4->value()); - } - if (!addr6->has_value()) { - addr6->emplace( - TCPBindAndListen<sockaddr_in6>(listener->value(), dual_stack) - .ValueOrDie()); - } - return CreateTCPConnectAcceptSocketPair(listener->value(), connected, type, - dual_stack, addr6->value()); - }; -} - -template <typename T> -PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateUDPBoundSocketPair( - int sock1, int sock2, int type, bool dual_stack) { - ASSIGN_OR_RETURN_ERRNO(T addr1, BindIP<T>(sock1, dual_stack)); - ASSIGN_OR_RETURN_ERRNO(T addr2, BindIP<T>(sock2, dual_stack)); - - return absl::make_unique<AddrFDSocketPair>(sock1, sock2, addr1, addr2); -} - -template <typename T> -PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> -CreateUDPBidirectionalBindSocketPair(int sock1, int sock2, int type, - bool dual_stack) { - ASSIGN_OR_RETURN_ERRNO( - auto socks, CreateUDPBoundSocketPair<T>(sock1, sock2, type, dual_stack)); - - // Connect sock1 to sock2. - RETURN_ERROR_IF_SYSCALL_FAIL(connect(socks->first_fd(), socks->second_addr(), - socks->second_addr_size())); - MaybeSave(); // Successful connection. - - // Connect sock2 to sock1. - RETURN_ERROR_IF_SYSCALL_FAIL(connect(socks->second_fd(), socks->first_addr(), - socks->first_addr_size())); - MaybeSave(); // Successful connection. - - return socks; -} - -Creator<SocketPair> UDPBidirectionalBindSocketPairCreator(int domain, int type, - int protocol, - bool dual_stack) { - return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> { - int sock1; - RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - - int sock2; - RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - - if (domain == AF_INET) { - return CreateUDPBidirectionalBindSocketPair<sockaddr_in>( - sock1, sock2, type, dual_stack); - } - return CreateUDPBidirectionalBindSocketPair<sockaddr_in6>(sock1, sock2, - type, dual_stack); - }; -} - -Creator<SocketPair> UDPUnboundSocketPairCreator(int domain, int type, - int protocol, bool dual_stack) { - return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> { - int sock1; - RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - - int sock2; - RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - - return absl::make_unique<FDSocketPair>(sock1, sock2); - }; -} - -SocketPairKind Reversed(SocketPairKind const& base) { - auto const& creator = base.creator; - return SocketPairKind{ - absl::StrCat("reversed ", base.description), base.domain, base.type, - base.protocol, - [creator]() -> PosixErrorOr<std::unique_ptr<ReversedSocketPair>> { - ASSIGN_OR_RETURN_ERRNO(auto creator_value, creator()); - return absl::make_unique<ReversedSocketPair>(std::move(creator_value)); - }}; -} - -Creator<FileDescriptor> UnboundSocketCreator(int domain, int type, - int protocol) { - return [=]() -> PosixErrorOr<std::unique_ptr<FileDescriptor>> { - int sock; - RETURN_ERROR_IF_SYSCALL_FAIL(sock = socket(domain, type, protocol)); - MaybeSave(); // Successful socket creation. - - return absl::make_unique<FileDescriptor>(sock); - }; -} - -std::vector<SocketPairKind> IncludeReversals(std::vector<SocketPairKind> vec) { - return ApplyVecToVec<SocketPairKind>(std::vector<Middleware>{NoOp, Reversed}, - vec); -} - -SocketPairKind NoOp(SocketPairKind const& base) { return base; } - -void TransferTest(int fd1, int fd2) { - char buf1[20]; - RandomizeBuffer(buf1, sizeof(buf1)); - ASSERT_THAT(WriteFd(fd1, buf1, sizeof(buf1)), - SyscallSucceedsWithValue(sizeof(buf1))); - - char buf2[20]; - ASSERT_THAT(ReadFd(fd2, buf2, sizeof(buf2)), - SyscallSucceedsWithValue(sizeof(buf2))); - - EXPECT_EQ(0, memcmp(buf1, buf2, sizeof(buf1))); - - RandomizeBuffer(buf1, sizeof(buf1)); - ASSERT_THAT(WriteFd(fd2, buf1, sizeof(buf1)), - SyscallSucceedsWithValue(sizeof(buf1))); - - ASSERT_THAT(ReadFd(fd1, buf2, sizeof(buf2)), - SyscallSucceedsWithValue(sizeof(buf2))); - - EXPECT_EQ(0, memcmp(buf1, buf2, sizeof(buf1))); -} - -// Initializes the given buffer with random data. -void RandomizeBuffer(char* ptr, size_t len) { - uint32_t seed = time(nullptr); - for (size_t i = 0; i < len; ++i) { - ptr[i] = static_cast<char>(rand_r(&seed)); - } -} - -size_t CalculateUnixSockAddrLen(const char* sun_path) { - // Abstract addresses always return the full length. - if (sun_path[0] == 0) { - return sizeof(sockaddr_un); - } - // Filesystem addresses use the address length plus the 2 byte sun_family - // and null terminator. - return strlen(sun_path) + 3; -} - -struct sockaddr_storage AddrFDSocketPair::to_storage(const sockaddr_un& addr) { - struct sockaddr_storage addr_storage = {}; - memcpy(&addr_storage, &addr, sizeof(addr)); - return addr_storage; -} - -struct sockaddr_storage AddrFDSocketPair::to_storage(const sockaddr_in& addr) { - struct sockaddr_storage addr_storage = {}; - memcpy(&addr_storage, &addr, sizeof(addr)); - return addr_storage; -} - -struct sockaddr_storage AddrFDSocketPair::to_storage(const sockaddr_in6& addr) { - struct sockaddr_storage addr_storage = {}; - memcpy(&addr_storage, &addr, sizeof(addr)); - return addr_storage; -} - -SocketKind SimpleSocket(int fam, int type, int proto) { - return SocketKind{ - absl::StrCat("Family ", fam, ", type ", type, ", proto ", proto), fam, - type, proto, SyscallSocketCreator(fam, type, proto)}; -} - -ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets, - size_t size, bool reader) { - const int rfd = sockets->second_fd(); - ScopedThread t([rfd, size, reader] { - if (!reader) { - return; - } - - // Potentially too many syscalls in the loop. - const DisableSave ds; - - std::vector<char> buf(size); - size_t total = 0; - - while (total < size) { - int ret = read(rfd, buf.data(), buf.size()); - if (ret == -1 && errno == EAGAIN) { - continue; - } - if (ret > 0) { - total += ret; - } - - // Assert to return on first failure. - ASSERT_THAT(ret, SyscallSucceeds()); - } - }); - - std::vector<char> buf(size); - - struct iovec iov = {}; - iov.iov_base = buf.data(); - iov.iov_len = buf.size(); - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - return RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0); -} - -namespace internal { -PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, - SocketType type, bool reuse_addr) { - if (port < 0) { - return PosixError(EINVAL, "Invalid port"); - } - - // Both Ipv6 and Dualstack are AF_INET6. - int sock_fam = (family == AddressFamily::kIpv4 ? AF_INET : AF_INET6); - int sock_type = (type == SocketType::kTcp ? SOCK_STREAM : SOCK_DGRAM); - ASSIGN_OR_RETURN_ERRNO(auto fd, Socket(sock_fam, sock_type, 0)); - - if (reuse_addr) { - int one = 1; - RETURN_ERROR_IF_SYSCALL_FAIL( - setsockopt(fd.get(), SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))); - } - - // Try to bind. - sockaddr_storage storage = {}; - int storage_size = 0; - if (family == AddressFamily::kIpv4) { - sockaddr_in* addr = reinterpret_cast<sockaddr_in*>(&storage); - storage_size = sizeof(*addr); - addr->sin_family = AF_INET; - addr->sin_port = htons(port); - addr->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - sockaddr_in6* addr = reinterpret_cast<sockaddr_in6*>(&storage); - storage_size = sizeof(*addr); - addr->sin6_family = AF_INET6; - addr->sin6_port = htons(port); - if (family == AddressFamily::kDualStack) { - inet_pton(AF_INET6, "::ffff:0.0.0.0", - reinterpret_cast<void*>(&addr->sin6_addr.s6_addr)); - } else { - addr->sin6_addr = in6addr_any; - } - } - - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(fd.get(), reinterpret_cast<sockaddr*>(&storage), storage_size)); - - // If the user specified 0 as the port, we will return the port that the - // kernel gave us, otherwise we will validate that this socket bound to the - // requested port. - sockaddr_storage bound_storage = {}; - socklen_t bound_storage_size = sizeof(bound_storage); - RETURN_ERROR_IF_SYSCALL_FAIL( - getsockname(fd.get(), reinterpret_cast<sockaddr*>(&bound_storage), - &bound_storage_size)); - - int available_port = -1; - if (bound_storage.ss_family == AF_INET) { - sockaddr_in* addr = reinterpret_cast<sockaddr_in*>(&bound_storage); - available_port = ntohs(addr->sin_port); - } else if (bound_storage.ss_family == AF_INET6) { - sockaddr_in6* addr = reinterpret_cast<sockaddr_in6*>(&bound_storage); - available_port = ntohs(addr->sin6_port); - } else { - return PosixError(EPROTOTYPE, "Getsockname returned invalid family"); - } - - // If we requested a specific port make sure our bound port is that port. - if (port != 0 && available_port != port) { - return PosixError(EINVAL, - absl::StrCat("Bound port ", available_port, - " was not equal to requested port ", port)); - } - - // If we're trying to do a TCP socket, let's also try to listen. - if (type == SocketType::kTcp) { - RETURN_ERROR_IF_SYSCALL_FAIL(listen(fd.get(), 1)); - } - - return available_port; -} -} // namespace internal - -PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) { - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = buf_size; - msg->msg_iov = &iov; - msg->msg_iovlen = 1; - - int ret; - RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(sendmsg)(sock, msg, 0)); - return ret; -} - -void RecvNoData(int sock) { - char data = 0; - struct iovec iov; - iov.iov_base = &data; - iov.iov_len = 1; - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -TestAddress V4Any() { - TestAddress t("V4Any"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = htonl(INADDR_ANY); - return t; -} - -TestAddress V4Loopback() { - TestAddress t("V4Loopback"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - htonl(INADDR_LOOPBACK); - return t; -} - -TestAddress V4MappedAny() { - TestAddress t("V4MappedAny"); - t.addr.ss_family = AF_INET6; - t.addr_len = sizeof(sockaddr_in6); - inet_pton(AF_INET6, "::ffff:0.0.0.0", - reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr.s6_addr); - return t; -} - -TestAddress V4MappedLoopback() { - TestAddress t("V4MappedLoopback"); - t.addr.ss_family = AF_INET6; - t.addr_len = sizeof(sockaddr_in6); - inet_pton(AF_INET6, "::ffff:127.0.0.1", - reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr.s6_addr); - return t; -} - -TestAddress V4Multicast() { - TestAddress t("V4Multicast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - inet_addr(kMulticastAddress); - return t; -} - -TestAddress V4Broadcast() { - TestAddress t("V4Broadcast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - htonl(INADDR_BROADCAST); - return t; -} - -TestAddress V6Any() { - TestAddress t("V6Any"); - t.addr.ss_family = AF_INET6; - t.addr_len = sizeof(sockaddr_in6); - reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr = in6addr_any; - return t; -} - -TestAddress V6Loopback() { - TestAddress t("V6Loopback"); - t.addr.ss_family = AF_INET6; - t.addr_len = sizeof(sockaddr_in6); - reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr = in6addr_loopback; - return t; -} - -// Checksum computes the internet checksum of a buffer. -uint16_t Checksum(uint16_t* buf, ssize_t buf_size) { - // Add up the 16-bit values in the buffer. - uint32_t total = 0; - for (unsigned int i = 0; i < buf_size; i += sizeof(*buf)) { - total += *buf; - buf++; - } - - // If buf has an odd size, add the remaining byte. - if (buf_size % 2) { - total += *(reinterpret_cast<unsigned char*>(buf) - 1); - } - - // This carries any bits past the lower 16 until everything fits in 16 bits. - while (total >> 16) { - uint16_t lower = total & 0xffff; - uint16_t upper = total >> 16; - total = lower + upper; - } - - return ~total; -} - -uint16_t IPChecksum(struct iphdr ip) { - return Checksum(reinterpret_cast<uint16_t*>(&ip), sizeof(ip)); -} - -// The pseudo-header defined in RFC 768 for calculating the UDP checksum. -struct udp_pseudo_hdr { - uint32_t srcip; - uint32_t destip; - char zero; - char protocol; - uint16_t udplen; -}; - -uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, - const char* payload, ssize_t payload_len) { - struct udp_pseudo_hdr phdr = {}; - phdr.srcip = iphdr.saddr; - phdr.destip = iphdr.daddr; - phdr.zero = 0; - phdr.protocol = IPPROTO_UDP; - phdr.udplen = udphdr.len; - - ssize_t buf_size = sizeof(phdr) + sizeof(udphdr) + payload_len; - char* buf = static_cast<char*>(malloc(buf_size)); - memcpy(buf, &phdr, sizeof(phdr)); - memcpy(buf + sizeof(phdr), &udphdr, sizeof(udphdr)); - memcpy(buf + sizeof(phdr) + sizeof(udphdr), payload, payload_len); - - uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size); - free(buf); - return csum; -} - -uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, - ssize_t payload_len) { - ssize_t buf_size = sizeof(icmphdr) + payload_len; - char* buf = static_cast<char*>(malloc(buf_size)); - memcpy(buf, &icmphdr, sizeof(icmphdr)); - memcpy(buf + sizeof(icmphdr), payload, payload_len); - - uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size); - free(buf); - return csum; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h deleted file mode 100644 index 734b48b96..000000000 --- a/test/syscalls/linux/socket_test_util.h +++ /dev/null @@ -1,518 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_ -#define GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_ - -#include <errno.h> -#include <netinet/ip.h> -#include <netinet/ip_icmp.h> -#include <netinet/udp.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <functional> -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/str_format.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// Wrapper for socket(2) that returns a FileDescriptor. -inline PosixErrorOr<FileDescriptor> Socket(int family, int type, int protocol) { - int fd = socket(family, type, protocol); - MaybeSave(); - if (fd < 0) { - return PosixError( - errno, absl::StrFormat("socket(%d, %d, %d)", family, type, protocol)); - } - return FileDescriptor(fd); -} - -// Wrapper for accept(2) that returns a FileDescriptor. -inline PosixErrorOr<FileDescriptor> Accept(int sockfd, sockaddr* addr, - socklen_t* addrlen) { - int fd = RetryEINTR(accept)(sockfd, addr, addrlen); - MaybeSave(); - if (fd < 0) { - return PosixError( - errno, absl::StrFormat("accept(%d, %p, %p)", sockfd, addr, addrlen)); - } - return FileDescriptor(fd); -} - -// Wrapper for accept4(2) that returns a FileDescriptor. -inline PosixErrorOr<FileDescriptor> Accept4(int sockfd, sockaddr* addr, - socklen_t* addrlen, int flags) { - int fd = RetryEINTR(accept4)(sockfd, addr, addrlen, flags); - MaybeSave(); - if (fd < 0) { - return PosixError(errno, absl::StrFormat("accept4(%d, %p, %p, %#x)", sockfd, - addr, addrlen, flags)); - } - return FileDescriptor(fd); -} - -inline ssize_t SendFd(int fd, void* buf, size_t count, int flags) { - return internal::ApplyFileIoSyscall( - [&](size_t completed) { - return sendto(fd, static_cast<char*>(buf) + completed, - count - completed, flags, nullptr, 0); - }, - count); -} - -PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain); - -// A Creator<T> is a function that attempts to create and return a new T. (This -// is copy/pasted from cloud/gvisor/api/sandbox_util.h and is just duplicated -// here for clarity.) -template <typename T> -using Creator = std::function<PosixErrorOr<std::unique_ptr<T>>()>; - -// A SocketPair represents a pair of socket file descriptors owned by the -// SocketPair. -class SocketPair { - public: - virtual ~SocketPair() = default; - - virtual int first_fd() const = 0; - virtual int second_fd() const = 0; - virtual int release_first_fd() = 0; - virtual int release_second_fd() = 0; - virtual const struct sockaddr* first_addr() const = 0; - virtual const struct sockaddr* second_addr() const = 0; - virtual size_t first_addr_size() const = 0; - virtual size_t second_addr_size() const = 0; - virtual size_t first_addr_len() const = 0; - virtual size_t second_addr_len() const = 0; -}; - -// A FDSocketPair is a SocketPair that consists of only a pair of file -// descriptors. -class FDSocketPair : public SocketPair { - public: - FDSocketPair(int first_fd, int second_fd) - : first_(first_fd), second_(second_fd) {} - FDSocketPair(std::unique_ptr<FileDescriptor> first_fd, - std::unique_ptr<FileDescriptor> second_fd) - : first_(first_fd->release()), second_(second_fd->release()) {} - - int first_fd() const override { return first_.get(); } - int second_fd() const override { return second_.get(); } - int release_first_fd() override { return first_.release(); } - int release_second_fd() override { return second_.release(); } - const struct sockaddr* first_addr() const override { return nullptr; } - const struct sockaddr* second_addr() const override { return nullptr; } - size_t first_addr_size() const override { return 0; } - size_t second_addr_size() const override { return 0; } - size_t first_addr_len() const override { return 0; } - size_t second_addr_len() const override { return 0; } - - private: - FileDescriptor first_; - FileDescriptor second_; -}; - -// CalculateUnixSockAddrLen calculates the length returned by recvfrom(2) and -// recvmsg(2) for Unix sockets. -size_t CalculateUnixSockAddrLen(const char* sun_path); - -// A AddrFDSocketPair is a SocketPair that consists of a pair of file -// descriptors in addition to a pair of socket addresses. -class AddrFDSocketPair : public SocketPair { - public: - AddrFDSocketPair(int first_fd, int second_fd, - const struct sockaddr_un& first_address, - const struct sockaddr_un& second_address) - : first_(first_fd), - second_(second_fd), - first_addr_(to_storage(first_address)), - second_addr_(to_storage(second_address)), - first_len_(CalculateUnixSockAddrLen(first_address.sun_path)), - second_len_(CalculateUnixSockAddrLen(second_address.sun_path)), - first_size_(sizeof(first_address)), - second_size_(sizeof(second_address)) {} - - AddrFDSocketPair(int first_fd, int second_fd, - const struct sockaddr_in& first_address, - const struct sockaddr_in& second_address) - : first_(first_fd), - second_(second_fd), - first_addr_(to_storage(first_address)), - second_addr_(to_storage(second_address)), - first_len_(sizeof(first_address)), - second_len_(sizeof(second_address)), - first_size_(sizeof(first_address)), - second_size_(sizeof(second_address)) {} - - AddrFDSocketPair(int first_fd, int second_fd, - const struct sockaddr_in6& first_address, - const struct sockaddr_in6& second_address) - : first_(first_fd), - second_(second_fd), - first_addr_(to_storage(first_address)), - second_addr_(to_storage(second_address)), - first_len_(sizeof(first_address)), - second_len_(sizeof(second_address)), - first_size_(sizeof(first_address)), - second_size_(sizeof(second_address)) {} - - int first_fd() const override { return first_.get(); } - int second_fd() const override { return second_.get(); } - int release_first_fd() override { return first_.release(); } - int release_second_fd() override { return second_.release(); } - const struct sockaddr* first_addr() const override { - return reinterpret_cast<const struct sockaddr*>(&first_addr_); - } - const struct sockaddr* second_addr() const override { - return reinterpret_cast<const struct sockaddr*>(&second_addr_); - } - size_t first_addr_size() const override { return first_size_; } - size_t second_addr_size() const override { return second_size_; } - size_t first_addr_len() const override { return first_len_; } - size_t second_addr_len() const override { return second_len_; } - - private: - // to_storage coverts a sockaddr_* to a sockaddr_storage. - static struct sockaddr_storage to_storage(const sockaddr_un& addr); - static struct sockaddr_storage to_storage(const sockaddr_in& addr); - static struct sockaddr_storage to_storage(const sockaddr_in6& addr); - - FileDescriptor first_; - FileDescriptor second_; - const struct sockaddr_storage first_addr_; - const struct sockaddr_storage second_addr_; - const size_t first_len_; - const size_t second_len_; - const size_t first_size_; - const size_t second_size_; -}; - -// SyscallSocketPairCreator returns a Creator<SocketPair> that obtains file -// descriptors by invoking the socketpair() syscall. -Creator<SocketPair> SyscallSocketPairCreator(int domain, int type, - int protocol); - -// SyscallSocketCreator returns a Creator<FileDescriptor> that obtains a file -// descriptor by invoking the socket() syscall. -Creator<FileDescriptor> SyscallSocketCreator(int domain, int type, - int protocol); - -// FilesystemBidirectionalBindSocketPairCreator returns a Creator<SocketPair> -// that obtains file descriptors by invoking the bind() and connect() syscalls -// on filesystem paths. Only works for DGRAM sockets. -Creator<SocketPair> FilesystemBidirectionalBindSocketPairCreator(int domain, - int type, - int protocol); - -// AbstractBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that -// obtains file descriptors by invoking the bind() and connect() syscalls on -// abstract namespace paths. Only works for DGRAM sockets. -Creator<SocketPair> AbstractBidirectionalBindSocketPairCreator(int domain, - int type, - int protocol); - -// SocketpairGoferSocketPairCreator returns a Creator<SocketPair> that -// obtains file descriptors by connect() syscalls on two sockets with socketpair -// gofer paths. -Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type, - int protocol); - -// SocketpairGoferFileSocketPairCreator returns a Creator<SocketPair> that -// obtains file descriptors by open() syscalls on socketpair gofer paths. -Creator<SocketPair> SocketpairGoferFileSocketPairCreator(int flags); - -// FilesystemAcceptBindSocketPairCreator returns a Creator<SocketPair> that -// obtains file descriptors by invoking the accept() and bind() syscalls on -// a filesystem path. Only works for STREAM and SEQPACKET sockets. -Creator<SocketPair> FilesystemAcceptBindSocketPairCreator(int domain, int type, - int protocol); - -// AbstractAcceptBindSocketPairCreator returns a Creator<SocketPair> that -// obtains file descriptors by invoking the accept() and bind() syscalls on a -// abstract namespace path. Only works for STREAM and SEQPACKET sockets. -Creator<SocketPair> AbstractAcceptBindSocketPairCreator(int domain, int type, - int protocol); - -// FilesystemUnboundSocketPairCreator returns a Creator<SocketPair> that obtains -// file descriptors by invoking the socket() syscall and generates a filesystem -// path for binding. -Creator<SocketPair> FilesystemUnboundSocketPairCreator(int domain, int type, - int protocol); - -// AbstractUnboundSocketPairCreator returns a Creator<SocketPair> that obtains -// file descriptors by invoking the socket() syscall and generates an abstract -// path for binding. -Creator<SocketPair> AbstractUnboundSocketPairCreator(int domain, int type, - int protocol); - -// TCPAcceptBindSocketPairCreator returns a Creator<SocketPair> that obtains -// file descriptors by invoking the accept() and bind() syscalls on TCP sockets. -Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type, - int protocol, - bool dual_stack); - -// TCPAcceptBindPersistentListenerSocketPairCreator is like -// TCPAcceptBindSocketPairCreator, except it uses the same listening socket to -// create all SocketPairs. -Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator( - int domain, int type, int protocol, bool dual_stack); - -// UDPBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that -// obtains file descriptors by invoking the bind() and connect() syscalls on UDP -// sockets. -Creator<SocketPair> UDPBidirectionalBindSocketPairCreator(int domain, int type, - int protocol, - bool dual_stack); - -// UDPUnboundSocketPairCreator returns a Creator<SocketPair> that obtains file -// descriptors by creating UDP sockets. -Creator<SocketPair> UDPUnboundSocketPairCreator(int domain, int type, - int protocol, bool dual_stack); - -// UnboundSocketCreator returns a Creator<FileDescriptor> that obtains a file -// descriptor by creating a socket. -Creator<FileDescriptor> UnboundSocketCreator(int domain, int type, - int protocol); - -// A SocketPairKind couples a human-readable description of a socket pair with -// a function that creates such a socket pair. -struct SocketPairKind { - std::string description; - int domain; - int type; - int protocol; - Creator<SocketPair> creator; - - // Create creates a socket pair of this kind. - PosixErrorOr<std::unique_ptr<SocketPair>> Create() const { return creator(); } -}; - -// A SocketKind couples a human-readable description of a socket with -// a function that creates such a socket. -struct SocketKind { - std::string description; - int domain; - int type; - int protocol; - Creator<FileDescriptor> creator; - - // Create creates a socket pair of this kind. - PosixErrorOr<std::unique_ptr<FileDescriptor>> Create() const { - return creator(); - } -}; - -// A ReversedSocketPair wraps another SocketPair but flips the first and second -// file descriptors. ReversedSocketPair is used to test socket pairs that -// should be symmetric. -class ReversedSocketPair : public SocketPair { - public: - explicit ReversedSocketPair(std::unique_ptr<SocketPair> base) - : base_(std::move(base)) {} - - int first_fd() const override { return base_->second_fd(); } - int second_fd() const override { return base_->first_fd(); } - int release_first_fd() override { return base_->release_second_fd(); } - int release_second_fd() override { return base_->release_first_fd(); } - const struct sockaddr* first_addr() const override { - return base_->second_addr(); - } - const struct sockaddr* second_addr() const override { - return base_->first_addr(); - } - size_t first_addr_size() const override { return base_->second_addr_size(); } - size_t second_addr_size() const override { return base_->first_addr_size(); } - size_t first_addr_len() const override { return base_->second_addr_len(); } - size_t second_addr_len() const override { return base_->first_addr_len(); } - - private: - std::unique_ptr<SocketPair> base_; -}; - -// Reversed returns a SocketPairKind that represents SocketPairs created by -// flipping the file descriptors provided by another SocketPair. -SocketPairKind Reversed(SocketPairKind const& base); - -// IncludeReversals returns a vector<SocketPairKind> that returns all -// SocketPairKinds in `vec` as well as all SocketPairKinds obtained by flipping -// the file descriptors provided by the kinds in `vec`. -std::vector<SocketPairKind> IncludeReversals(std::vector<SocketPairKind> vec); - -// A Middleware is a function wraps a SocketPairKind. -using Middleware = std::function<SocketPairKind(SocketPairKind)>; - -// Reversed returns a SocketPairKind that represents SocketPairs created by -// flipping the file descriptors provided by another SocketPair. -template <typename T> -Middleware SetSockOpt(int level, int optname, T* value) { - return [=](SocketPairKind const& base) { - auto const& creator = base.creator; - return SocketPairKind{ - absl::StrCat("setsockopt(", level, ", ", optname, ", ", *value, ") ", - base.description), - base.domain, base.type, base.protocol, - [creator, level, optname, - value]() -> PosixErrorOr<std::unique_ptr<SocketPair>> { - ASSIGN_OR_RETURN_ERRNO(auto creator_value, creator()); - if (creator_value->first_fd() >= 0) { - RETURN_ERROR_IF_SYSCALL_FAIL(setsockopt( - creator_value->first_fd(), level, optname, value, sizeof(T))); - } - if (creator_value->second_fd() >= 0) { - RETURN_ERROR_IF_SYSCALL_FAIL(setsockopt( - creator_value->second_fd(), level, optname, value, sizeof(T))); - } - return creator_value; - }}; - }; -} - -constexpr int kSockOptOn = 1; -constexpr int kSockOptOff = 0; - -// NoOp returns the same SocketPairKind that it is passed. -SocketPairKind NoOp(SocketPairKind const& base); - -// TransferTest tests that data can be send back and fourth between two -// specified FDs. Note that calls to this function should be wrapped in -// ASSERT_NO_FATAL_FAILURE(). -void TransferTest(int fd1, int fd2); - -// Fills [buf, buf+len) with random bytes. -void RandomizeBuffer(char* buf, size_t len); - -// Base test fixture for tests that operate on pairs of connected sockets. -class SocketPairTest : public ::testing::TestWithParam<SocketPairKind> { - protected: - SocketPairTest() { - // gUnit uses printf, so so will we. - printf("Testing with %s\n", GetParam().description.c_str()); - fflush(stdout); - } - - PosixErrorOr<std::unique_ptr<SocketPair>> NewSocketPair() const { - return GetParam().Create(); - } -}; - -// Base test fixture for tests that operate on simple Sockets. -class SimpleSocketTest : public ::testing::TestWithParam<SocketKind> { - protected: - SimpleSocketTest() { - // gUnit uses printf, so so will we. - printf("Testing with %s\n", GetParam().description.c_str()); - } - - PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const { - return GetParam().Create(); - } -}; - -SocketKind SimpleSocket(int fam, int type, int proto); - -// Send a buffer of size 'size' to sockets->first_fd(), returning the result of -// sendmsg. -// -// If reader, read from second_fd() until size bytes have been read. -ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets, - size_t size, bool reader); - -// Initializes the given buffer with random data. -void RandomizeBuffer(char* ptr, size_t len); - -enum class AddressFamily { kIpv4 = 1, kIpv6 = 2, kDualStack = 3 }; -enum class SocketType { kUdp = 1, kTcp = 2 }; - -// Returns a PosixError or a port that is available. If 0 is specified as the -// port it will bind port 0 (and allow the kernel to select any free port). -// Otherwise, it will try to bind the specified port and validate that it can be -// used for the requested family and socket type. The final option is -// reuse_addr. This specifies whether SO_REUSEADDR should be applied before a -// bind(2) attempt. SO_REUSEADDR means that sockets in TIME_WAIT states or other -// bound UDP sockets would not cause an error on bind(2). This option should be -// set if subsequent calls to bind on the returned port will also use -// SO_REUSEADDR. -// -// Note: That this test will attempt to bind the ANY address for the respective -// protocol. -PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type, - bool reuse_addr); - -// FreeAvailablePort is used to return a port that was obtained by using -// the PortAvailable helper with port 0. -PosixError FreeAvailablePort(int port); - -// SendMsg converts a buffer to an iovec and adds it to msg before sending it. -PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size); - -// RecvNoData checks that no data is receivable on sock. -void RecvNoData(int sock); - -// Base test fixture for tests that apply to all kinds of pairs of connected -// sockets. -using AllSocketPairTest = SocketPairTest; - -struct TestAddress { - std::string description; - sockaddr_storage addr; - socklen_t addr_len; - - int family() const { return addr.ss_family; } - explicit TestAddress(std::string description = "") - : description(std::move(description)), addr(), addr_len() {} -}; - -constexpr char kMulticastAddress[] = "224.0.2.1"; -constexpr char kBroadcastAddress[] = "255.255.255.255"; - -TestAddress V4Any(); -TestAddress V4Broadcast(); -TestAddress V4Loopback(); -TestAddress V4MappedAny(); -TestAddress V4MappedLoopback(); -TestAddress V4Multicast(); -TestAddress V6Any(); -TestAddress V6Loopback(); - -// Compute the internet checksum of an IP header. -uint16_t IPChecksum(struct iphdr ip); - -// Compute the internet checksum of a UDP header. -uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, - const char* payload, ssize_t payload_len); - -// Compute the internet checksum of an ICMP header. -uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, - ssize_t payload_len); - -namespace internal { -PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, - SocketType type, bool reuse_addr); -} // namespace internal - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_ diff --git a/test/syscalls/linux/socket_test_util_impl.cc b/test/syscalls/linux/socket_test_util_impl.cc deleted file mode 100644 index ef661a0e3..000000000 --- a/test/syscalls/linux/socket_test_util_impl.cc +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type, - bool reuse_addr) { - return internal::TryPortAvailable(port, family, type, reuse_addr); -} - -PosixError FreeAvailablePort(int port) { return NoError(); } - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc deleted file mode 100644 index 4cf1f76f1..000000000 --- a/test/syscalls/linux/socket_unix.cc +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_unix.h" - -#include <errno.h> -#include <net/if.h> -#include <stdio.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -// This file contains tests specific to Unix domain sockets. It does not contain -// tests for UDS control messages. Those belong in socket_unix_cmsg.cc. -// -// This file is a generic socket test file. It must be built with another file -// that provides the test types. - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(UnixSocketPairTest, InvalidGetSockOpt) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - int opt; - socklen_t optlen = sizeof(opt); - EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, -1, &opt, &optlen), - SyscallFailsWithErrno(ENOPROTOOPT)); -} - -TEST_P(UnixSocketPairTest, BindToBadName) { - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - constexpr char kBadName[] = "/some/path/that/does/not/exist"; - sockaddr_un sockaddr; - sockaddr.sun_family = AF_LOCAL; - memcpy(sockaddr.sun_path, kBadName, sizeof(kBadName)); - - EXPECT_THAT( - bind(pair->first_fd(), reinterpret_cast<struct sockaddr*>(&sockaddr), - sizeof(sockaddr)), - SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UnixSocketPairTest, BindToBadFamily) { - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - constexpr char kBadName[] = "/some/path/that/does/not/exist"; - sockaddr_un sockaddr; - sockaddr.sun_family = AF_INET; - memcpy(sockaddr.sun_path, kBadName, sizeof(kBadName)); - - EXPECT_THAT( - bind(pair->first_fd(), reinterpret_cast<struct sockaddr*>(&sockaddr), - sizeof(sockaddr)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UnixSocketPairTest, RecvmmsgTimeoutAfterRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[10]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - char received_data[sizeof(sent_data) * 2]; - std::vector<struct mmsghdr> msgs(2); - std::vector<struct iovec> iovs(msgs.size()); - const int chunk_size = sizeof(received_data) / msgs.size(); - for (size_t i = 0; i < msgs.size(); i++) { - iovs[i].iov_len = chunk_size; - iovs[i].iov_base = &received_data[i * chunk_size]; - msgs[i].msg_hdr.msg_iov = &iovs[i]; - msgs[i].msg_hdr.msg_iovlen = 1; - } - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - struct timespec timeout = {0, 1}; - ASSERT_THAT(RetryEINTR(recvmmsg)(sockets->second_fd(), &msgs[0], msgs.size(), - 0, &timeout), - SyscallSucceedsWithValue(1)); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(chunk_size, msgs[0].msg_len); -} - -TEST_P(UnixSocketPairTest, TIOCINQSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - if (IsRunningOnGvisor()) { - // TODO(gvisor.dev/issue/273): Inherited host UDS don't support TIOCINQ. - // Skip the test. - int size = -1; - int ret = ioctl(sockets->first_fd(), TIOCINQ, &size); - SKIP_IF(ret == -1 && errno == ENOTTY); - } - - int size = -1; - EXPECT_THAT(ioctl(sockets->first_fd(), TIOCINQ, &size), SyscallSucceeds()); - EXPECT_EQ(size, 0); - - const char some_data[] = "dangerzone"; - ASSERT_THAT( - RetryEINTR(send)(sockets->second_fd(), &some_data, sizeof(some_data), 0), - SyscallSucceeds()); - EXPECT_THAT(ioctl(sockets->first_fd(), TIOCINQ, &size), SyscallSucceeds()); - EXPECT_EQ(size, sizeof(some_data)); - - // Linux only reports the first message's size, which is wrong. We test for - // the behavior described in the man page. - SKIP_IF(!IsRunningOnGvisor()); - - ASSERT_THAT( - RetryEINTR(send)(sockets->second_fd(), &some_data, sizeof(some_data), 0), - SyscallSucceeds()); - EXPECT_THAT(ioctl(sockets->first_fd(), TIOCINQ, &size), SyscallSucceeds()); - EXPECT_EQ(size, sizeof(some_data) * 2); -} - -TEST_P(UnixSocketPairTest, TIOCOUTQSucceeds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - if (IsRunningOnGvisor()) { - // TODO(gvisor.dev/issue/273): Inherited host UDS don't support TIOCOUTQ. - // Skip the test. - int size = -1; - int ret = ioctl(sockets->second_fd(), TIOCOUTQ, &size); - SKIP_IF(ret == -1 && errno == ENOTTY); - } - - int size = -1; - EXPECT_THAT(ioctl(sockets->second_fd(), TIOCOUTQ, &size), SyscallSucceeds()); - EXPECT_EQ(size, 0); - - // Linux reports bogus numbers which are related to its internal allocations. - // We test for the behavior described in the man page. - SKIP_IF(!IsRunningOnGvisor()); - - const char some_data[] = "dangerzone"; - ASSERT_THAT( - RetryEINTR(send)(sockets->second_fd(), &some_data, sizeof(some_data), 0), - SyscallSucceeds()); - EXPECT_THAT(ioctl(sockets->second_fd(), TIOCOUTQ, &size), SyscallSucceeds()); - EXPECT_EQ(size, sizeof(some_data)); - - ASSERT_THAT( - RetryEINTR(send)(sockets->second_fd(), &some_data, sizeof(some_data), 0), - SyscallSucceeds()); - EXPECT_THAT(ioctl(sockets->second_fd(), TIOCOUTQ, &size), SyscallSucceeds()); - EXPECT_EQ(size, sizeof(some_data) * 2); -} - -TEST_P(UnixSocketPairTest, NetdeviceIoctlsSucceed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Prepare the request. - struct ifreq ifr; - snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - - // Check that the ioctl either succeeds or fails with ENODEV. - int err = ioctl(sockets->first_fd(), SIOCGIFINDEX, &ifr); - if (err < 0) { - ASSERT_EQ(errno, ENODEV); - } -} - -TEST_P(UnixSocketPairTest, Shutdown) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - const std::string data = "abc"; - ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), - SyscallSucceedsWithValue(data.size())); - - ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); - ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); - - // Shutting down a socket does not clear the buffer. - char buf[3]; - ASSERT_THAT(ReadFd(sockets->second_fd(), buf, data.size()), - SyscallSucceedsWithValue(data.size())); - EXPECT_EQ(data, absl::string_view(buf, data.size())); -} - -TEST_P(UnixSocketPairTest, ShutdownRead) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RD), SyscallSucceeds()); - - // When the socket is shutdown for read, read behavior varies between - // different socket types. This is covered by the various ReadOneSideClosed - // test cases. - - // ... and the peer cannot write. - const std::string data = "abc"; - EXPECT_THAT(WriteFd(sockets->second_fd(), data.c_str(), data.size()), - SyscallFailsWithErrno(EPIPE)); - - // ... but the socket can still write. - ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), - SyscallSucceedsWithValue(data.size())); - - // ... and the peer can still read. - char buf[3]; - EXPECT_THAT(ReadFd(sockets->second_fd(), buf, data.size()), - SyscallSucceedsWithValue(data.size())); - EXPECT_EQ(data, absl::string_view(buf, data.size())); -} - -TEST_P(UnixSocketPairTest, ShutdownWrite) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_WR), SyscallSucceeds()); - - // When the socket is shutdown for write, it cannot write. - const std::string data = "abc"; - EXPECT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), - SyscallFailsWithErrno(EPIPE)); - - // ... and the peer read behavior varies between different socket types. This - // is covered by the various ReadOneSideClosed test cases. - - // ... but the peer can still write. - char buf[3]; - ASSERT_THAT(WriteFd(sockets->second_fd(), data.c_str(), data.size()), - SyscallSucceedsWithValue(data.size())); - - // ... and the socket can still read. - EXPECT_THAT(ReadFd(sockets->first_fd(), buf, data.size()), - SyscallSucceedsWithValue(data.size())); - EXPECT_EQ(data, absl::string_view(buf, data.size())); -} - -TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) { - // TODO(b/122310852): We should be returning ENXIO and NOT EIO. - SKIP_IF(IsRunningOnGvisor()); - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Opening a socket pair via /proc/self/fd/X is a ENXIO. - for (const int fd : {sockets->first_fd(), sockets->second_fd()}) { - ASSERT_THAT(Open(absl::StrCat("/proc/self/fd/", fd), O_WRONLY), - PosixErrorIs(ENXIO, ::testing::_)); - } -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix.h b/test/syscalls/linux/socket_unix.h deleted file mode 100644 index 3625cc404..000000000 --- a/test/syscalls/linux/socket_unix.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected unix sockets. -using UnixSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_H_ diff --git a/test/syscalls/linux/socket_unix_abstract_nonblock.cc b/test/syscalls/linux/socket_unix_abstract_nonblock.cc deleted file mode 100644 index 8bef76b67..000000000 --- a/test/syscalls/linux/socket_unix_abstract_nonblock.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_non_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVec<SocketPairKind>( - AbstractBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET}, - List<int>{SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P( - NonBlockingAbstractUnixSockets, NonBlockingSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_blocking_local.cc b/test/syscalls/linux/socket_unix_blocking_local.cc deleted file mode 100644 index 77cb8c6d6..000000000 --- a/test/syscalls/linux/socket_unix_blocking_local.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>( - UnixDomainSocketPair, - std::vector<int>{SOCK_STREAM, SOCK_SEQPACKET, SOCK_DGRAM}), - ApplyVec<SocketPairKind>( - FilesystemBoundUnixDomainSocketPair, - std::vector<int>{SOCK_STREAM, SOCK_SEQPACKET, SOCK_DGRAM}), - ApplyVec<SocketPairKind>( - AbstractBoundUnixDomainSocketPair, - std::vector<int>{SOCK_STREAM, SOCK_SEQPACKET, SOCK_DGRAM})); -} - -INSTANTIATE_TEST_SUITE_P( - NonBlockingUnixDomainSockets, BlockingSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc deleted file mode 100644 index a16899493..000000000 --- a/test/syscalls/linux/socket_unix_cmsg.cc +++ /dev/null @@ -1,1501 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_unix_cmsg.h" - -#include <errno.h> -#include <net/if.h> -#include <stdio.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/un.h> - -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -// This file contains tests for control message in Unix domain sockets. -// -// This file is a generic socket test file. It must be built with another file -// that provides the test types. - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(UnixSocketPairCmsgTest, BasicFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairCmsgTest, BasicTwoFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); - - char received_data[20]; - int received_fds[] = {-1, -1}; - - ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 2, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); -} - -TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair3 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); - - char received_data[20]; - int received_fds[] = {-1, -1, -1}; - - ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 3, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[2], pair3->first_fd())); -} - -TEST_P(UnixSocketPairCmsgTest, BadFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - int sent_fd = -1; - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(sent_fd))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = CMSG_LEN(sizeof(sent_fd)); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); - - struct iovec iov; - iov.iov_base = sent_data; - iov.iov_len = sizeof(sent_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EBADF)); -} - -TEST_P(UnixSocketPairCmsgTest, ShortCmsg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - int sent_fd = -1; - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(sent_fd))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = 1; - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); - - struct iovec iov; - iov.iov_base = sent_data; - iov.iov_len = sizeof(sent_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EINVAL)); -} - -// BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass. -// The difference is that when calling recvmsg, no space for FDs is provided, -// only space for the cmsg header. -TEST_P(UnixSocketPairCmsgTest, BasicFDPassNoSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - - struct msghdr msg = {}; - std::vector<char> control(CMSG_SPACE(0)); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -// BasicFDPassNoSpaceMsgCtrunc sends an FD, but does not provide any space to -// receive it. It then verifies that the MSG_CTRUNC flag is set in the msghdr. -TEST_P(UnixSocketPairCmsgTest, BasicFDPassNoSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector<char> control(CMSG_SPACE(0)); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicFDPassNullControlMsgCtrunc sends an FD and sets contradictory values for -// msg_controllen and msg_control. msg_controllen is set to the correct size to -// accommodate the FD, but msg_control is set to NULL. In this case, msg_control -// should override msg_controllen. -TEST_P(UnixSocketPairCmsgTest, BasicFDPassNullControlMsgCtrunc) { - // FIXME(gvisor.dev/issue/207): Fix handling of NULL msg_control. - SKIP_IF(IsRunningOnGvisor()); - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - msg.msg_controllen = CMSG_SPACE(1); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicFDPassNotEnoughSpaceMsgCtrunc sends an FD, but does not provide enough -// space to receive it. It then verifies that the MSG_CTRUNC flag is set in the -// msghdr. -TEST_P(UnixSocketPairCmsgTest, BasicFDPassNotEnoughSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector<char> control(CMSG_SPACE(0) + 1); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicThreeFDPassTruncationMsgCtrunc sends three FDs, but only provides enough -// space to receive two of them. It then verifies that the MSG_CTRUNC flag is -// set in the msghdr. -TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPassTruncationMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair3 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector<char> control(CMSG_SPACE(2 * sizeof(int))); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(2 * sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -// BasicFDPassUnalignedRecv starts off by sending a single FD just like -// BasicFDPass. The difference is that when calling recvmsg, the length of the -// receive data is only aligned on a 4 byte boundry instead of the normal 8. -TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFDUnaligned( - sockets->second_fd(), &fd, received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -// BasicFDPassUnalignedRecvNoMsgTrunc sends one FD and only provides enough -// space to receive just it. (Normally the minimum amount of space one would -// provide would be enough space for two FDs.) It then verifies that the -// MSG_CTRUNC flag is not set in the msghdr. -TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecvNoMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, 0); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -// BasicTwoFDPassUnalignedRecvTruncationMsgTrunc sends two FDs, but only -// provides enough space to receive one of them. It then verifies that the -// MSG_CTRUNC flag is set in the msghdr. -TEST_P(UnixSocketPairCmsgTest, BasicTwoFDPassUnalignedRecvTruncationMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair->first_fd(), pair->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - // CMSG_SPACE rounds up to two FDs, we only want one. - char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -TEST_P(UnixSocketPairCmsgTest, ConcurrentBasicFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - int sockfd1 = sockets->first_fd(); - auto recv_func = [sockfd1, sent_data]() { - char received_data[20]; - int fd = -1; - RecvSingleFD(sockfd1, &fd, received_data, sizeof(received_data)); - ASSERT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - char buf[20]; - ASSERT_THAT(ReadFd(fd, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(WriteFd(fd, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - }; - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->second_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - ScopedThread t(recv_func); - - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(WriteFd(pair->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - ASSERT_THAT(ReadFd(pair->first_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - t.Join(); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -// FDPassNoRecv checks that the control message can be safely ignored by using -// read(2) instead of recvmsg(2). -TEST_P(UnixSocketPairCmsgTest, FDPassNoRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - // Read while ignoring the passed FD. - char received_data[20]; - ASSERT_THAT( - ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // Check that the socket still works for reads and writes. - ASSERT_NO_FATAL_FAILURE( - TransferTest(sockets->first_fd(), sockets->second_fd())); -} - -// FDPassInterspersed1 checks that sent control messages cannot be read before -// their associated data has been read. -TEST_P(UnixSocketPairCmsgTest, FDPassInterspersed1) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char written_data[20]; - RandomizeBuffer(written_data, sizeof(written_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), - SyscallSucceedsWithValue(sizeof(written_data))); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - // Check that we don't get a control message, but do get the data. - char received_data[20]; - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)); - EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); -} - -// FDPassInterspersed2 checks that sent control messages cannot be read after -// their associated data has been read while ignoring the control message by -// using read(2) instead of recvmsg(2). -TEST_P(UnixSocketPairCmsgTest, FDPassInterspersed2) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char written_data[20]; - RandomizeBuffer(written_data, sizeof(written_data)); - ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), - SyscallSucceedsWithValue(sizeof(written_data))); - - char received_data[20]; - ASSERT_THAT( - ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); -} - -TEST_P(UnixSocketPairCmsgTest, FDPassNotCoalesced) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(), - sent_data1, sizeof(sent_data1))); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(), - sent_data2, sizeof(sent_data2))); - - char received_data1[sizeof(sent_data1) + sizeof(sent_data2)]; - int received_fd1 = -1; - - RecvSingleFD(sockets->second_fd(), &received_fd1, received_data1, - sizeof(received_data1), sizeof(sent_data1)); - - EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1))); - TransferTest(pair1->first_fd(), pair1->second_fd()); - - char received_data2[sizeof(sent_data1) + sizeof(sent_data2)]; - int received_fd2 = -1; - - RecvSingleFD(sockets->second_fd(), &received_fd2, received_data2, - sizeof(received_data2), sizeof(sent_data2)); - - EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2))); - TransferTest(pair2->first_fd(), pair2->second_fd()); -} - -TEST_P(UnixSocketPairCmsgTest, FDPassPeek) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char peek_data[20]; - int peek_fd = -1; - PeekSingleFD(sockets->second_fd(), &peek_fd, peek_data, sizeof(peek_data)); - EXPECT_EQ(0, memcmp(sent_data, peek_data, sizeof(sent_data))); - TransferTest(peek_fd, pair->first_fd()); - EXPECT_THAT(close(peek_fd), SyscallSucceeds()); - - char received_data[20]; - int received_fd = -1; - RecvSingleFD(sockets->second_fd(), &received_fd, received_data, - sizeof(received_data)); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - TransferTest(received_fd, pair->first_fd()); - EXPECT_THAT(close(received_fd), SyscallSucceeds()); -} - -TEST_P(UnixSocketPairCmsgTest, BasicCredPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - EXPECT_EQ(sent_creds.pid, received_creds.pid); - EXPECT_EQ(sent_creds.uid, received_creds.uid); - EXPECT_EQ(sent_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairCmsgTest, SendNullCredsBeforeSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairCmsgTest, SendNullCredsAfterSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairCmsgTest, SendNullCredsBeforeSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->first_fd()); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairCmsgTest, SendNullCredsAfterSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairCmsgTest, - SendNullCredsBeforeSoPassCredRecvEndAfterSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairCmsgTest, WriteAfterSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->second_fd()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->first_fd()); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairCmsgTest, WriteAfterSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->first_fd()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredRecvEndAfterSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairCmsgTest, CredPassTruncated) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0) + sizeof(pid_t)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); - - pid_t pid = 0; - memcpy(&pid, CMSG_DATA(cmsg), sizeof(pid)); - EXPECT_EQ(pid, sent_creds.pid); -} - -// CredPassNoMsgCtrunc passes a full set of credentials. It then verifies that -// receiving the full set does not result in MSG_CTRUNC being set in the msghdr. -TEST_P(UnixSocketPairCmsgTest, CredPassNoMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(struct ucred))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should not be truncated. - EXPECT_EQ(msg.msg_flags, 0); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// CredPassNoSpaceMsgCtrunc passes a full set of credentials. It then receives -// the data without providing space for any credentials and verifies that -// MSG_CTRUNC is set in the msghdr. -TEST_P(UnixSocketPairCmsgTest, CredPassNoSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should be truncated. - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// CredPassTruncatedMsgCtrunc passes a full set of credentials. It then receives -// the data while providing enough space for only the first field of the -// credentials and verifies that MSG_CTRUNC is set in the msghdr. -TEST_P(UnixSocketPairCmsgTest, CredPassTruncatedMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0) + sizeof(pid_t)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should be truncated. - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -TEST_P(UnixSocketPairCmsgTest, SoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int opt; - socklen_t optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - SetSoPassCred(sockets->first_fd()); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_TRUE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - int zero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &zero, - sizeof(zero)), - SyscallSucceeds()); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); -} - -TEST_P(UnixSocketPairCmsgTest, NoDataCredPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct msghdr msg = {}; - - struct iovec iov; - iov.iov_base = sent_data; - iov.iov_len = sizeof(sent_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - char control[CMSG_SPACE(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_CREDENTIALS; - cmsg->cmsg_len = CMSG_LEN(0); - - ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UnixSocketPairCmsgTest, NoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - char received_data[20]; - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairCmsgTest, CredAndFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendCredsAndFD(sockets->first_fd(), sent_creds, - pair->second_fd(), sent_data, - sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(sent_creds.pid, received_creds.pid); - EXPECT_EQ(sent_creds.uid, received_creds.uid); - EXPECT_EQ(sent_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairCmsgTest, FDPassBeforeSoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairCmsgTest, CloexecDroppedWhenFDPassed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = ASSERT_NO_ERRNO_AND_VALUE( - UnixDomainSocketPair(SOCK_SEQPACKET | SOCK_CLOEXEC).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data))); - - EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(0)); -} - -TEST_P(UnixSocketPairCmsgTest, CloexecRecvFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - char received_data[20]; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CMSG_CLOEXEC), - SyscallSucceedsWithValue(sizeof(received_data))); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); - - int fd = -1; - memcpy(&fd, CMSG_DATA(cmsg), sizeof(int)); - - EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); -} - -TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCredWithoutCredSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_LEN(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[20]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// This test will validate that MSG_CTRUNC as an input flag to recvmsg will -// not appear as an output flag on the control message when truncation doesn't -// happen. -TEST_P(UnixSocketPairCmsgTest, MsgCtruncInputIsNoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) /* we're passing a single fd */]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - char received_data[20]; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CTRUNC), - SyscallSucceedsWithValue(sizeof(received_data))); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); - - // Now we should verify that MSG_CTRUNC wasn't set as an output flag. - EXPECT_EQ(msg.msg_flags & MSG_CTRUNC, 0); -} - -TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCredWithoutCredHeaderSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_LEN(0) / 2]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[20]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - EXPECT_EQ(msg.msg_controllen, 0); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_cmsg.h b/test/syscalls/linux/socket_unix_cmsg.h deleted file mode 100644 index 431606903..000000000 --- a/test/syscalls/linux/socket_unix_cmsg.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected unix sockets about -// control messages. -using UnixSocketPairCmsgTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc deleted file mode 100644 index af0df4fb4..000000000 --- a/test/syscalls/linux/socket_unix_dgram.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_unix_dgram.h" - -#include <stdio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(DgramUnixSocketPairTest, WriteOneSideClosed) { - // FIXME(b/35925052): gVisor datagram sockets return EPIPE instead of - // ECONNREFUSED. - SKIP_IF(IsRunningOnGvisor()); - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - constexpr char kStr[] = "abc"; - ASSERT_THAT(write(sockets->second_fd(), kStr, 3), - SyscallFailsWithErrno(ECONNREFUSED)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_dgram.h b/test/syscalls/linux/socket_unix_dgram.h deleted file mode 100644 index 0764ef85b..000000000 --- a/test/syscalls/linux/socket_unix_dgram.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_DGRAM_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_DGRAM_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected dgram unix sockets. -using DgramUnixSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_DGRAM_H_ diff --git a/test/syscalls/linux/socket_unix_dgram_local.cc b/test/syscalls/linux/socket_unix_dgram_local.cc deleted file mode 100644 index 31d2d5216..000000000 --- a/test/syscalls/linux/socket_unix_dgram_local.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_non_stream.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/socket_unix_dgram.h" -#include "test/syscalls/linux/socket_unix_non_stream.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return VecCat<SocketPairKind>(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>( - UnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_DGRAM, SOCK_RAW}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - FilesystemBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_DGRAM, SOCK_RAW}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_DGRAM, SOCK_RAW}, - List<int>{0, SOCK_NONBLOCK})))); -} - -INSTANTIATE_TEST_SUITE_P( - DgramUnixSockets, DgramUnixSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - DgramUnixSockets, UnixNonStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - DgramUnixSockets, NonStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc deleted file mode 100644 index 2db8b68d3..000000000 --- a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Test fixture for tests that apply to pairs of connected non-blocking dgram -// unix sockets. -using NonBlockingDgramUnixSocketPairTest = SocketPairTest; - -TEST_P(NonBlockingDgramUnixSocketPairTest, ReadOneSideClosed) { - if (IsRunningOnGvisor()) { - // FIXME(b/70803293): gVisor datagram sockets return 0 instead of - // EAGAIN. - return; - } - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - char data[10] = {}; - ASSERT_THAT(read(sockets->second_fd(), data, sizeof(data)), - SyscallFailsWithErrno(EAGAIN)); -} - -INSTANTIATE_TEST_SUITE_P( - NonBlockingDgramUnixSockets, NonBlockingDgramUnixSocketPairTest, - ::testing::ValuesIn(IncludeReversals(std::vector<SocketPairKind>{ - UnixDomainSocketPair(SOCK_DGRAM | SOCK_NONBLOCK), - FilesystemBoundUnixDomainSocketPair(SOCK_DGRAM | SOCK_NONBLOCK), - AbstractBoundUnixDomainSocketPair(SOCK_DGRAM | SOCK_NONBLOCK), - }))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_domain.cc b/test/syscalls/linux/socket_unix_domain.cc deleted file mode 100644 index f7dff8b4d..000000000 --- a/test/syscalls/linux/socket_unix_domain.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_generic.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVec<SocketPairKind>( - UnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, AllSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc deleted file mode 100644 index 6700b4d90..000000000 --- a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_non_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVec<SocketPairKind>( - FilesystemBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET}, - List<int>{SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P( - NonBlockingFilesystemUnixSockets, NonBlockingSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc deleted file mode 100644 index 884319e1d..000000000 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_unix_non_stream.h" - -#include <stdio.h> -#include <sys/mman.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/memory_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(UnixNonStreamSocketPairTest, RecvMsgTooLarge) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int rcvbuf; - socklen_t length = sizeof(rcvbuf); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &rcvbuf, &length), - SyscallSucceeds()); - - // Make the call larger than the receive buffer. - const int recv_size = 3 * rcvbuf; - - // Write a message that does fit in the receive buffer. - const int write_size = rcvbuf - kPageSize; - - std::vector<char> write_buf(write_size, 'a'); - const int ret = RetryEINTR(write)(sockets->second_fd(), write_buf.data(), - write_buf.size()); - if (ret < 0 && errno == ENOBUFS) { - // NOTE(b/116636318): Linux may stall the write for a long time and - // ultimately return ENOBUFS. Allow this error, since a retry will likely - // result in the same error. - return; - } - ASSERT_THAT(ret, SyscallSucceeds()); - - std::vector<char> recv_buf(recv_size); - - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sockets->first_fd(), recv_buf.data(), - recv_buf.size(), write_size)); - - recv_buf.resize(write_size); - EXPECT_EQ(recv_buf, write_buf); -} - -// Create a region of anonymous memory of size 'size', which is fragmented in -// FileMem. -// -// ptr contains the start address of the region. The returned vector contains -// all of the mappings to be unmapped when done. -PosixErrorOr<std::vector<Mapping>> CreateFragmentedRegion(const int size, - void** ptr) { - Mapping region; - ASSIGN_OR_RETURN_ERRNO(region, Mmap(nullptr, size, PROT_NONE, - MAP_ANONYMOUS | MAP_PRIVATE, -1, 0)); - - *ptr = region.ptr(); - - // Don't save hundreds of times for all of these mmaps. - DisableSave ds; - - std::vector<Mapping> pages; - - // Map and commit a single page at a time, mapping and committing an unrelated - // page between each call to force FileMem fragmentation. - for (uintptr_t addr = region.addr(); addr < region.endaddr(); - addr += kPageSize) { - Mapping page; - ASSIGN_OR_RETURN_ERRNO( - page, - Mmap(reinterpret_cast<void*>(addr), kPageSize, PROT_READ | PROT_WRITE, - MAP_ANONYMOUS | MAP_PRIVATE | MAP_FIXED, -1, 0)); - *reinterpret_cast<volatile char*>(page.ptr()) = 42; - - pages.emplace_back(std::move(page)); - - // Unrelated page elsewhere. - ASSIGN_OR_RETURN_ERRNO(page, - Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, - MAP_ANONYMOUS | MAP_PRIVATE, -1, 0)); - *reinterpret_cast<volatile char*>(page.ptr()) = 42; - - pages.emplace_back(std::move(page)); - } - - // The mappings above have taken ownership of the region. - region.release(); - - return std::move(pages); -} - -// A contiguous iov that is heavily fragmented in FileMem can still be sent -// successfully. See b/115833655. -TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - const int buffer_size = UIO_MAXIOV * kPageSize; - // Extra page for message header overhead. - const int sndbuf = buffer_size + kPageSize; - // N.B. setsockopt(SO_SNDBUF) doubles the passed value. - const int set_sndbuf = sndbuf / 2; - - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, - &set_sndbuf, sizeof(set_sndbuf)), - SyscallSucceeds()); - - int actual_sndbuf = 0; - socklen_t length = sizeof(actual_sndbuf); - ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, - &actual_sndbuf, &length), - SyscallSucceeds()); - - if (actual_sndbuf != sndbuf) { - // Unable to get the sndbuf we want. - // - // N.B. At minimum, the socketpair gofer should provide a socket that is - // already the correct size. - // - // TODO(b/35921550): When internal UDS support SO_SNDBUF, we can assert that - // we always get the right SO_SNDBUF on gVisor. - GTEST_SKIP() << "SO_SNDBUF = " << actual_sndbuf << ", want " << sndbuf; - } - - // Create a contiguous region of memory of 2*UIO_MAXIOV*PAGE_SIZE. We'll call - // sendmsg with a single iov, but the goal is to get the sentry to split this - // into > UIO_MAXIOV iovs when calling the kernel. - void* ptr; - std::vector<Mapping> pages = - ASSERT_NO_ERRNO_AND_VALUE(CreateFragmentedRegion(buffer_size, &ptr)); - - struct iovec iov = {}; - iov.iov_base = ptr; - iov.iov_len = buffer_size; - - struct msghdr msg = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - // NOTE(b/116636318,b/115833655): Linux has poor behavior in the presence of - // physical memory fragmentation. As a result, this may stall for a long time - // and ultimately return ENOBUFS. Allow this error, since it means that we - // made it to the host kernel and started the sendmsg. - EXPECT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), - AnyOf(SyscallSucceedsWithValue(buffer_size), - SyscallFailsWithErrno(ENOBUFS))); -} - -// A contiguous iov that is heavily fragmented in FileMem can still be received -// into successfully. Regression test for b/115833655. -TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - const int buffer_size = UIO_MAXIOV * kPageSize; - // Extra page for message header overhead. - const int sndbuf = buffer_size + kPageSize; - // N.B. setsockopt(SO_SNDBUF) doubles the passed value. - const int set_sndbuf = sndbuf / 2; - - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, - &set_sndbuf, sizeof(set_sndbuf)), - SyscallSucceeds()); - - int actual_sndbuf = 0; - socklen_t length = sizeof(actual_sndbuf); - ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, - &actual_sndbuf, &length), - SyscallSucceeds()); - - if (actual_sndbuf != sndbuf) { - // Unable to get the sndbuf we want. - // - // N.B. At minimum, the socketpair gofer should provide a socket that is - // already the correct size. - // - // TODO(b/35921550): When internal UDS support SO_SNDBUF, we can assert that - // we always get the right SO_SNDBUF on gVisor. - GTEST_SKIP() << "SO_SNDBUF = " << actual_sndbuf << ", want " << sndbuf; - } - - std::vector<char> write_buf(buffer_size, 'a'); - const int ret = RetryEINTR(write)(sockets->first_fd(), write_buf.data(), - write_buf.size()); - if (ret < 0 && errno == ENOBUFS) { - // NOTE(b/116636318): Linux may stall the write for a long time and - // ultimately return ENOBUFS. Allow this error, since a retry will likely - // result in the same error. - return; - } - ASSERT_THAT(ret, SyscallSucceeds()); - - // Create a contiguous region of memory of 2*UIO_MAXIOV*PAGE_SIZE. We'll call - // sendmsg with a single iov, but the goal is to get the sentry to split this - // into > UIO_MAXIOV iovs when calling the kernel. - void* ptr; - std::vector<Mapping> pages = - ASSERT_NO_ERRNO_AND_VALUE(CreateFragmentedRegion(buffer_size, &ptr)); - - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg( - sockets->second_fd(), reinterpret_cast<char*>(ptr), buffer_size)); - - EXPECT_EQ(0, memcmp(write_buf.data(), ptr, buffer_size)); -} - -TEST_P(UnixNonStreamSocketPairTest, SendTimeout) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT( - setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), - SyscallSucceeds()); - - const int buf_size = 5 * kPageSize; - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &buf_size, - sizeof(buf_size)), - SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVBUF, &buf_size, - sizeof(buf_size)), - SyscallSucceeds()); - - // The buffer size should be big enough to avoid many iterations in the next - // loop. Otherwise, this will slow down cooperative_save tests. - std::vector<char> buf(kPageSize); - for (;;) { - int ret; - ASSERT_THAT( - ret = RetryEINTR(send)(sockets->first_fd(), buf.data(), buf.size(), 0), - ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN))); - if (ret == -1) { - break; - } - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_non_stream.h b/test/syscalls/linux/socket_unix_non_stream.h deleted file mode 100644 index 7478ab172..000000000 --- a/test/syscalls/linux/socket_unix_non_stream.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_NON_STREAM_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_NON_STREAM_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected non-stream -// unix-domain sockets. -using UnixNonStreamSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_NON_STREAM_H_ diff --git a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc deleted file mode 100644 index fddcdf1c5..000000000 --- a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_non_stream_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>(UnixDomainSocketPair, - std::vector<int>{SOCK_DGRAM, SOCK_SEQPACKET}), - ApplyVec<SocketPairKind>(FilesystemBoundUnixDomainSocketPair, - std::vector<int>{SOCK_DGRAM, SOCK_SEQPACKET}), - ApplyVec<SocketPairKind>(AbstractBoundUnixDomainSocketPair, - std::vector<int>{SOCK_DGRAM, SOCK_SEQPACKET})); -} - -INSTANTIATE_TEST_SUITE_P( - BlockingNonStreamUnixSockets, BlockingNonStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_pair.cc b/test/syscalls/linux/socket_unix_pair.cc deleted file mode 100644 index 85999db04..000000000 --- a/test/syscalls/linux/socket_unix_pair.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/socket_unix.h" -#include "test/syscalls/linux/socket_unix_cmsg.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return VecCat<SocketPairKind>(ApplyVec<SocketPairKind>( - UnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK}))); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnixSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnixSocketPairCmsgTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_pair_nonblock.cc b/test/syscalls/linux/socket_unix_pair_nonblock.cc deleted file mode 100644 index 281410a9a..000000000 --- a/test/syscalls/linux/socket_unix_pair_nonblock.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_non_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVec<SocketPairKind>( - UnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET}, - List<int>{SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P( - NonBlockingUnixSockets, NonBlockingSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc deleted file mode 100644 index 84d3a569e..000000000 --- a/test/syscalls/linux/socket_unix_seqpacket.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/socket_unix_seqpacket.h" - -#include <stdio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(SeqpacketUnixSocketPairTest, WriteOneSideClosed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - constexpr char kStr[] = "abc"; - ASSERT_THAT(write(sockets->second_fd(), kStr, 3), - SyscallFailsWithErrno(EPIPE)); -} - -TEST_P(SeqpacketUnixSocketPairTest, ReadOneSideClosed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - char data[10] = {}; - ASSERT_THAT(read(sockets->second_fd(), data, sizeof(data)), - SyscallSucceedsWithValue(0)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_seqpacket.h b/test/syscalls/linux/socket_unix_seqpacket.h deleted file mode 100644 index 30d9b9edf..000000000 --- a/test/syscalls/linux/socket_unix_seqpacket.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_SEQPACKET_H_ -#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_SEQPACKET_H_ - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// Test fixture for tests that apply to pairs of connected seqpacket unix -// sockets. -using SeqpacketUnixSocketPairTest = SocketPairTest; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_SEQPACKET_H_ diff --git a/test/syscalls/linux/socket_unix_seqpacket_local.cc b/test/syscalls/linux/socket_unix_seqpacket_local.cc deleted file mode 100644 index 69a5f150d..000000000 --- a/test/syscalls/linux/socket_unix_seqpacket_local.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_non_stream.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/socket_unix_non_stream.h" -#include "test/syscalls/linux/socket_unix_seqpacket.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return VecCat<SocketPairKind>(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>( - UnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - FilesystemBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})))); -} - -INSTANTIATE_TEST_SUITE_P( - SeqpacketUnixSockets, NonStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - SeqpacketUnixSockets, SeqpacketUnixSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -INSTANTIATE_TEST_SUITE_P( - SeqpacketUnixSockets, UnixNonStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc deleted file mode 100644 index 563467365..000000000 --- a/test/syscalls/linux/socket_unix_stream.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <poll.h> -#include <stdio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Test fixture for tests that apply to pairs of connected stream unix sockets. -using StreamUnixSocketPairTest = SocketPairTest; - -TEST_P(StreamUnixSocketPairTest, WriteOneSideClosed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - constexpr char kStr[] = "abc"; - ASSERT_THAT(write(sockets->second_fd(), kStr, 3), - SyscallFailsWithErrno(EPIPE)); -} - -TEST_P(StreamUnixSocketPairTest, ReadOneSideClosed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - char data[10] = {}; - ASSERT_THAT(read(sockets->second_fd(), data, sizeof(data)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(StreamUnixSocketPairTest, RecvmsgOneSideClosed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Set timeout so that it will not wait for ever. - struct timeval tv { - .tv_sec = 0, .tv_usec = 10 - }; - EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, - sizeof(tv)), - SyscallSucceeds()); - - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - - char received_data[10] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - struct msghdr msg = {}; - msg.msg_flags = -1; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(recvmsg(sockets->second_fd(), &msg, MSG_WAITALL), - SyscallSucceedsWithValue(0)); -} - -TEST_P(StreamUnixSocketPairTest, ReadOneSideClosedWithUnreadData) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char buf[10] = {}; - ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); - - ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), - SyscallSucceedsWithValue(0)); - - ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); - - ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), - SyscallFailsWithErrno(ECONNRESET)); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, StreamUnixSocketPairTest, - ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>(UnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>(FilesystemBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, SOCK_NONBLOCK})))))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_stream_blocking_local.cc deleted file mode 100644 index 8429bd429..000000000 --- a/test/syscalls/linux/socket_unix_stream_blocking_local.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_stream_blocking.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return { - UnixDomainSocketPair(SOCK_STREAM), - FilesystemBoundUnixDomainSocketPair(SOCK_STREAM), - AbstractBoundUnixDomainSocketPair(SOCK_STREAM), - }; -} - -INSTANTIATE_TEST_SUITE_P( - BlockingStreamUnixSockets, BlockingStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_stream_local.cc b/test/syscalls/linux/socket_unix_stream_local.cc deleted file mode 100644 index a7e3449a9..000000000 --- a/test/syscalls/linux/socket_unix_stream_local.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include "test/syscalls/linux/socket_stream.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>( - UnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - FilesystemBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, SOCK_NONBLOCK}))); -} - -INSTANTIATE_TEST_SUITE_P( - StreamUnixSockets, StreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc deleted file mode 100644 index 4b763c8e2..000000000 --- a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include <vector> - -#include "test/syscalls/linux/socket_stream_nonblock.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -std::vector<SocketPairKind> GetSocketPairs() { - return { - UnixDomainSocketPair(SOCK_STREAM | SOCK_NONBLOCK), - FilesystemBoundUnixDomainSocketPair(SOCK_STREAM | SOCK_NONBLOCK), - AbstractBoundUnixDomainSocketPair(SOCK_STREAM | SOCK_NONBLOCK), - }; -} - -INSTANTIATE_TEST_SUITE_P( - NonBlockingStreamUnixSockets, NonBlockingStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_unbound_abstract.cc b/test/syscalls/linux/socket_unix_unbound_abstract.cc deleted file mode 100644 index 8b1762000..000000000 --- a/test/syscalls/linux/socket_unix_unbound_abstract.cc +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Test fixture for tests that apply to pairs of unbound abstract unix sockets. -using UnboundAbstractUnixSocketPairTest = SocketPairTest; - -TEST_P(UnboundAbstractUnixSocketPairTest, AddressAfterNull) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct sockaddr_un addr = - *reinterpret_cast<const struct sockaddr_un*>(sockets->first_addr()); - ASSERT_EQ(addr.sun_path[sizeof(addr.sun_path) - 1], 0); - SKIP_IF(addr.sun_path[sizeof(addr.sun_path) - 2] != 0 || - addr.sun_path[sizeof(addr.sun_path) - 3] != 0); - - addr.sun_path[sizeof(addr.sun_path) - 2] = 'a'; - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallSucceeds()); -} - -TEST_P(UnboundAbstractUnixSocketPairTest, ShortAddressNotExtended) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct sockaddr_un addr = - *reinterpret_cast<const struct sockaddr_un*>(sockets->first_addr()); - ASSERT_EQ(addr.sun_path[sizeof(addr.sun_path) - 1], 0); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size() - 1), - SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); -} - -TEST_P(UnboundAbstractUnixSocketPairTest, BindNothing) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - struct sockaddr_un addr = {.sun_family = AF_UNIX}; - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallSucceeds()); -} - -TEST_P(UnboundAbstractUnixSocketPairTest, GetSockNameFullLength) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - sockaddr_storage addr = {}; - socklen_t addr_len = sizeof(addr); - ASSERT_THAT(getsockname(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(&addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, sockets->first_addr_size()); -} - -TEST_P(UnboundAbstractUnixSocketPairTest, GetSockNamePartialLength) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size() - 1), - SyscallSucceeds()); - - sockaddr_storage addr = {}; - socklen_t addr_len = sizeof(addr); - ASSERT_THAT(getsockname(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(&addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, sockets->first_addr_size() - 1); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnboundAbstractUnixSocketPairTest, - ::testing::ValuesIn(ApplyVec<SocketPairKind>( - AbstractUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_SEQPACKET, - SOCK_DGRAM}, - List<int>{0, SOCK_NONBLOCK})))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_unbound_dgram.cc b/test/syscalls/linux/socket_unix_unbound_dgram.cc deleted file mode 100644 index 907dca0f1..000000000 --- a/test/syscalls/linux/socket_unix_unbound_dgram.cc +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/socket.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Test fixture for tests that apply to pairs of unbound dgram unix sockets. -using UnboundDgramUnixSocketPairTest = SocketPairTest; - -TEST_P(UnboundDgramUnixSocketPairTest, BindConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); -} - -TEST_P(UnboundDgramUnixSocketPairTest, SelfConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(connect(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); -} - -TEST_P(UnboundDgramUnixSocketPairTest, DoubleConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); -} - -TEST_P(UnboundDgramUnixSocketPairTest, GetRemoteAddress) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - socklen_t addressLength = sockets->first_addr_size(); - struct sockaddr_storage address = {}; - ASSERT_THAT(getpeername(sockets->second_fd(), (struct sockaddr*)(&address), - &addressLength), - SyscallSucceeds()); - EXPECT_EQ( - 0, memcmp(&address, sockets->first_addr(), sockets->first_addr_size())); -} - -TEST_P(UnboundDgramUnixSocketPairTest, Sendto) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(sendto(sockets->second_fd(), sent_data, sizeof(sent_data), 0, - sockets->first_addr(), sockets->first_addr_size()), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[sizeof(sent_data)]; - ASSERT_THAT(ReadFd(sockets->first_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data))); -} - -TEST_P(UnboundDgramUnixSocketPairTest, ZeroWriteAllowed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - char sent_data[3]; - // Send a zero length packet. - ASSERT_THAT(write(sockets->second_fd(), sent_data, 0), - SyscallSucceedsWithValue(0)); - // Receive the packet. - char received_data[sizeof(sent_data)]; - ASSERT_THAT(read(sockets->first_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UnboundDgramUnixSocketPairTest, Listen) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(ENOTSUP)); -} - -TEST_P(UnboundDgramUnixSocketPairTest, Accept) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - ASSERT_THAT(accept(sockets->first_fd(), nullptr, nullptr), - SyscallFailsWithErrno(ENOTSUP)); -} - -TEST_P(UnboundDgramUnixSocketPairTest, SendtoWithoutConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - char data = 'a'; - ASSERT_THAT( - RetryEINTR(sendto)(sockets->second_fd(), &data, sizeof(data), 0, - sockets->first_addr(), sockets->first_addr_size()), - SyscallSucceedsWithValue(sizeof(data))); -} - -TEST_P(UnboundDgramUnixSocketPairTest, SendtoWithoutConnectPassCreds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - SetSoPassCred(sockets->first_fd()); - char data = 'a'; - ASSERT_THAT( - RetryEINTR(sendto)(sockets->second_fd(), &data, sizeof(data), 0, - sockets->first_addr(), sockets->first_addr_size()), - SyscallSucceedsWithValue(sizeof(data))); - ucred creds; - creds.pid = -1; - char buf[sizeof(data) + 1]; - ASSERT_NO_FATAL_FAILURE( - RecvCreds(sockets->first_fd(), &creds, buf, sizeof(buf), sizeof(data))); - EXPECT_EQ(0, memcmp(&data, buf, sizeof(data))); - EXPECT_THAT(getpid(), SyscallSucceedsWithValue(creds.pid)); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnboundDgramUnixSocketPairTest, - ::testing::ValuesIn(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>(FilesystemUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_DGRAM}, - List<int>{ - 0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_DGRAM}, - List<int>{0, SOCK_NONBLOCK}))))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc deleted file mode 100644 index cab912152..000000000 --- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Test fixture for tests that apply to pairs of unbound filesystem unix -// sockets. -using UnboundFilesystemUnixSocketPairTest = SocketPairTest; - -TEST_P(UnboundFilesystemUnixSocketPairTest, AddressAfterNull) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - struct sockaddr_un addr = - *reinterpret_cast<const struct sockaddr_un*>(sockets->first_addr()); - ASSERT_EQ(addr.sun_path[sizeof(addr.sun_path) - 1], 0); - SKIP_IF(addr.sun_path[sizeof(addr.sun_path) - 2] != 0 || - addr.sun_path[sizeof(addr.sun_path) - 3] != 0); - - addr.sun_path[sizeof(addr.sun_path) - 2] = 'a'; - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(UnboundFilesystemUnixSocketPairTest, GetSockNameLength) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - sockaddr_storage got_addr = {}; - socklen_t got_addr_len = sizeof(got_addr); - ASSERT_THAT( - getsockname(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(&got_addr), &got_addr_len), - SyscallSucceeds()); - - sockaddr_un want_addr = - *reinterpret_cast<const struct sockaddr_un*>(sockets->first_addr()); - - EXPECT_EQ(got_addr_len, - strlen(want_addr.sun_path) + 1 + sizeof(want_addr.sun_family)); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnboundFilesystemUnixSocketPairTest, - ::testing::ValuesIn(ApplyVec<SocketPairKind>( - FilesystemUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_SEQPACKET, - SOCK_DGRAM}, - List<int>{0, SOCK_NONBLOCK})))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc deleted file mode 100644 index cb99030f5..000000000 --- a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Test fixture for tests that apply to pairs of unbound seqpacket unix sockets. -using UnboundUnixSeqpacketSocketPairTest = SocketPairTest; - -TEST_P(UnboundUnixSeqpacketSocketPairTest, SendtoWithoutConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - char data = 'a'; - ASSERT_THAT(sendto(sockets->second_fd(), &data, sizeof(data), 0, - sockets->first_addr(), sockets->first_addr_size()), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UnboundUnixSeqpacketSocketPairTest, SendtoWithoutConnectIgnoresAddr) { - // FIXME(b/68223466): gVisor tries to find /foo/bar and thus returns ENOENT. - if (IsRunningOnGvisor()) { - return; - } - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - // Even a bogus address is completely ignored. - constexpr char kPath[] = "/foo/bar"; - - // Sanity check that kPath doesn't exist. - struct stat s; - ASSERT_THAT(stat(kPath, &s), SyscallFailsWithErrno(ENOENT)); - - struct sockaddr_un addr = {}; - addr.sun_family = AF_UNIX; - memcpy(addr.sun_path, kPath, sizeof(kPath)); - - char data = 'a'; - ASSERT_THAT( - sendto(sockets->second_fd(), &data, sizeof(data), 0, - reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(ENOTCONN)); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnboundUnixSeqpacketSocketPairTest, - ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>( - FilesystemUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_SEQPACKET}, - List<int>{0, SOCK_NONBLOCK})))))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_unbound_stream.cc b/test/syscalls/linux/socket_unix_unbound_stream.cc deleted file mode 100644 index f185dded3..000000000 --- a/test/syscalls/linux/socket_unix_unbound_stream.cc +++ /dev/null @@ -1,733 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdio.h> -#include <sys/un.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Test fixture for tests that apply to pairs of connected unix stream sockets. -using UnixStreamSocketPairTest = SocketPairTest; - -// FDPassPartialRead checks that sent control messages cannot be read after -// any of their associated data has been read while ignoring the control message -// by using read(2) instead of recvmsg(2). -TEST_P(UnixStreamSocketPairTest, FDPassPartialRead) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[sizeof(sent_data) / 2]; - ASSERT_THAT( - ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data))); - - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)); - EXPECT_EQ(0, memcmp(sent_data + sizeof(received_data), received_data, - sizeof(received_data))); -} - -TEST_P(UnixStreamSocketPairTest, FDPassCoalescedRead) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(), - sent_data1, sizeof(sent_data1))); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(), - sent_data2, sizeof(sent_data2))); - - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - ASSERT_THAT( - ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1), - sizeof(sent_data2))); -} - -// ZeroLengthMessageFDDiscarded checks that control messages associated with -// zero length messages are discarded. -TEST_P(UnixStreamSocketPairTest, ZeroLengthMessageFDDiscarded) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - // Zero length arrays are invalid in ISO C++, so allocate one of size 1 and - // send a length of 0. - char sent_data1[1] = {}; - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE( - SendSingleFD(sockets->first_fd(), pair->second_fd(), sent_data1, 0)); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - char received_data[sizeof(sent_data2)] = {}; - - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)); - EXPECT_EQ(0, memcmp(sent_data2, received_data, sizeof(received_data))); -} - -// FDPassCoalescedRecv checks that control messages not in the first message are -// preserved in a coalesced recv. -TEST_P(UnixStreamSocketPairTest, FDPassCoalescedRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data) / 2), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data + sizeof(sent_data) / 2, - sizeof(sent_data) / 2)); - - char received_data[sizeof(sent_data)]; - - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -// ReadsNotCoalescedAfterFDPass checks that messages after a message containing -// an FD control message are not coalesced. -TEST_P(UnixStreamSocketPairTest, ReadsNotCoalescedAfterFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data) / 2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data + sizeof(sent_data) / 2, - sizeof(sent_data) / 2), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - - char received_data[sizeof(sent_data)]; - - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data), - sizeof(sent_data) / 2)); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2)); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(sent_data) / 2)); - - EXPECT_EQ(0, memcmp(sent_data + sizeof(sent_data) / 2, received_data, - sizeof(sent_data) / 2)); -} - -// FDPassNotCombined checks that FD control messages are not combined in a -// coalesced read. -TEST_P(UnixStreamSocketPairTest, FDPassNotCombined) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(), - sent_data, sizeof(sent_data) / 2)); - - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(), - sent_data + sizeof(sent_data) / 2, - sizeof(sent_data) / 2)); - - char received_data[sizeof(sent_data)]; - - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data), - sizeof(sent_data) / 2)); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2)); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair1->first_fd())); - - EXPECT_THAT(close(fd), SyscallSucceeds()); - fd = -1; - - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data), - sizeof(sent_data) / 2)); - - EXPECT_EQ(0, memcmp(sent_data + sizeof(sent_data) / 2, received_data, - sizeof(sent_data) / 2)); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair2->first_fd())); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST_P(UnixStreamSocketPairTest, CredPassPartialRead) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - int one = 1; - ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &one, - sizeof(one)), - SyscallSucceeds()); - - for (int i = 0; i < 2; i++) { - char received_data[10]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data), - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data + i * sizeof(received_data), received_data, - sizeof(received_data))); - EXPECT_EQ(sent_creds.pid, received_creds.pid); - EXPECT_EQ(sent_creds.uid, received_creds.uid); - EXPECT_EQ(sent_creds.gid, received_creds.gid); - } -} - -// Unix stream sockets peek in the same way as datagram sockets. -// -// SinglePeek checks that only a single message is peekable in a single recv. -TEST_P(UnixStreamSocketPairTest, SinglePeek) { - if (!IsRunningOnGvisor()) { - // Don't run this test on linux kernels newer than 4.3.x Linux kernel commit - // 9f389e35674f5b086edd70ed524ca0f287259725 which changes this behavior. We - // used to target 3.11 compatibility, so disable this test on newer kernels. - // - // NOTE(b/118902768): Bring this up to Linux 4.4 compatibility. - auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); - SKIP_IF(version.major > 4 || (version.major == 4 && version.minor >= 3)); - } - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char sent_data[40]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), sent_data, - sizeof(sent_data) / 2, 0), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), sent_data + sizeof(sent_data) / 2, - sizeof(sent_data) / 2, 0), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - char received_data[sizeof(sent_data)]; - for (int i = 0; i < 3; i++) { - memset(received_data, 0, sizeof(received_data)); - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(received_data), MSG_PEEK), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2)); - } - memset(received_data, 0, sizeof(received_data)); - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(sent_data) / 2, 0), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2)); - memset(received_data, 0, sizeof(received_data)); - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data, - sizeof(sent_data) / 2, 0), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - EXPECT_EQ(0, memcmp(sent_data + sizeof(sent_data) / 2, received_data, - sizeof(sent_data) / 2)); -} - -TEST_P(UnixStreamSocketPairTest, CredsNotCoalescedUp) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)), - SyscallSucceedsWithValue(sizeof(sent_data1))); - - SetSoPassCred(sockets->second_fd()); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data), - sizeof(sent_data1))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data), - sizeof(sent_data2))); - - EXPECT_EQ(0, memcmp(sent_data2, received_data, sizeof(sent_data2))); - - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixStreamSocketPairTest, CredsNotCoalescedDown) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->second_fd()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)), - SyscallSucceedsWithValue(sizeof(sent_data1))); - - UnsetSoPassCred(sockets->second_fd()); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - struct ucred received_creds; - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data), - sizeof(sent_data1))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data), - sizeof(sent_data2))); - - EXPECT_EQ(0, memcmp(sent_data2, received_data, sizeof(sent_data2))); - - want_creds = {0, 65534, 65534}; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixStreamSocketPairTest, CoalescedCredsNoPasscred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->second_fd()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)), - SyscallSucceedsWithValue(sizeof(sent_data1))); - - UnsetSoPassCred(sockets->second_fd()); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1), - sizeof(sent_data2))); -} - -TEST_P(UnixStreamSocketPairTest, CoalescedCreds1) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)), - SyscallSucceedsWithValue(sizeof(sent_data1))); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - struct ucred received_creds; - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1), - sizeof(sent_data2))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixStreamSocketPairTest, CoalescedCreds2) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->second_fd()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)), - SyscallSucceedsWithValue(sizeof(sent_data1))); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - char received_data[sizeof(sent_data1) + sizeof(sent_data2)]; - struct ucred received_creds; - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1), - sizeof(sent_data2))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixStreamSocketPairTest, NonCoalescedDifferingCreds1) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)), - SyscallSucceedsWithValue(sizeof(sent_data1))); - - SetSoPassCred(sockets->second_fd()); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - char received_data1[sizeof(sent_data1) + sizeof(sent_data2)]; - struct ucred received_creds1; - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds1, - received_data1, sizeof(sent_data1))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1))); - - struct ucred want_creds1 { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds1.pid, received_creds1.pid); - EXPECT_EQ(want_creds1.uid, received_creds1.uid); - EXPECT_EQ(want_creds1.gid, received_creds1.gid); - - char received_data2[sizeof(sent_data1) + sizeof(sent_data2)]; - struct ucred received_creds2; - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds2, - received_data2, sizeof(sent_data2))); - - EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2))); - - struct ucred want_creds2; - ASSERT_THAT(want_creds2.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds2.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds2.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds2.pid, received_creds2.pid); - EXPECT_EQ(want_creds2.uid, received_creds2.uid); - EXPECT_EQ(want_creds2.gid, received_creds2.gid); -} - -TEST_P(UnixStreamSocketPairTest, NonCoalescedDifferingCreds2) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->second_fd()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)), - SyscallSucceedsWithValue(sizeof(sent_data1))); - - UnsetSoPassCred(sockets->second_fd()); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - SetSoPassCred(sockets->second_fd()); - - char received_data1[sizeof(sent_data1) + sizeof(sent_data2)]; - struct ucred received_creds1; - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds1, - received_data1, sizeof(sent_data1))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1))); - - struct ucred want_creds1; - ASSERT_THAT(want_creds1.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds1.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds1.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds1.pid, received_creds1.pid); - EXPECT_EQ(want_creds1.uid, received_creds1.uid); - EXPECT_EQ(want_creds1.gid, received_creds1.gid); - - char received_data2[sizeof(sent_data1) + sizeof(sent_data2)]; - struct ucred received_creds2; - - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds2, - received_data2, sizeof(sent_data2))); - - EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2))); - - struct ucred want_creds2 { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds2.pid, received_creds2.pid); - EXPECT_EQ(want_creds2.uid, received_creds2.uid); - EXPECT_EQ(want_creds2.gid, received_creds2.gid); -} - -TEST_P(UnixStreamSocketPairTest, CoalescedDifferingCreds) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->second_fd()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)), - SyscallSucceedsWithValue(sizeof(sent_data1))); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)), - SyscallSucceedsWithValue(sizeof(sent_data2))); - - UnsetSoPassCred(sockets->second_fd()); - - char sent_data3[20]; - RandomizeBuffer(sent_data3, sizeof(sent_data3)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data3, sizeof(sent_data3)), - SyscallSucceedsWithValue(sizeof(sent_data3))); - - char received_data[sizeof(sent_data1) + sizeof(sent_data2) + - sizeof(sent_data3)]; - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); - EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1), - sizeof(sent_data2))); - EXPECT_EQ(0, memcmp(sent_data3, - received_data + sizeof(sent_data1) + sizeof(sent_data2), - sizeof(sent_data3))); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnixStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>(UnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>(FilesystemBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractBoundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, SOCK_NONBLOCK})))))); - -// Test fixture for tests that apply to pairs of unbound unix stream sockets. -using UnboundUnixStreamSocketPairTest = SocketPairTest; - -TEST_P(UnboundUnixStreamSocketPairTest, SendtoWithoutConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - char data = 'a'; - ASSERT_THAT(sendto(sockets->second_fd(), &data, sizeof(data), 0, - sockets->first_addr(), sockets->first_addr_size()), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -TEST_P(UnboundUnixStreamSocketPairTest, SendtoWithoutConnectIgnoresAddr) { - // FIXME(b/68223466): gVisor tries to find /foo/bar and thus returns ENOENT. - if (IsRunningOnGvisor()) { - return; - } - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), - sockets->first_addr_size()), - SyscallSucceeds()); - - // Even a bogus address is completely ignored. - constexpr char kPath[] = "/foo/bar"; - - // Sanity check that kPath doesn't exist. - struct stat s; - ASSERT_THAT(stat(kPath, &s), SyscallFailsWithErrno(ENOENT)); - - struct sockaddr_un addr = {}; - addr.sun_family = AF_UNIX; - memcpy(addr.sun_path, kPath, sizeof(kPath)); - - char data = 'a'; - ASSERT_THAT( - sendto(sockets->second_fd(), &data, sizeof(data), 0, - reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnboundUnixStreamSocketPairTest, - ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( - ApplyVec<SocketPairKind>(FilesystemUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})), - ApplyVec<SocketPairKind>( - AbstractUnboundUnixDomainSocketPair, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, SOCK_NONBLOCK})))))); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc deleted file mode 100644 index faa1247f6..000000000 --- a/test/syscalls/linux/splice.cc +++ /dev/null @@ -1,649 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/eventfd.h> -#include <sys/resource.h> -#include <sys/sendfile.h> -#include <sys/time.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(SpliceTest, TwoRegularFiles) { - // Create temp files. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Open the input file as read only. - const FileDescriptor in_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open the output file as write only. - const FileDescriptor out_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); - - // Verify that it is rejected as expected; regardless of offsets. - loff_t in_offset = 0; - loff_t out_offset = 0; - EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), &out_offset, 1, 0), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), &out_offset, 1, 0), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), nullptr, 1, 0), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), nullptr, 1, 0), - SyscallFailsWithErrno(EINVAL)); -} - -int memfd_create(const std::string& name, unsigned int flags) { - return syscall(__NR_memfd_create, name.c_str(), flags); -} - -TEST(SpliceTest, NegativeOffset) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill the pipe. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Open the output file as write only. - int fd; - EXPECT_THAT(fd = memfd_create("negative", 0), SyscallSucceeds()); - const FileDescriptor out_fd(fd); - - loff_t out_offset = 0xffffffffffffffffull; - constexpr int kSize = 2; - EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - -// Write offset + size overflows int64. -// -// This is a regression test for b/148041624. -TEST(SpliceTest, WriteOverflow) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill the pipe. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Open the output file. - int fd; - EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds()); - const FileDescriptor out_fd(fd); - - // out_offset + kSize overflows INT64_MAX. - loff_t out_offset = 0x7ffffffffffffffeull; - constexpr int kSize = 3; - EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SpliceTest, SamePipe) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill the pipe. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Attempt to splice to itself. - EXPECT_THAT(splice(rfd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(TeeTest, SamePipe) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill the pipe. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Attempt to tee to itself. - EXPECT_THAT(tee(rfd.get(), wfd.get(), kPageSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(TeeTest, RegularFile) { - // Open some file. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor in_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Attempt to tee from the file. - EXPECT_THAT(tee(in_fd.get(), wfd.get(), kPageSize, 0), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(tee(rfd.get(), in_fd.get(), kPageSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SpliceTest, PipeOffsets) { - // Create two new pipes. - int first[2], second[2]; - ASSERT_THAT(pipe(first), SyscallSucceeds()); - const FileDescriptor rfd1(first[0]); - const FileDescriptor wfd1(first[1]); - ASSERT_THAT(pipe(second), SyscallSucceeds()); - const FileDescriptor rfd2(second[0]); - const FileDescriptor wfd2(second[1]); - - // All pipe offsets should be rejected. - loff_t in_offset = 0; - loff_t out_offset = 0; - EXPECT_THAT(splice(rfd1.get(), &in_offset, wfd2.get(), &out_offset, 1, 0), - SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(splice(rfd1.get(), nullptr, wfd2.get(), &out_offset, 1, 0), - SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(splice(rfd1.get(), &in_offset, wfd2.get(), nullptr, 1, 0), - SyscallFailsWithErrno(ESPIPE)); -} - -// Event FDs may be used with splice without an offset. -TEST(SpliceTest, FromEventFD) { - // Open the input eventfd with an initial value so that it is readable. - constexpr uint64_t kEventFDValue = 1; - int efd; - ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds()); - const FileDescriptor in_fd(efd); - - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Splice 8-byte eventfd value to pipe. - constexpr int kEventFDSize = 8; - EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), - SyscallSucceedsWithValue(kEventFDSize)); - - // Contents should be equal. - std::vector<char> rbuf(kEventFDSize); - ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kEventFDSize)); - EXPECT_EQ(memcmp(rbuf.data(), &kEventFDValue, rbuf.size()), 0); -} - -// Event FDs may not be used with splice with an offset. -TEST(SpliceTest, FromEventFDOffset) { - int efd; - ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor in_fd(efd); - - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Attempt to splice 8-byte eventfd value to pipe with offset. - // - // This is not allowed because eventfd doesn't support pread. - constexpr int kEventFDSize = 8; - loff_t in_off = 0; - EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - -// Event FDs may not be used with splice with an offset. -TEST(SpliceTest, ToEventFDOffset) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill with a value. - constexpr int kEventFDSize = 8; - std::vector<char> buf(kEventFDSize); - buf[0] = 1; - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kEventFDSize)); - - int efd; - ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor out_fd(efd); - - // Attempt to splice 8-byte eventfd value to pipe with offset. - // - // This is not allowed because eventfd doesn't support pwrite. - loff_t out_off = 0; - EXPECT_THAT( - splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SpliceTest, ToPipe) { - // Open the input file. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor in_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - - // Fill with some random data. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(lseek(in_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Splice to the pipe. - EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - - // Contents should be equal. - std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); -} - -TEST(SpliceTest, ToPipeOffset) { - // Open the input file. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor in_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - - // Fill with some random data. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Splice to the pipe. - loff_t in_offset = kPageSize / 2; - EXPECT_THAT( - splice(in_fd.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0), - SyscallSucceedsWithValue(kPageSize / 2)); - - // Contents should be equal to only the second part. - std::vector<char> rbuf(kPageSize / 2); - ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize / 2)); - EXPECT_EQ(memcmp(rbuf.data(), buf.data() + (kPageSize / 2), rbuf.size()), 0); -} - -TEST(SpliceTest, FromPipe) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill with some random data. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Open the input file. - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor out_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); - - // Splice to the output file. - EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - - // The offset of the output should be equal to kPageSize. We assert that and - // reset to zero so that we can read the contents and ensure they match. - EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); - - // Contents should be equal. - std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); -} - -TEST(SpliceTest, FromPipeOffset) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill with some random data. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Open the input file. - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor out_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); - - // Splice to the output file. - loff_t out_offset = kPageSize / 2; - EXPECT_THAT( - splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - - // Content should reflect the splice. We write to a specific offset in the - // file, so the internals should now be allocated sparsely. - std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - std::vector<char> zbuf(kPageSize / 2); - memset(zbuf.data(), 0, zbuf.size()); - EXPECT_EQ(memcmp(rbuf.data(), zbuf.data(), zbuf.size()), 0); - EXPECT_EQ(memcmp(rbuf.data() + kPageSize / 2, buf.data(), kPageSize / 2), 0); -} - -TEST(SpliceTest, TwoPipes) { - // Create two new pipes. - int first[2], second[2]; - ASSERT_THAT(pipe(first), SyscallSucceeds()); - const FileDescriptor rfd1(first[0]); - const FileDescriptor wfd1(first[1]); - ASSERT_THAT(pipe(second), SyscallSucceeds()); - const FileDescriptor rfd2(second[0]); - const FileDescriptor wfd2(second[1]); - - // Fill with some random data. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd1.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Splice to the second pipe, using two operations. - EXPECT_THAT( - splice(rfd1.get(), nullptr, wfd2.get(), nullptr, kPageSize / 2, 0), - SyscallSucceedsWithValue(kPageSize / 2)); - EXPECT_THAT( - splice(rfd1.get(), nullptr, wfd2.get(), nullptr, kPageSize / 2, 0), - SyscallSucceedsWithValue(kPageSize / 2)); - - // Content should reflect the splice. - std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); -} - -TEST(SpliceTest, Blocking) { - // Create two new pipes. - int first[2], second[2]; - ASSERT_THAT(pipe(first), SyscallSucceeds()); - const FileDescriptor rfd1(first[0]); - const FileDescriptor wfd1(first[1]); - ASSERT_THAT(pipe(second), SyscallSucceeds()); - const FileDescriptor rfd2(second[0]); - const FileDescriptor wfd2(second[1]); - - // This thread writes to the main pipe. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ScopedThread t([&]() { - ASSERT_THAT(write(wfd1.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - }); - - // Attempt a splice immediately; it should block. - EXPECT_THAT(splice(rfd1.get(), nullptr, wfd2.get(), nullptr, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - - // Thread should be joinable. - t.Join(); - - // Content should reflect the splice. - std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); -} - -TEST(TeeTest, Blocking) { - // Create two new pipes. - int first[2], second[2]; - ASSERT_THAT(pipe(first), SyscallSucceeds()); - const FileDescriptor rfd1(first[0]); - const FileDescriptor wfd1(first[1]); - ASSERT_THAT(pipe(second), SyscallSucceeds()); - const FileDescriptor rfd2(second[0]); - const FileDescriptor wfd2(second[1]); - - // This thread writes to the main pipe. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ScopedThread t([&]() { - ASSERT_THAT(write(wfd1.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - }); - - // Attempt a tee immediately; it should block. - EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - - // Thread should be joinable. - t.Join(); - - // Content should reflect the splice, in both pipes. - std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); - ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); -} - -TEST(TeeTest, BlockingWrite) { - // Create two new pipes. - int first[2], second[2]; - ASSERT_THAT(pipe(first), SyscallSucceeds()); - const FileDescriptor rfd1(first[0]); - const FileDescriptor wfd1(first[1]); - ASSERT_THAT(pipe(second), SyscallSucceeds()); - const FileDescriptor rfd2(second[0]); - const FileDescriptor wfd2(second[1]); - - // Make some data available to be read. - std::vector<char> buf1(kPageSize); - RandomizeBuffer(buf1.data(), buf1.size()); - ASSERT_THAT(write(wfd1.get(), buf1.data(), buf1.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Fill up the write pipe's buffer. - int pipe_size = -1; - ASSERT_THAT(pipe_size = fcntl(wfd2.get(), F_GETPIPE_SZ), SyscallSucceeds()); - std::vector<char> buf2(pipe_size); - ASSERT_THAT(write(wfd2.get(), buf2.data(), buf2.size()), - SyscallSucceedsWithValue(pipe_size)); - - ScopedThread t([&]() { - absl::SleepFor(absl::Milliseconds(100)); - ASSERT_THAT(read(rfd2.get(), buf2.data(), buf2.size()), - SyscallSucceedsWithValue(pipe_size)); - }); - - // Attempt a tee immediately; it should block. - EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); - - // Thread should be joinable. - t.Join(); - - // Content should reflect the tee. - std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(memcmp(rbuf.data(), buf1.data(), kPageSize), 0); -} - -TEST(SpliceTest, NonBlocking) { - // Create two new pipes. - int first[2], second[2]; - ASSERT_THAT(pipe(first), SyscallSucceeds()); - const FileDescriptor rfd1(first[0]); - const FileDescriptor wfd1(first[1]); - ASSERT_THAT(pipe(second), SyscallSucceeds()); - const FileDescriptor rfd2(second[0]); - const FileDescriptor wfd2(second[1]); - - // Splice with no data to back it. - EXPECT_THAT(splice(rfd1.get(), nullptr, wfd2.get(), nullptr, kPageSize, - SPLICE_F_NONBLOCK), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST(TeeTest, NonBlocking) { - // Create two new pipes. - int first[2], second[2]; - ASSERT_THAT(pipe(first), SyscallSucceeds()); - const FileDescriptor rfd1(first[0]); - const FileDescriptor wfd1(first[1]); - ASSERT_THAT(pipe(second), SyscallSucceeds()); - const FileDescriptor rfd2(second[0]); - const FileDescriptor wfd2(second[1]); - - // Splice with no data to back it. - EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, SPLICE_F_NONBLOCK), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST(TeeTest, MultiPage) { - // Create two new pipes. - int first[2], second[2]; - ASSERT_THAT(pipe(first), SyscallSucceeds()); - const FileDescriptor rfd1(first[0]); - const FileDescriptor wfd1(first[1]); - ASSERT_THAT(pipe(second), SyscallSucceeds()); - const FileDescriptor rfd2(second[0]); - const FileDescriptor wfd2(second[1]); - - // Make some data available to be read. - std::vector<char> wbuf(8 * kPageSize); - RandomizeBuffer(wbuf.data(), wbuf.size()); - ASSERT_THAT(write(wfd1.get(), wbuf.data(), wbuf.size()), - SyscallSucceedsWithValue(wbuf.size())); - - // Attempt a tee immediately; it should complete. - EXPECT_THAT(tee(rfd1.get(), wfd2.get(), wbuf.size(), 0), - SyscallSucceedsWithValue(wbuf.size())); - - // Content should reflect the tee. - std::vector<char> rbuf(wbuf.size()); - ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(rbuf.size())); - EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0); - ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(rbuf.size())); - EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0); -} - -TEST(SpliceTest, FromPipeMaxFileSize) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill with some random data. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - // Open the input file. - const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor out_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); - - EXPECT_THAT(ftruncate(out_fd.get(), 13 << 20), SyscallSucceeds()); - EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_END), - SyscallSucceedsWithValue(13 << 20)); - - // Set our file size limit. - sigset_t set; - sigemptyset(&set); - sigaddset(&set, SIGXFSZ); - TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); - rlimit rlim = {}; - rlim.rlim_cur = rlim.rlim_max = (13 << 20); - EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &rlim), SyscallSucceeds()); - - // Splice to the output file. - EXPECT_THAT( - splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3 * kPageSize, 0), - SyscallFailsWithErrno(EFBIG)); - - // Contents should be equal. - std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc deleted file mode 100644 index c951ac3b3..000000000 --- a/test/syscalls/linux/stat.cc +++ /dev/null @@ -1,659 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/statfs.h> -#include <sys/types.h> -#include <unistd.h> - -#include <string> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class StatTest : public FileTest {}; - -TEST_F(StatTest, FstatatAbs) { - struct stat st; - - // Check that the stat works. - EXPECT_THAT(fstatat(AT_FDCWD, test_file_name_.c_str(), &st, 0), - SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(st.st_mode)); -} - -TEST_F(StatTest, FstatatEmptyPath) { - struct stat st; - const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); - - // Check that the stat works. - EXPECT_THAT(fstatat(fd.get(), "", &st, AT_EMPTY_PATH), SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(st.st_mode)); -} - -TEST_F(StatTest, FstatatRel) { - struct stat st; - int dirfd; - auto filename = std::string(Basename(test_file_name_)); - - // Open the temporary directory read-only. - ASSERT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_RDONLY), - SyscallSucceeds()); - - // Check that the stat works. - EXPECT_THAT(fstatat(dirfd, filename.c_str(), &st, 0), SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(st.st_mode)); - close(dirfd); -} - -TEST_F(StatTest, FstatatSymlink) { - struct stat st; - - // Check that the link is followed. - EXPECT_THAT(fstatat(AT_FDCWD, "/proc/self", &st, 0), SyscallSucceeds()); - EXPECT_TRUE(S_ISDIR(st.st_mode)); - EXPECT_FALSE(S_ISLNK(st.st_mode)); - - // Check that the flag works. - EXPECT_THAT(fstatat(AT_FDCWD, "/proc/self", &st, AT_SYMLINK_NOFOLLOW), - SyscallSucceeds()); - EXPECT_TRUE(S_ISLNK(st.st_mode)); - EXPECT_FALSE(S_ISDIR(st.st_mode)); -} - -TEST_F(StatTest, Nlinks) { - TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // Directory is initially empty, it should contain 2 links (one from itself, - // one from "."). - EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(2)); - - // Create a file in the test directory. Files shouldn't increase the link - // count on the base directory. - TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(basedir.path())); - EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(2)); - - // Create subdirectories. This should increase the link count by 1 per - // subdirectory. - TempPath dir1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(basedir.path())); - EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(3)); - TempPath dir2 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(basedir.path())); - EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(4)); - - // Removing directories should reduce the link count. - dir1.reset(); - EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(3)); - dir2.reset(); - EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(2)); - - // Removing files should have no effect on link count. - file1.reset(); - EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(2)); -} - -TEST_F(StatTest, BlocksIncreaseOnWrite) { - struct stat st; - - // Stat the empty file. - ASSERT_THAT(fstat(test_file_fd_.get(), &st), SyscallSucceeds()); - - const int initial_blocks = st.st_blocks; - - // Write to the file, making sure to exceed the block size. - std::vector<char> buf(2 * st.st_blksize, 'a'); - ASSERT_THAT(write(test_file_fd_.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - // Stat the file again, and verify that number of allocated blocks has - // increased. - ASSERT_THAT(fstat(test_file_fd_.get(), &st), SyscallSucceeds()); - EXPECT_GT(st.st_blocks, initial_blocks); -} - -TEST_F(StatTest, PathNotCleaned) { - TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // Create a file in the basedir. - TempPath file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(basedir.path())); - - // Stating the file directly should succeed. - struct stat buf; - EXPECT_THAT(lstat(file.path().c_str(), &buf), SyscallSucceeds()); - - // Try to stat the file using a directory that does not exist followed by - // "..". If the path is cleaned prior to stating (which it should not be) - // then this will succeed. - const std::string bad_path = JoinPath("/does_not_exist/..", file.path()); - EXPECT_THAT(lstat(bad_path.c_str(), &buf), SyscallFailsWithErrno(ENOENT)); -} - -TEST_F(StatTest, PathCanContainDotDot) { - TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath subdir = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(basedir.path())); - const std::string subdir_name = std::string(Basename(subdir.path())); - - // Create a file in the subdir. - TempPath file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(subdir.path())); - const std::string file_name = std::string(Basename(file.path())); - - // Stat the file through a path that includes '..' and '.' but still resolves - // to the file. - const std::string good_path = - JoinPath(basedir.path(), subdir_name, "..", subdir_name, ".", file_name); - struct stat buf; - EXPECT_THAT(lstat(good_path.c_str(), &buf), SyscallSucceeds()); -} - -TEST_F(StatTest, PathCanContainEmptyComponent) { - TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // Create a file in the basedir. - TempPath file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(basedir.path())); - const std::string file_name = std::string(Basename(file.path())); - - // Stat the file through a path that includes an empty component. We have to - // build this ourselves because JoinPath automatically removes empty - // components. - const std::string good_path = absl::StrCat(basedir.path(), "//", file_name); - struct stat buf; - EXPECT_THAT(lstat(good_path.c_str(), &buf), SyscallSucceeds()); -} - -TEST_F(StatTest, TrailingSlashNotCleanedReturnsENOTDIR) { - TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // Create a file in the basedir. - TempPath file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(basedir.path())); - - // Stat the file with an extra "/" on the end of it. Since file is not a - // directory, this should return ENOTDIR. - const std::string bad_path = absl::StrCat(file.path(), "/"); - struct stat buf; - EXPECT_THAT(lstat(bad_path.c_str(), &buf), SyscallFailsWithErrno(ENOTDIR)); -} - -// Test fstatating a symlink directory. -TEST_F(StatTest, FstatatSymlinkDir) { - // Create a directory and symlink to it. - const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - const std::string symlink_to_dir = NewTempAbsPath(); - EXPECT_THAT(symlink(dir.path().c_str(), symlink_to_dir.c_str()), - SyscallSucceeds()); - auto cleanup = Cleanup([&symlink_to_dir]() { - EXPECT_THAT(unlink(symlink_to_dir.c_str()), SyscallSucceeds()); - }); - - // Fstatat the link with AT_SYMLINK_NOFOLLOW should return symlink data. - struct stat st = {}; - EXPECT_THAT( - fstatat(AT_FDCWD, symlink_to_dir.c_str(), &st, AT_SYMLINK_NOFOLLOW), - SyscallSucceeds()); - EXPECT_FALSE(S_ISDIR(st.st_mode)); - EXPECT_TRUE(S_ISLNK(st.st_mode)); - - // Fstatat the link should return dir data. - EXPECT_THAT(fstatat(AT_FDCWD, symlink_to_dir.c_str(), &st, 0), - SyscallSucceeds()); - EXPECT_TRUE(S_ISDIR(st.st_mode)); - EXPECT_FALSE(S_ISLNK(st.st_mode)); -} - -// Test fstatating a symlink directory with trailing slash. -TEST_F(StatTest, FstatatSymlinkDirWithTrailingSlash) { - // Create a directory and symlink to it. - const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string symlink_to_dir = NewTempAbsPath(); - EXPECT_THAT(symlink(dir.path().c_str(), symlink_to_dir.c_str()), - SyscallSucceeds()); - auto cleanup = Cleanup([&symlink_to_dir]() { - EXPECT_THAT(unlink(symlink_to_dir.c_str()), SyscallSucceeds()); - }); - - // Fstatat on the symlink with a trailing slash should return the directory - // data. - struct stat st = {}; - EXPECT_THAT( - fstatat(AT_FDCWD, absl::StrCat(symlink_to_dir, "/").c_str(), &st, 0), - SyscallSucceeds()); - EXPECT_TRUE(S_ISDIR(st.st_mode)); - EXPECT_FALSE(S_ISLNK(st.st_mode)); - - // Fstatat on the symlink with a trailing slash with AT_SYMLINK_NOFOLLOW - // should return the directory data. - // Symlink to directory with trailing slash will ignore AT_SYMLINK_NOFOLLOW. - EXPECT_THAT(fstatat(AT_FDCWD, absl::StrCat(symlink_to_dir, "/").c_str(), &st, - AT_SYMLINK_NOFOLLOW), - SyscallSucceeds()); - EXPECT_TRUE(S_ISDIR(st.st_mode)); - EXPECT_FALSE(S_ISLNK(st.st_mode)); -} - -// Test fstatating a symlink directory with a trailing slash -// should return same stat data with fstatating directory. -TEST_F(StatTest, FstatatSymlinkDirWithTrailingSlashSameInode) { - // Create a directory and symlink to it. - const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // We are going to assert that the symlink inode id is the same as the linked - // dir's inode id. In order for the inode id to be stable across - // save/restore, it must be kept open. The FileDescriptor type will do that - // for us automatically. - auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); - - const std::string symlink_to_dir = NewTempAbsPath(); - EXPECT_THAT(symlink(dir.path().c_str(), symlink_to_dir.c_str()), - SyscallSucceeds()); - auto cleanup = Cleanup([&symlink_to_dir]() { - EXPECT_THAT(unlink(symlink_to_dir.c_str()), SyscallSucceeds()); - }); - - // Fstatat on the symlink with a trailing slash should return the directory - // data. - struct stat st = {}; - EXPECT_THAT(fstatat(AT_FDCWD, absl::StrCat(symlink_to_dir, "/").c_str(), &st, - AT_SYMLINK_NOFOLLOW), - SyscallSucceeds()); - EXPECT_TRUE(S_ISDIR(st.st_mode)); - - // Dir and symlink should point to same inode. - struct stat st_dir = {}; - EXPECT_THAT( - fstatat(AT_FDCWD, dir.path().c_str(), &st_dir, AT_SYMLINK_NOFOLLOW), - SyscallSucceeds()); - EXPECT_EQ(st.st_ino, st_dir.st_ino); -} - -TEST_F(StatTest, LeadingDoubleSlash) { - // Create a file, and make sure we can stat it. - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - struct stat st; - ASSERT_THAT(lstat(file.path().c_str(), &st), SyscallSucceeds()); - - // Now add an extra leading slash. - const std::string double_slash_path = absl::StrCat("/", file.path()); - ASSERT_TRUE(absl::StartsWith(double_slash_path, "//")); - - // We should be able to stat the new path, and it should resolve to the same - // file (same device and inode). - struct stat double_slash_st; - ASSERT_THAT(lstat(double_slash_path.c_str(), &double_slash_st), - SyscallSucceeds()); - EXPECT_EQ(st.st_dev, double_slash_st.st_dev); - EXPECT_EQ(st.st_ino, double_slash_st.st_ino); -} - -// Test that a rename doesn't change the underlying file. -TEST_F(StatTest, StatDoesntChangeAfterRename) { - const TempPath old_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath new_path(NewTempAbsPath()); - - struct stat st_old = {}; - struct stat st_new = {}; - - ASSERT_THAT(stat(old_dir.path().c_str(), &st_old), SyscallSucceeds()); - ASSERT_THAT(rename(old_dir.path().c_str(), new_path.path().c_str()), - SyscallSucceeds()); - ASSERT_THAT(stat(new_path.path().c_str(), &st_new), SyscallSucceeds()); - - EXPECT_EQ(st_old.st_nlink, st_new.st_nlink); - EXPECT_EQ(st_old.st_dev, st_new.st_dev); - EXPECT_EQ(st_old.st_ino, st_new.st_ino); - EXPECT_EQ(st_old.st_mode, st_new.st_mode); - EXPECT_EQ(st_old.st_uid, st_new.st_uid); - EXPECT_EQ(st_old.st_gid, st_new.st_gid); - EXPECT_EQ(st_old.st_size, st_new.st_size); -} - -// Test link counts with a regular file as the child. -TEST_F(StatTest, LinkCountsWithRegularFileChild) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - struct stat st_parent_before = {}; - ASSERT_THAT(stat(dir.path().c_str(), &st_parent_before), SyscallSucceeds()); - EXPECT_EQ(st_parent_before.st_nlink, 2); - - // Adding a regular file doesn't adjust the parent's link count. - const TempPath child = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - - struct stat st_parent_after = {}; - ASSERT_THAT(stat(dir.path().c_str(), &st_parent_after), SyscallSucceeds()); - EXPECT_EQ(st_parent_after.st_nlink, 2); - - // The child should have a single link from the parent. - struct stat st_child = {}; - ASSERT_THAT(stat(child.path().c_str(), &st_child), SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(st_child.st_mode)); - EXPECT_EQ(st_child.st_nlink, 1); - - // Finally unlinking the child should not affect the parent's link count. - ASSERT_THAT(unlink(child.path().c_str()), SyscallSucceeds()); - ASSERT_THAT(stat(dir.path().c_str(), &st_parent_after), SyscallSucceeds()); - EXPECT_EQ(st_parent_after.st_nlink, 2); -} - -// This test verifies that inodes remain around when there is an open fd -// after link count hits 0. -TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { - // Setting the enviornment variable GVISOR_GOFER_UNCACHED to any value - // will prevent this test from running, see the tmpfs lifecycle. - // - // We need to support this because when a file is unlinked and we forward - // the stat to the gofer it would return ENOENT. - const char* uncached_gofer = getenv("GVISOR_GOFER_UNCACHED"); - SKIP_IF(uncached_gofer != nullptr); - - // We don't support saving unlinked files. - const DisableSave ds; - - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - dir.path(), "hello", TempPath::kDefaultFileMode)); - - // The child should have a single link from the parent. - struct stat st_child_before = {}; - ASSERT_THAT(stat(child.path().c_str(), &st_child_before), SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(st_child_before.st_mode)); - EXPECT_EQ(st_child_before.st_nlink, 1); - EXPECT_EQ(st_child_before.st_size, 5); // Hello is 5 bytes. - - // Open the file so we can fstat after unlinking. - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(child.path(), O_RDONLY)); - - // Now a stat should return ENOENT but we should still be able to stat - // via the open fd and fstat. - ASSERT_THAT(unlink(child.path().c_str()), SyscallSucceeds()); - - // Since the file has no more links stat should fail. - struct stat st_child_after = {}; - ASSERT_THAT(stat(child.path().c_str(), &st_child_after), - SyscallFailsWithErrno(ENOENT)); - - // Fstat should still allow us to access the same file via the fd. - struct stat st_child_fd = {}; - ASSERT_THAT(fstat(fd.get(), &st_child_fd), SyscallSucceeds()); - EXPECT_EQ(st_child_before.st_dev, st_child_fd.st_dev); - EXPECT_EQ(st_child_before.st_ino, st_child_fd.st_ino); - EXPECT_EQ(st_child_before.st_mode, st_child_fd.st_mode); - EXPECT_EQ(st_child_before.st_uid, st_child_fd.st_uid); - EXPECT_EQ(st_child_before.st_gid, st_child_fd.st_gid); - EXPECT_EQ(st_child_before.st_size, st_child_fd.st_size); - - // TODO(b/34861058): This isn't ideal but since fstatfs(2) will always return - // OVERLAYFS_SUPER_MAGIC we have no way to know if this fs is backed by a - // gofer which doesn't support links. - EXPECT_TRUE(st_child_fd.st_nlink == 0 || st_child_fd.st_nlink == 1); -} - -// Test link counts with a directory as the child. -TEST_F(StatTest, LinkCountsWithDirChild) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // Before a child is added the two links are "." and the link from the parent. - struct stat st_parent_before = {}; - ASSERT_THAT(stat(dir.path().c_str(), &st_parent_before), SyscallSucceeds()); - EXPECT_EQ(st_parent_before.st_nlink, 2); - - // Create a subdirectory and stat for the parent link counts. - const TempPath sub_dir = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); - - // The three links are ".", the link from the parent, and the link from - // the child as "..". - struct stat st_parent_after = {}; - ASSERT_THAT(stat(dir.path().c_str(), &st_parent_after), SyscallSucceeds()); - EXPECT_EQ(st_parent_after.st_nlink, 3); - - // The child will have 1 link from the parent and 1 link which represents ".". - struct stat st_child = {}; - ASSERT_THAT(stat(sub_dir.path().c_str(), &st_child), SyscallSucceeds()); - EXPECT_TRUE(S_ISDIR(st_child.st_mode)); - EXPECT_EQ(st_child.st_nlink, 2); - - // Finally delete the child dir and the parent link count should return to 2. - ASSERT_THAT(rmdir(sub_dir.path().c_str()), SyscallSucceeds()); - ASSERT_THAT(stat(dir.path().c_str(), &st_parent_after), SyscallSucceeds()); - - // Now we should only have links from the parent and "." since the subdir - // has been removed. - EXPECT_EQ(st_parent_after.st_nlink, 2); -} - -// Test statting a child of a non-directory. -TEST_F(StatTest, ChildOfNonDir) { - // Create a path that has a child of a regular file. - const std::string filename = JoinPath(test_file_name_, "child"); - - // Statting the path should return ENOTDIR. - struct stat st; - EXPECT_THAT(lstat(filename.c_str(), &st), SyscallFailsWithErrno(ENOTDIR)); -} - -// Test lstating a symlink directory. -TEST_F(StatTest, LstatSymlinkDir) { - // Create a directory and symlink to it. - const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string symlink_to_dir = NewTempAbsPath(); - EXPECT_THAT(symlink(dir.path().c_str(), symlink_to_dir.c_str()), - SyscallSucceeds()); - auto cleanup = Cleanup([&symlink_to_dir]() { - EXPECT_THAT(unlink(symlink_to_dir.c_str()), SyscallSucceeds()); - }); - - // Lstat on the symlink should return symlink data. - struct stat st = {}; - ASSERT_THAT(lstat(symlink_to_dir.c_str(), &st), SyscallSucceeds()); - EXPECT_FALSE(S_ISDIR(st.st_mode)); - EXPECT_TRUE(S_ISLNK(st.st_mode)); - - // Lstat on the symlink with a trailing slash should return the directory - // data. - ASSERT_THAT(lstat(absl::StrCat(symlink_to_dir, "/").c_str(), &st), - SyscallSucceeds()); - EXPECT_TRUE(S_ISDIR(st.st_mode)); - EXPECT_FALSE(S_ISLNK(st.st_mode)); -} - -// Verify that we get an ELOOP from too many symbolic links even when there -// are directories in the middle. -TEST_F(StatTest, LstatELOOPPath) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - std::string subdir_base = "subdir"; - ASSERT_THAT(mkdir(JoinPath(dir.path(), subdir_base).c_str(), 0755), - SyscallSucceeds()); - - std::string target = JoinPath(dir.path(), subdir_base, subdir_base); - std::string dst = JoinPath("..", subdir_base); - ASSERT_THAT(symlink(dst.c_str(), target.c_str()), SyscallSucceeds()); - auto cleanup = Cleanup( - [&target]() { EXPECT_THAT(unlink(target.c_str()), SyscallSucceeds()); }); - - // Now build a path which is /subdir/subdir/... repeated many times so that - // we can build a path that is shorter than PATH_MAX but can still cause - // too many symbolic links. Note: Every other subdir is actually a directory - // so we're not in a situation where it's a -> b -> a -> b, where a and b - // are symbolic links. - std::string path = dir.path(); - std::string subdir_append = absl::StrCat("/", subdir_base); - do { - absl::StrAppend(&path, subdir_append); - // Keep appending /subdir until we would overflow PATH_MAX. - } while ((path.size() + subdir_append.size()) < PATH_MAX); - - struct stat s = {}; - ASSERT_THAT(lstat(path.c_str(), &s), SyscallFailsWithErrno(ELOOP)); -} - -// Ensure that inode allocation for anonymous devices work correctly across -// save/restore. In particular, inode numbers should be unique across S/R. -TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) { - // Use sockets as a convenient way to create inodes on an anonymous device. - int fd; - ASSERT_THAT(fd = socket(AF_UNIX, SOCK_STREAM, 0), SyscallSucceeds()); - FileDescriptor fd1(fd); - MaybeSave(); - ASSERT_THAT(fd = socket(AF_UNIX, SOCK_STREAM, 0), SyscallSucceeds()); - FileDescriptor fd2(fd); - - struct stat st1; - struct stat st2; - ASSERT_THAT(fstat(fd1.get(), &st1), SyscallSucceeds()); - ASSERT_THAT(fstat(fd2.get(), &st2), SyscallSucceeds()); - - // The two fds should have different inode numbers. - EXPECT_NE(st2.st_ino, st1.st_ino); - - // Verify again after another S/R cycle. The inode numbers should remain the - // same. - MaybeSave(); - - struct stat st1_after; - struct stat st2_after; - ASSERT_THAT(fstat(fd1.get(), &st1_after), SyscallSucceeds()); - ASSERT_THAT(fstat(fd2.get(), &st2_after), SyscallSucceeds()); - - EXPECT_EQ(st1_after.st_ino, st1.st_ino); - EXPECT_EQ(st2_after.st_ino, st2.st_ino); -} - -#ifndef SYS_statx -#if defined(__x86_64__) -#define SYS_statx 332 -#elif defined(__aarch64__) -#define SYS_statx 291 -#else -#error "Unknown architecture" -#endif -#endif // SYS_statx - -#ifndef STATX_ALL -#define STATX_ALL 0x00000fffU -#endif // STATX_ALL - -// struct kernel_statx_timestamp is a Linux statx_timestamp struct. -struct kernel_statx_timestamp { - int64_t tv_sec; - uint32_t tv_nsec; - int32_t __reserved; -}; - -// struct kernel_statx is a Linux statx struct. Old versions of glibc do not -// expose it. See include/uapi/linux/stat.h -struct kernel_statx { - uint32_t stx_mask; - uint32_t stx_blksize; - uint64_t stx_attributes; - uint32_t stx_nlink; - uint32_t stx_uid; - uint32_t stx_gid; - uint16_t stx_mode; - uint16_t __spare0[1]; - uint64_t stx_ino; - uint64_t stx_size; - uint64_t stx_blocks; - uint64_t stx_attributes_mask; - struct kernel_statx_timestamp stx_atime; - struct kernel_statx_timestamp stx_btime; - struct kernel_statx_timestamp stx_ctime; - struct kernel_statx_timestamp stx_mtime; - uint32_t stx_rdev_major; - uint32_t stx_rdev_minor; - uint32_t stx_dev_major; - uint32_t stx_dev_minor; - uint64_t __spare2[14]; -}; - -int statx(int dirfd, const char* pathname, int flags, unsigned int mask, - struct kernel_statx* statxbuf) { - return syscall(SYS_statx, dirfd, pathname, flags, mask, statxbuf); -} - -TEST_F(StatTest, StatxAbsPath) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && - errno == ENOSYS); - - struct kernel_statx stx; - EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, STATX_ALL, &stx), - SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(stx.stx_mode)); -} - -TEST_F(StatTest, StatxRelPathDirFD) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && - errno == ENOSYS); - - struct kernel_statx stx; - auto const dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY)); - auto filename = std::string(Basename(test_file_name_)); - - EXPECT_THAT(statx(dirfd.get(), filename.c_str(), 0, STATX_ALL, &stx), - SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(stx.stx_mode)); -} - -TEST_F(StatTest, StatxRelPathCwd) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && - errno == ENOSYS); - - ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); - auto filename = std::string(Basename(test_file_name_)); - struct kernel_statx stx; - EXPECT_THAT(statx(AT_FDCWD, filename.c_str(), 0, STATX_ALL, &stx), - SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(stx.stx_mode)); -} - -TEST_F(StatTest, StatxEmptyPath) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && - errno == ENOSYS); - - const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); - struct kernel_statx stx; - EXPECT_THAT(statx(fd.get(), "", AT_EMPTY_PATH, STATX_ALL, &stx), - SyscallSucceeds()); - EXPECT_TRUE(S_ISREG(stx.stx_mode)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/stat_times.cc b/test/syscalls/linux/stat_times.cc deleted file mode 100644 index 68c0bef09..000000000 --- a/test/syscalls/linux/stat_times.cc +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/stat.h> - -#include <tuple> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -using ::testing::IsEmpty; -using ::testing::Not; - -std::tuple<absl::Time, absl::Time, absl::Time> GetTime(const TempPath& file) { - struct stat statbuf = {}; - EXPECT_THAT(stat(file.path().c_str(), &statbuf), SyscallSucceeds()); - - const auto atime = absl::TimeFromTimespec(statbuf.st_atim); - const auto mtime = absl::TimeFromTimespec(statbuf.st_mtim); - const auto ctime = absl::TimeFromTimespec(statbuf.st_ctim); - return std::make_tuple(atime, mtime, ctime); -} - -enum class AtimeEffect { - Unchanged, - Changed, -}; - -enum class MtimeEffect { - Unchanged, - Changed, -}; - -enum class CtimeEffect { - Unchanged, - Changed, -}; - -// Tests that fn modifies the atime/mtime/ctime of path as specified. -void CheckTimes(const TempPath& path, std::function<void()> fn, - AtimeEffect atime_effect, MtimeEffect mtime_effect, - CtimeEffect ctime_effect) { - absl::Time atime, mtime, ctime; - std::tie(atime, mtime, ctime) = GetTime(path); - - // FIXME(b/132819225): gVisor filesystem timestamps inconsistently use the - // internal or host clock, which may diverge slightly. Allow some slack on - // times to account for the difference. - // - // Here we sleep for 1s so that initial creation of path doesn't fall within - // the before slack window. - absl::SleepFor(absl::Seconds(1)); - - const absl::Time before = absl::Now() - absl::Seconds(1); - - // Perform the op. - fn(); - - const absl::Time after = absl::Now() + absl::Seconds(1); - - absl::Time atime2, mtime2, ctime2; - std::tie(atime2, mtime2, ctime2) = GetTime(path); - - if (atime_effect == AtimeEffect::Changed) { - EXPECT_LE(before, atime2); - EXPECT_GE(after, atime2); - EXPECT_GT(atime2, atime); - } else { - EXPECT_EQ(atime2, atime); - } - - if (mtime_effect == MtimeEffect::Changed) { - EXPECT_LE(before, mtime2); - EXPECT_GE(after, mtime2); - EXPECT_GT(mtime2, mtime); - } else { - EXPECT_EQ(mtime2, mtime); - } - - if (ctime_effect == CtimeEffect::Changed) { - EXPECT_LE(before, ctime2); - EXPECT_GE(after, ctime2); - EXPECT_GT(ctime2, ctime); - } else { - EXPECT_EQ(ctime2, ctime); - } -} - -// File creation time is reflected in atime, mtime, and ctime. -TEST(StatTimesTest, FileCreation) { - const DisableSave ds; // Timing-related test. - - // Get a time for when the file is created. - // - // FIXME(b/132819225): See above. - const absl::Time before = absl::Now() - absl::Seconds(1); - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const absl::Time after = absl::Now() + absl::Seconds(1); - - absl::Time atime, mtime, ctime; - std::tie(atime, mtime, ctime) = GetTime(file); - - EXPECT_LE(before, atime); - EXPECT_LE(before, mtime); - EXPECT_LE(before, ctime); - EXPECT_GE(after, atime); - EXPECT_GE(after, mtime); - EXPECT_GE(after, ctime); -} - -// Calling chmod on a file changes ctime. -TEST(StatTimesTest, FileChmod) { - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - auto fn = [&] { - EXPECT_THAT(chmod(file.path().c_str(), 0666), SyscallSucceeds()); - }; - CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Unchanged, - CtimeEffect::Changed); -} - -// Renaming a file changes ctime. -TEST(StatTimesTest, FileRename) { - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - const std::string newpath = NewTempAbsPath(); - - auto fn = [&] { - ASSERT_THAT(rename(file.release().c_str(), newpath.c_str()), - SyscallSucceeds()); - file.reset(newpath); - }; - CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Unchanged, - CtimeEffect::Changed); -} - -// Renaming a file changes ctime, even with an open FD. -// -// NOTE(b/132732387): This is a regression test for fs/gofer failing to update -// cached ctime. -TEST(StatTimesTest, FileRenameOpenFD) { - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // Holding an FD shouldn't affect behavior. - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); - - const std::string newpath = NewTempAbsPath(); - - // FIXME(b/132814682): Restore fails with an uncached gofer and an open FD - // across rename. - // - // N.B. The logic here looks backwards because it isn't possible to - // conditionally disable save, only conditionally re-enable it. - DisableSave ds; - if (!getenv("GVISOR_GOFER_UNCACHED")) { - ds.reset(); - } - - auto fn = [&] { - ASSERT_THAT(rename(file.release().c_str(), newpath.c_str()), - SyscallSucceeds()); - file.reset(newpath); - }; - CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Unchanged, - CtimeEffect::Changed); -} - -// Calling utimes on a file changes ctime and the time that we ask to change -// (atime to now in this case). -TEST(StatTimesTest, FileUtimes) { - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - auto fn = [&] { - const struct timespec ts[2] = {{0, UTIME_NOW}, {0, UTIME_OMIT}}; - ASSERT_THAT(utimensat(AT_FDCWD, file.path().c_str(), ts, 0), - SyscallSucceeds()); - }; - CheckTimes(file, fn, AtimeEffect::Changed, MtimeEffect::Unchanged, - CtimeEffect::Changed); -} - -// Truncating a file changes mtime and ctime. -TEST(StatTimesTest, FileTruncate) { - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "yaaass", 0666)); - - auto fn = [&] { - EXPECT_THAT(truncate(file.path().c_str(), 0), SyscallSucceeds()); - }; - CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Changed, - CtimeEffect::Changed); -} - -// Writing a file changes mtime and ctime. -TEST(StatTimesTest, FileWrite) { - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "yaaass", 0666)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0)); - - auto fn = [&] { - const std::string contents = "all the single dollars"; - EXPECT_THAT(WriteFd(fd.get(), contents.data(), contents.size()), - SyscallSucceeds()); - }; - CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Changed, - CtimeEffect::Changed); -} - -// Reading a file changes atime. -TEST(StatTimesTest, FileRead) { - const std::string contents = "bills bills bills"; - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), contents, 0666)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY, 0)); - - auto fn = [&] { - char buf[20]; - ASSERT_THAT(ReadFd(fd.get(), buf, sizeof(buf)), - SyscallSucceedsWithValue(contents.size())); - }; - CheckTimes(file, fn, AtimeEffect::Changed, MtimeEffect::Unchanged, - CtimeEffect::Unchanged); -} - -// Listing files in a directory changes atime. -TEST(StatTimesTest, DirList) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - - auto fn = [&] { - const auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), false)); - EXPECT_THAT(contents, Not(IsEmpty())); - }; - CheckTimes(dir, fn, AtimeEffect::Changed, MtimeEffect::Unchanged, - CtimeEffect::Unchanged); -} - -// Creating a file in a directory changes mtime and ctime. -TEST(StatTimesTest, DirCreateFile) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - TempPath file; - auto fn = [&] { - file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - }; - CheckTimes(dir, fn, AtimeEffect::Unchanged, MtimeEffect::Changed, - CtimeEffect::Changed); -} - -// Creating a directory in a directory changes mtime and ctime. -TEST(StatTimesTest, DirCreateDir) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - TempPath dir2; - auto fn = [&] { - dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); - }; - CheckTimes(dir, fn, AtimeEffect::Unchanged, MtimeEffect::Changed, - CtimeEffect::Changed); -} - -// Removing a file from a directory changes mtime and ctime. -TEST(StatTimesTest, DirRemoveFile) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - auto fn = [&] { file.reset(); }; - CheckTimes(dir, fn, AtimeEffect::Unchanged, MtimeEffect::Changed, - CtimeEffect::Changed); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc deleted file mode 100644 index aca51d30f..000000000 --- a/test/syscalls/linux/statfs.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/statfs.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(StatfsTest, CannotStatBadPath) { - auto temp_file = NewTempAbsPathInDir("/tmp"); - - struct statfs st; - EXPECT_THAT(statfs(temp_file.c_str(), &st), SyscallFailsWithErrno(ENOENT)); -} - -TEST(StatfsTest, InternalTmpfs) { - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - struct statfs st; - EXPECT_THAT(statfs(temp_file.path().c_str(), &st), SyscallSucceeds()); -} - -TEST(StatfsTest, InternalDevShm) { - struct statfs st; - EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds()); -} - -TEST(StatfsTest, NameLen) { - struct statfs st; - EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds()); - - // This assumes that /dev/shm is tmpfs. - EXPECT_EQ(st.f_namelen, NAME_MAX); -} - -TEST(FstatfsTest, CannotStatBadFd) { - struct statfs st; - EXPECT_THAT(fstatfs(-1, &st), SyscallFailsWithErrno(EBADF)); -} - -TEST(FstatfsTest, InternalTmpfs) { - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDONLY)); - - struct statfs st; - EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds()); -} - -TEST(FstatfsTest, InternalDevShm) { - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/shm", O_RDONLY)); - - struct statfs st; - EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc deleted file mode 100644 index 7e73325bf..000000000 --- a/test/syscalls/linux/sticky.cc +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <grp.h> -#include <sys/prctl.h> -#include <sys/types.h> -#include <unistd.h> - -#include <vector> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -ABSL_FLAG(int32_t, scratch_uid, 65534, "first scratch UID"); -ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID"); - -namespace gvisor { -namespace testing { - -namespace { - -TEST(StickyTest, StickyBitPermDenied) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); - - // Drop privileges and change IDs only in child thread, or else this parent - // thread won't be able to open some log files after the test ends. - ScopedThread([&] { - // Drop privileges. - if (HaveCapability(CAP_FOWNER).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, false)); - } - - // Change EUID and EGID. - EXPECT_THAT( - syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1), - SyscallSucceeds()); - EXPECT_THAT( - syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1), - SyscallSucceeds()); - - EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM)); - }); -} - -TEST(StickyTest, StickyBitSameUID) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); - - // Drop privileges and change IDs only in child thread, or else this parent - // thread won't be able to open some log files after the test ends. - ScopedThread([&] { - // Drop privileges. - if (HaveCapability(CAP_FOWNER).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, false)); - } - - // Change EGID. - EXPECT_THAT( - syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1), - SyscallSucceeds()); - - // We still have the same EUID. - EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds()); - }); -} - -TEST(StickyTest, StickyBitCapFOWNER) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); - - // Drop privileges and change IDs only in child thread, or else this parent - // thread won't be able to open some log files after the test ends. - ScopedThread([&] { - // Set PR_SET_KEEPCAPS. - EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds()); - - // Change EUID and EGID. - EXPECT_THAT( - syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1), - SyscallSucceeds()); - EXPECT_THAT( - syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1), - SyscallSucceeds()); - - EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true)); - EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds()); - }); -} -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc deleted file mode 100644 index 03ee1250d..000000000 --- a/test/syscalls/linux/symlink.cc +++ /dev/null @@ -1,377 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <string.h> -#include <unistd.h> - -#include <string> - -#include "gtest/gtest.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -mode_t FilePermission(const std::string& path) { - struct stat buf = {0}; - TEST_CHECK(lstat(path.c_str(), &buf) == 0); - return buf.st_mode & 0777; -} - -// Test that name collisions are checked on the new link path, not the source -// path. Regression test for b/31782115. -TEST(SymlinkTest, CanCreateSymlinkWithCachedSourceDirent) { - const std::string srcname = NewTempAbsPath(); - const std::string newname = NewTempAbsPath(); - const std::string basedir = std::string(Dirname(srcname)); - ASSERT_EQ(basedir, Dirname(newname)); - - ASSERT_THAT(chdir(basedir.c_str()), SyscallSucceeds()); - - // Open the source node to cause the underlying dirent to be cached. It will - // remain cached while we have the file open. - int fd; - ASSERT_THAT(fd = open(srcname.c_str(), O_CREAT | O_RDWR, 0666), - SyscallSucceeds()); - FileDescriptor fd_closer(fd); - - // Attempt to create a symlink. If the bug exists, this will fail since the - // dirent link creation code will check for a name collision on the source - // link name. - EXPECT_THAT(symlink(std::string(Basename(srcname)).c_str(), - std::string(Basename(newname)).c_str()), - SyscallSucceeds()); -} - -TEST(SymlinkTest, CanCreateSymlinkFile) { - const std::string oldname = NewTempAbsPath(); - const std::string newname = NewTempAbsPath(); - - int fd; - ASSERT_THAT(fd = open(oldname.c_str(), O_CREAT | O_RDWR, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(symlink(oldname.c_str(), newname.c_str()), SyscallSucceeds()); - EXPECT_EQ(FilePermission(newname), 0777); - - auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink(newname)); - EXPECT_EQ(oldname, link); - - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); - EXPECT_THAT(unlink(oldname.c_str()), SyscallSucceeds()); -} - -TEST(SymlinkTest, CanCreateSymlinkDir) { - const std::string olddir = NewTempAbsPath(); - const std::string newdir = NewTempAbsPath(); - - EXPECT_THAT(mkdir(olddir.c_str(), 0777), SyscallSucceeds()); - EXPECT_THAT(symlink(olddir.c_str(), newdir.c_str()), SyscallSucceeds()); - EXPECT_EQ(FilePermission(newdir), 0777); - - auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink(newdir)); - EXPECT_EQ(olddir, link); - - EXPECT_THAT(unlink(newdir.c_str()), SyscallSucceeds()); - - ASSERT_THAT(rmdir(olddir.c_str()), SyscallSucceeds()); -} - -TEST(SymlinkTest, CannotCreateSymlinkInReadOnlyDir) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - const std::string olddir = NewTempAbsPath(); - ASSERT_THAT(mkdir(olddir.c_str(), 0444), SyscallSucceeds()); - - const std::string newdir = NewTempAbsPathInDir(olddir); - EXPECT_THAT(symlink(olddir.c_str(), newdir.c_str()), - SyscallFailsWithErrno(EACCES)); - - ASSERT_THAT(rmdir(olddir.c_str()), SyscallSucceeds()); -} - -TEST(SymlinkTest, CannotSymlinkOverExistingFile) { - const auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const auto newfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - EXPECT_THAT(symlink(oldfile.path().c_str(), newfile.path().c_str()), - SyscallFailsWithErrno(EEXIST)); -} - -TEST(SymlinkTest, CannotSymlinkOverExistingDir) { - const auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const auto newdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - EXPECT_THAT(symlink(oldfile.path().c_str(), newdir.path().c_str()), - SyscallFailsWithErrno(EEXIST)); -} - -TEST(SymlinkTest, OldnameIsEmpty) { - const std::string newname = NewTempAbsPath(); - EXPECT_THAT(symlink("", newname.c_str()), SyscallFailsWithErrno(ENOENT)); -} - -TEST(SymlinkTest, OldnameIsDangling) { - const std::string newname = NewTempAbsPath(); - EXPECT_THAT(symlink("/dangling", newname.c_str()), SyscallSucceeds()); - - // This is required for S/R random save tests, which pre-run this test - // in the same TEST_TMPDIR, which means that we need to clean it for any - // operations exclusively creating files, like symlink above. - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); -} - -TEST(SymlinkTest, NewnameCannotExist) { - const std::string newname = - JoinPath(GetAbsoluteTestTmpdir(), "thisdoesnotexist", "foo"); - EXPECT_THAT(symlink("/thisdoesnotmatter", newname.c_str()), - SyscallFailsWithErrno(ENOENT)); -} - -TEST(SymlinkTest, CanEvaluateLink) { - const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - - // We are going to assert that the symlink inode id is the same as the linked - // file's inode id. In order for the inode id to be stable across - // save/restore, it must be kept open. The FileDescriptor type will do that - // for us automatically. - auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - struct stat file_st; - EXPECT_THAT(fstat(fd.get(), &file_st), SyscallSucceeds()); - - const std::string link = NewTempAbsPath(); - EXPECT_THAT(symlink(file.path().c_str(), link.c_str()), SyscallSucceeds()); - EXPECT_EQ(FilePermission(link), 0777); - - auto linkfd = ASSERT_NO_ERRNO_AND_VALUE(Open(link.c_str(), O_RDWR)); - struct stat link_st; - EXPECT_THAT(fstat(linkfd.get(), &link_st), SyscallSucceeds()); - - // Check that in fact newname points to the file we expect. - EXPECT_EQ(file_st.st_dev, link_st.st_dev); - EXPECT_EQ(file_st.st_ino, link_st.st_ino); -} - -TEST(SymlinkTest, TargetIsNotMapped) { - const std::string oldname = NewTempAbsPath(); - const std::string newname = NewTempAbsPath(); - - int fd; - // Create the target so that when we read the link, it exists. - ASSERT_THAT(fd = open(oldname.c_str(), O_CREAT | O_RDWR, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - // Create a symlink called newname that points to oldname. - EXPECT_THAT(symlink(oldname.c_str(), newname.c_str()), SyscallSucceeds()); - - std::vector<char> buf(1024); - int linksize; - // Read the link and assert that the oldname is still the same. - EXPECT_THAT(linksize = readlink(newname.c_str(), buf.data(), 1024), - SyscallSucceeds()); - EXPECT_EQ(0, strncmp(oldname.c_str(), buf.data(), linksize)); - - EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds()); - EXPECT_THAT(unlink(oldname.c_str()), SyscallSucceeds()); -} - -TEST(SymlinkTest, PreadFromSymlink) { - std::string name = NewTempAbsPath(); - int fd; - ASSERT_THAT(fd = open(name.c_str(), O_CREAT, 0644), SyscallSucceeds()); - ASSERT_THAT(close(fd), SyscallSucceeds()); - - std::string linkname = NewTempAbsPath(); - ASSERT_THAT(symlink(name.c_str(), linkname.c_str()), SyscallSucceeds()); - - ASSERT_THAT(fd = open(linkname.c_str(), O_RDONLY), SyscallSucceeds()); - - char buf[1024]; - EXPECT_THAT(pread64(fd, buf, 1024, 0), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(unlink(name.c_str()), SyscallSucceeds()); - EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds()); -} - -TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - - int dirfd; - ASSERT_THAT(dirfd = open(dir.path().c_str(), O_DIRECTORY, 0), - SyscallSucceeds()); - - const DisableSave ds; // Permissions are dropped. - EXPECT_THAT(fchmod(dirfd, 0), SyscallSucceeds()); - - std::string basename = std::string(Basename(file.path())); - EXPECT_THAT(symlinkat("/dangling", dirfd, basename.c_str()), - SyscallFailsWithErrno(EACCES)); - EXPECT_THAT(close(dirfd), SyscallSucceeds()); -} - -TEST(SymlinkTest, ReadlinkAtDegradedPermissions_NoRandomSave) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string oldpath = NewTempAbsPathInDir(dir.path()); - const std::string oldbase = std::string(Basename(oldpath)); - ASSERT_THAT(symlink("/dangling", oldpath.c_str()), SyscallSucceeds()); - - int dirfd; - EXPECT_THAT(dirfd = open(dir.path().c_str(), O_DIRECTORY, 0), - SyscallSucceeds()); - - const DisableSave ds; // Permissions are dropped. - EXPECT_THAT(fchmod(dirfd, 0), SyscallSucceeds()); - - char buf[1024]; - int linksize; - EXPECT_THAT(linksize = readlinkat(dirfd, oldbase.c_str(), buf, 1024), - SyscallFailsWithErrno(EACCES)); - EXPECT_THAT(close(dirfd), SyscallSucceeds()); -} - -TEST(SymlinkTest, ChmodSymlink) { - auto target = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const std::string newpath = NewTempAbsPath(); - ASSERT_THAT(symlink(target.path().c_str(), newpath.c_str()), - SyscallSucceeds()); - EXPECT_EQ(FilePermission(newpath), 0777); - EXPECT_THAT(chmod(newpath.c_str(), 0666), SyscallSucceeds()); - EXPECT_EQ(FilePermission(newpath), 0777); -} - -class ParamSymlinkTest : public ::testing::TestWithParam<std::string> {}; - -// Test that creating an existing symlink with creat will create the target. -TEST_P(ParamSymlinkTest, CreatLinkCreatesTarget) { - const std::string target = GetParam(); - const std::string linkpath = NewTempAbsPath(); - - ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds()); - - int fd; - EXPECT_THAT(fd = creat(linkpath.c_str(), 0666), SyscallSucceeds()); - ASSERT_THAT(close(fd), SyscallSucceeds()); - - ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); - struct stat st; - EXPECT_THAT(stat(target.c_str(), &st), SyscallSucceeds()); - - ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds()); - ASSERT_THAT(unlink(target.c_str()), SyscallSucceeds()); -} - -// Test that opening an existing symlink with O_CREAT will create the target. -TEST_P(ParamSymlinkTest, OpenLinkCreatesTarget) { - const std::string target = GetParam(); - const std::string linkpath = NewTempAbsPath(); - - ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds()); - - int fd; - EXPECT_THAT(fd = open(linkpath.c_str(), O_CREAT, 0666), SyscallSucceeds()); - ASSERT_THAT(close(fd), SyscallSucceeds()); - - ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); - struct stat st; - EXPECT_THAT(stat(target.c_str(), &st), SyscallSucceeds()); - - ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds()); - ASSERT_THAT(unlink(target.c_str()), SyscallSucceeds()); -} - -// Test that opening a self-symlink with O_CREAT will fail with ELOOP. -TEST_P(ParamSymlinkTest, CreateExistingSelfLink) { - ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); - - const std::string linkpath = GetParam(); - ASSERT_THAT(symlink(linkpath.c_str(), linkpath.c_str()), SyscallSucceeds()); - - EXPECT_THAT(open(linkpath.c_str(), O_CREAT, 0666), - SyscallFailsWithErrno(ELOOP)); - - ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds()); -} - -// Test that opening a file that is a symlink to its parent directory fails -// with ELOOP. -TEST_P(ParamSymlinkTest, CreateExistingParentLink) { - ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); - - const std::string linkpath = GetParam(); - const std::string target = JoinPath(linkpath, "child"); - ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds()); - - EXPECT_THAT(open(linkpath.c_str(), O_CREAT, 0666), - SyscallFailsWithErrno(ELOOP)); - - ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds()); -} - -// Test that opening an existing symlink with O_CREAT|O_EXCL will fail with -// EEXIST. -TEST_P(ParamSymlinkTest, OpenLinkExclFails) { - const std::string target = GetParam(); - const std::string linkpath = NewTempAbsPath(); - - ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds()); - - EXPECT_THAT(open(linkpath.c_str(), O_CREAT | O_EXCL, 0666), - SyscallFailsWithErrno(EEXIST)); - - ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds()); -} - -// Test that opening an existing symlink with O_CREAT|O_NOFOLLOW will fail with -// ELOOP. -TEST_P(ParamSymlinkTest, OpenLinkNoFollowFails) { - const std::string target = GetParam(); - const std::string linkpath = NewTempAbsPath(); - - ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds()); - - EXPECT_THAT(open(linkpath.c_str(), O_CREAT | O_NOFOLLOW, 0666), - SyscallFailsWithErrno(ELOOP)); - - ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds()); -} - -INSTANTIATE_TEST_SUITE_P(AbsAndRelTarget, ParamSymlinkTest, - ::testing::Values(NewTempAbsPath(), NewTempRelPath())); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sync.cc b/test/syscalls/linux/sync.cc deleted file mode 100644 index 8aa2525a9..000000000 --- a/test/syscalls/linux/sync.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdio.h> -#include <sys/syscall.h> -#include <unistd.h> - -#include <string> - -#include "gtest/gtest.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(SyncTest, SyncEverything) { - ASSERT_THAT(syscall(SYS_sync), SyscallSucceeds()); -} - -TEST(SyncTest, SyncFileSytem) { - int fd; - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_THAT(fd = open(f.path().c_str(), O_RDONLY), SyscallSucceeds()); - EXPECT_THAT(syncfs(fd), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(SyncTest, SyncFromPipe) { - int pipes[2]; - EXPECT_THAT(pipe(pipes), SyscallSucceeds()); - EXPECT_THAT(syncfs(pipes[0]), SyscallSucceeds()); - EXPECT_THAT(syncfs(pipes[1]), SyscallSucceeds()); - EXPECT_THAT(close(pipes[0]), SyscallSucceeds()); - EXPECT_THAT(close(pipes[1]), SyscallSucceeds()); -} - -TEST(SyncTest, CannotSyncFileSytemAtBadFd) { - EXPECT_THAT(syncfs(-1), SyscallFailsWithErrno(EBADF)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sync_file_range.cc b/test/syscalls/linux/sync_file_range.cc deleted file mode 100644 index 36cc42043..000000000 --- a/test/syscalls/linux/sync_file_range.cc +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <stdio.h> -#include <unistd.h> - -#include <string> - -#include "gtest/gtest.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(SyncFileRangeTest, TempFileSucceeds) { - auto tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path(), O_RDWR)); - constexpr char data[] = "some data to sync"; - int fd = f.get(); - - EXPECT_THAT(write(fd, data, sizeof(data)), - SyscallSucceedsWithValue(sizeof(data))); - EXPECT_THAT(sync_file_range(fd, 0, 0, SYNC_FILE_RANGE_WRITE), - SyscallSucceeds()); - EXPECT_THAT(sync_file_range(fd, 0, 0, 0), SyscallSucceeds()); - EXPECT_THAT( - sync_file_range(fd, 0, 0, - SYNC_FILE_RANGE_WRITE | SYNC_FILE_RANGE_WAIT_AFTER | - SYNC_FILE_RANGE_WAIT_BEFORE), - SyscallSucceeds()); - EXPECT_THAT(sync_file_range( - fd, 0, 1, SYNC_FILE_RANGE_WRITE | SYNC_FILE_RANGE_WAIT_AFTER), - SyscallSucceeds()); - EXPECT_THAT(sync_file_range( - fd, 1, 0, SYNC_FILE_RANGE_WRITE | SYNC_FILE_RANGE_WAIT_AFTER), - SyscallSucceeds()); -} - -TEST(SyncFileRangeTest, CannotSyncFileRangeOnUnopenedFd) { - auto tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path(), O_RDWR)); - constexpr char data[] = "some data to sync"; - int fd = f.get(); - - EXPECT_THAT(write(fd, data, sizeof(data)), - SyscallSucceedsWithValue(sizeof(data))); - - pid_t pid = fork(); - if (pid == 0) { - f.reset(); - - // fd is now invalid. - TEST_CHECK(sync_file_range(fd, 0, 0, SYNC_FILE_RANGE_WRITE) == -1); - TEST_PCHECK(errno == EBADF); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - - int status = 0; - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(WEXITSTATUS(status), 0); -} - -TEST(SyncFileRangeTest, BadArgs) { - auto tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path(), O_RDWR)); - int fd = f.get(); - - EXPECT_THAT(sync_file_range(fd, -1, 0, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(sync_file_range(fd, 0, -1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(sync_file_range(fd, 8912, INT64_MAX - 4096, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(SyncFileRangeTest, CannotSyncFileRangeWithWaitBefore) { - auto tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - auto f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path(), O_RDWR)); - constexpr char data[] = "some data to sync"; - int fd = f.get(); - - EXPECT_THAT(write(fd, data, sizeof(data)), - SyscallSucceedsWithValue(sizeof(data))); - if (IsRunningOnGvisor()) { - EXPECT_THAT(sync_file_range(fd, 0, 0, SYNC_FILE_RANGE_WAIT_BEFORE), - SyscallFailsWithErrno(ENOSYS)); - EXPECT_THAT( - sync_file_range(fd, 0, 0, - SYNC_FILE_RANGE_WAIT_BEFORE | SYNC_FILE_RANGE_WRITE), - SyscallFailsWithErrno(ENOSYS)); - } -} - -} // namespace -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sysinfo.cc b/test/syscalls/linux/sysinfo.cc deleted file mode 100644 index 1a71256da..000000000 --- a/test/syscalls/linux/sysinfo.cc +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This is a very simple sanity test to validate that the sysinfo syscall is -// supported by gvisor and returns sane values. -#include <sys/syscall.h> -#include <sys/sysinfo.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(SysinfoTest, SysinfoIsCallable) { - struct sysinfo ignored = {}; - EXPECT_THAT(syscall(SYS_sysinfo, &ignored), SyscallSucceedsWithValue(0)); -} - -TEST(SysinfoTest, EfaultProducedOnBadAddress) { - // Validate that we return EFAULT when a bad address is provided. - // specified by man 2 sysinfo - EXPECT_THAT(syscall(SYS_sysinfo, nullptr), SyscallFailsWithErrno(EFAULT)); -} - -TEST(SysinfoTest, TotalRamSaneValue) { - struct sysinfo s = {}; - EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0)); - EXPECT_GT(s.totalram, 0); -} - -TEST(SysinfoTest, MemunitSet) { - struct sysinfo s = {}; - EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0)); - EXPECT_GE(s.mem_unit, 1); -} - -TEST(SysinfoTest, UptimeSaneValue) { - struct sysinfo s = {}; - EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0)); - EXPECT_GE(s.uptime, 0); -} - -TEST(SysinfoTest, UptimeIncreasingValue) { - struct sysinfo s = {}; - EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0)); - absl::SleepFor(absl::Seconds(2)); - struct sysinfo s2 = {}; - EXPECT_THAT(sysinfo(&s2), SyscallSucceedsWithValue(0)); - EXPECT_LT(s.uptime, s2.uptime); -} - -TEST(SysinfoTest, FreeRamSaneValue) { - struct sysinfo s = {}; - EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0)); - EXPECT_GT(s.freeram, 0); - EXPECT_LT(s.freeram, s.totalram); -} - -TEST(SysinfoTest, NumProcsSaneValue) { - struct sysinfo s = {}; - EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0)); - EXPECT_GT(s.procs, 0); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/syslog.cc b/test/syscalls/linux/syslog.cc deleted file mode 100644 index 9a7407d96..000000000 --- a/test/syscalls/linux/syslog.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/klog.h> -#include <sys/syscall.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr int SYSLOG_ACTION_READ_ALL = 3; -constexpr int SYSLOG_ACTION_SIZE_BUFFER = 10; - -int Syslog(int type, char* buf, int len) { - return syscall(__NR_syslog, type, buf, len); -} - -// Only SYSLOG_ACTION_SIZE_BUFFER and SYSLOG_ACTION_READ_ALL are implemented in -// gVisor. - -TEST(Syslog, Size) { - EXPECT_THAT(Syslog(SYSLOG_ACTION_SIZE_BUFFER, nullptr, 0), SyscallSucceeds()); -} - -TEST(Syslog, ReadAll) { - // There might not be anything to read, so we can't check the write count. - char buf[100]; - EXPECT_THAT(Syslog(SYSLOG_ACTION_READ_ALL, buf, sizeof(buf)), - SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/sysret.cc b/test/syscalls/linux/sysret.cc deleted file mode 100644 index 819fa655a..000000000 --- a/test/syscalls/linux/sysret.cc +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests to verify that the behavior of linux and gvisor matches when -// 'sysret' returns to bad (aka non-canonical) %rip or %rsp. -#include <sys/ptrace.h> -#include <sys/user.h> - -#include "gtest/gtest.h" -#include "test/util/logging.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr uint64_t kNonCanonicalRip = 0xCCCC000000000000; -constexpr uint64_t kNonCanonicalRsp = 0xFFFF000000000000; - -class SysretTest : public ::testing::Test { - protected: - struct user_regs_struct regs_; - pid_t child_; - - void SetUp() override { - pid_t pid = fork(); - - // Child. - if (pid == 0) { - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0); - MaybeSave(); - TEST_PCHECK(raise(SIGSTOP) == 0); - MaybeSave(); - _exit(0); - } - - // Parent. - int status; - ASSERT_THAT(pid, SyscallSucceeds()); // Might still be < 0. - ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP); - ASSERT_THAT(ptrace(PTRACE_GETREGS, pid, 0, ®s_), SyscallSucceeds()); - - child_ = pid; - } - - void Detach() { - ASSERT_THAT(ptrace(PTRACE_DETACH, child_, 0, 0), SyscallSucceeds()); - } - - void SetRip(uint64_t newrip) { - regs_.rip = newrip; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, ®s_), SyscallSucceeds()); - } - - void SetRsp(uint64_t newrsp) { - regs_.rsp = newrsp; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, ®s_), SyscallSucceeds()); - } - - // Wait waits for the child pid and returns the exit status. - int Wait() { - int status; - while (true) { - int rval = wait4(child_, &status, 0, NULL); - if (rval < 0) { - return rval; - } - if (rval == child_) { - return status; - } - } - } -}; - -TEST_F(SysretTest, JustDetach) { - Detach(); - int status = Wait(); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << "status = " << status; -} - -TEST_F(SysretTest, BadRip) { - SetRip(kNonCanonicalRip); - Detach(); - int status = Wait(); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) - << "status = " << status; -} - -TEST_F(SysretTest, BadRsp) { - SetRsp(kNonCanonicalRsp); - Detach(); - int status = Wait(); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGBUS) - << "status = " << status; -} -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc deleted file mode 100644 index d9c1ac0e1..000000000 --- a/test/syscalls/linux/tcp_socket.cc +++ /dev/null @@ -1,1387 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <poll.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <unistd.h> - -#include <limits> -#include <vector> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - addr.ss_family = family; - switch (family) { - case AF_INET: - reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = - htonl(INADDR_LOOPBACK); - break; - case AF_INET6: - reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = - in6addr_loopback; - break; - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } - return addr; -} - -// Fixture for tests parameterized by the address family to use (AF_INET and -// AF_INET6) when creating sockets. -class TcpSocketTest : public ::testing::TestWithParam<int> { - protected: - // Creates three sockets that will be used by test cases -- a listener, one - // that connects, and the accepted one. - void SetUp() override; - - // Closes the sockets created by SetUp(). - void TearDown() override; - - // Listening socket. - int listener_ = -1; - - // Socket connected via connect(). - int s_ = -1; - - // Socket connected via accept(). - int t_ = -1; - - // Initial size of the send buffer. - int sendbuf_size_ = -1; -}; - -void TcpSocketTest::SetUp() { - ASSERT_THAT(listener_ = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP), - SyscallSucceeds()); - - ASSERT_THAT(s_ = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP), - SyscallSucceeds()); - - // Initialize address to the loopback one. - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t addrlen = sizeof(addr); - - // Bind to some port then start listening. - ASSERT_THAT( - bind(listener_, reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - ASSERT_THAT(listen(listener_, SOMAXCONN), SyscallSucceeds()); - - // Get the address we're listening on, then connect to it. We need to do this - // because we're allowing the stack to pick a port for us. - ASSERT_THAT(getsockname(listener_, reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - - ASSERT_THAT(RetryEINTR(connect)(s_, reinterpret_cast<struct sockaddr*>(&addr), - addrlen), - SyscallSucceeds()); - - // Get the initial send buffer size. - socklen_t optlen = sizeof(sendbuf_size_); - ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &sendbuf_size_, &optlen), - SyscallSucceeds()); - - // Accept the connection. - ASSERT_THAT(t_ = RetryEINTR(accept)(listener_, nullptr, nullptr), - SyscallSucceeds()); -} - -void TcpSocketTest::TearDown() { - EXPECT_THAT(close(listener_), SyscallSucceeds()); - if (s_ >= 0) { - EXPECT_THAT(close(s_), SyscallSucceeds()); - } - if (t_ >= 0) { - EXPECT_THAT(close(t_), SyscallSucceeds()); - } -} - -TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) { - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t addrlen = sizeof(addr); - - ASSERT_THAT( - connect(s_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(EISCONN)); - ASSERT_THAT( - connect(t_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(EISCONN)); -} - -TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) { - EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); - EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds()); - absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT. - EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) { - EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); - EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); - absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2. - EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); -} - -TEST_P(TcpSocketTest, DataCoalesced) { - char buf[10]; - - // Write in two steps. - ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf) / 2), - SyscallSucceedsWithValue(sizeof(buf) / 2)); - ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf) / 2), - SyscallSucceedsWithValue(sizeof(buf) / 2)); - - // Allow stack to process both packets. - absl::SleepFor(absl::Seconds(1)); - - // Read in one shot. - EXPECT_THAT(RetryEINTR(recv)(t_, buf, sizeof(buf), 0), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(TcpSocketTest, SenderAddressIgnored) { - char buf[3]; - ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - memset(&addr, 0, sizeof(addr)); - - ASSERT_THAT( - RetryEINTR(recvfrom)(t_, buf, sizeof(buf), 0, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceedsWithValue(3)); - - // Check that addr remains zeroed-out. - const char* ptr = reinterpret_cast<char*>(&addr); - for (size_t i = 0; i < sizeof(addr); i++) { - EXPECT_EQ(ptr[i], 0); - } -} - -TEST_P(TcpSocketTest, SenderAddressIgnoredOnPeek) { - char buf[3]; - ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - memset(&addr, 0, sizeof(addr)); - - ASSERT_THAT( - RetryEINTR(recvfrom)(t_, buf, sizeof(buf), MSG_PEEK, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceedsWithValue(3)); - - // Check that addr remains zeroed-out. - const char* ptr = reinterpret_cast<char*>(&addr); - for (size_t i = 0; i < sizeof(addr); i++) { - EXPECT_EQ(ptr[i], 0); - } -} - -TEST_P(TcpSocketTest, SendtoAddressIgnored) { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - addr.ss_family = GetParam(); // FIXME(b/63803955) - - char data = '\0'; - EXPECT_THAT( - RetryEINTR(sendto)(s_, &data, sizeof(data), 0, - reinterpret_cast<sockaddr*>(&addr), sizeof(addr)), - SyscallSucceedsWithValue(1)); -} - -TEST_P(TcpSocketTest, WritevZeroIovec) { - // 2 bytes just to be safe and have vecs[1] not point to something random - // (even though length is 0). - char buf[2]; - char recv_buf[1]; - - // Construct a vec where the final vector is of length 0. - iovec vecs[2] = {}; - vecs[0].iov_base = buf; - vecs[0].iov_len = 1; - vecs[1].iov_base = buf + 1; - vecs[1].iov_len = 0; - - EXPECT_THAT(RetryEINTR(writev)(s_, vecs, 2), SyscallSucceedsWithValue(1)); - - EXPECT_THAT(RetryEINTR(recv)(t_, recv_buf, 1, 0), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(memcmp(recv_buf, buf, 1), 0); -} - -TEST_P(TcpSocketTest, ZeroWriteAllowed) { - char buf[3]; - // Send a zero length packet. - ASSERT_THAT(RetryEINTR(write)(s_, buf, 0), SyscallSucceedsWithValue(0)); - // Verify that there is no packet available. - EXPECT_THAT(RetryEINTR(recv)(t_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Test that a non-blocking write with a buffer that is larger than the send -// buffer size will not actually write the whole thing at once. Regression test -// for b/64438887. -TEST_P(TcpSocketTest, NonblockingLargeWrite) { - // Set the FD to O_NONBLOCK. - int opts; - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - opts |= O_NONBLOCK; - ASSERT_THAT(fcntl(s_, F_SETFL, opts), SyscallSucceeds()); - - // Allocate a buffer three times the size of the send buffer. We do this with - // a vector to avoid allocating on the stack. - int size = 3 * sendbuf_size_; - std::vector<char> buf(size); - - // Try to write the whole thing. - int n; - ASSERT_THAT(n = RetryEINTR(write)(s_, buf.data(), size), SyscallSucceeds()); - - // We should have written something, but not the whole thing. - EXPECT_GT(n, 0); - EXPECT_LT(n, size); -} - -// Test that a blocking write with a buffer that is larger than the send buffer -// will block until the entire buffer is sent. -TEST_P(TcpSocketTest, BlockingLargeWrite_NoRandomSave) { - // Allocate a buffer three times the size of the send buffer on the heap. We - // do this as a vector to avoid allocating on the stack. - int size = 3 * sendbuf_size_; - std::vector<char> writebuf(size); - - // Start reading the response in a loop. - int read_bytes = 0; - ScopedThread t([this, &read_bytes]() { - // Avoid interrupting the blocking write in main thread. - const DisableSave ds; - - // Take ownership of the FD so that we close it on failure. This will - // unblock the blocking write below. - FileDescriptor fd(t_); - t_ = -1; - - char readbuf[2500] = {}; - int n = -1; - while (n != 0) { - ASSERT_THAT(n = RetryEINTR(read)(fd.get(), &readbuf, sizeof(readbuf)), - SyscallSucceeds()); - read_bytes += n; - } - }); - - // Try to write the whole thing. - int n; - ASSERT_THAT(n = WriteFd(s_, writebuf.data(), size), SyscallSucceeds()); - - // We should have written the whole thing. - EXPECT_EQ(n, size); - EXPECT_THAT(close(s_), SyscallSucceedsWithValue(0)); - s_ = -1; - t.Join(); - - // We should have read the whole thing. - EXPECT_EQ(read_bytes, size); -} - -// Test that a send with MSG_DONTWAIT flag and buffer that larger than the send -// buffer size will not write the whole thing. -TEST_P(TcpSocketTest, LargeSendDontWait) { - // Allocate a buffer three times the size of the send buffer. We do this on - // with a vector to avoid allocating on the stack. - int size = 3 * sendbuf_size_; - std::vector<char> buf(size); - - // Try to write the whole thing with MSG_DONTWAIT flag, which can - // return a partial write. - int n; - ASSERT_THAT(n = RetryEINTR(send)(s_, buf.data(), size, MSG_DONTWAIT), - SyscallSucceeds()); - - // We should have written something, but not the whole thing. - EXPECT_GT(n, 0); - EXPECT_LT(n, size); -} - -// Test that a send on a non-blocking socket with a buffer that larger than the -// send buffer will not write the whole thing at once. -TEST_P(TcpSocketTest, NonblockingLargeSend) { - // Set the FD to O_NONBLOCK. - int opts; - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - opts |= O_NONBLOCK; - ASSERT_THAT(fcntl(s_, F_SETFL, opts), SyscallSucceeds()); - - // Allocate a buffer three times the size of the send buffer. We do this on - // with a vector to avoid allocating on the stack. - int size = 3 * sendbuf_size_; - std::vector<char> buf(size); - - // Try to write the whole thing. - int n; - ASSERT_THAT(n = RetryEINTR(send)(s_, buf.data(), size, 0), SyscallSucceeds()); - - // We should have written something, but not the whole thing. - EXPECT_GT(n, 0); - EXPECT_LT(n, size); -} - -// Same test as above, but calls send instead of write. -TEST_P(TcpSocketTest, BlockingLargeSend_NoRandomSave) { - // Allocate a buffer three times the size of the send buffer. We do this on - // with a vector to avoid allocating on the stack. - int size = 3 * sendbuf_size_; - std::vector<char> writebuf(size); - - // Start reading the response in a loop. - int read_bytes = 0; - ScopedThread t([this, &read_bytes]() { - // Avoid interrupting the blocking write in main thread. - const DisableSave ds; - - // Take ownership of the FD so that we close it on failure. This will - // unblock the blocking write below. - FileDescriptor fd(t_); - t_ = -1; - - char readbuf[2500] = {}; - int n = -1; - while (n != 0) { - ASSERT_THAT(n = RetryEINTR(read)(fd.get(), &readbuf, sizeof(readbuf)), - SyscallSucceeds()); - read_bytes += n; - } - }); - - // Try to send the whole thing. - int n; - ASSERT_THAT(n = SendFd(s_, writebuf.data(), size, 0), SyscallSucceeds()); - - // We should have written the whole thing. - EXPECT_EQ(n, size); - EXPECT_THAT(close(s_), SyscallSucceedsWithValue(0)); - s_ = -1; - t.Join(); - - // We should have read the whole thing. - EXPECT_EQ(read_bytes, size); -} - -// Test that polling on a socket with a full send buffer will block. -TEST_P(TcpSocketTest, PollWithFullBufferBlocks) { - // Set the FD to O_NONBLOCK. - int opts; - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - opts |= O_NONBLOCK; - ASSERT_THAT(fcntl(s_, F_SETFL, opts), SyscallSucceeds()); - - // Set TCP_NODELAY, which will cause linux to fill the receive buffer from the - // send buffer as quickly as possibly. This way we can fill up both buffers - // faster. - constexpr int tcp_nodelay_flag = 1; - ASSERT_THAT(setsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay_flag, - sizeof(tcp_nodelay_flag)), - SyscallSucceeds()); - - // Set a 256KB send/receive buffer. - int buf_sz = 1 << 18; - EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &buf_sz, sizeof(buf_sz)), - SyscallSucceedsWithValue(0)); - - // Create a large buffer that will be used for sending. - std::vector<char> buf(1 << 16); - - // Write until we receive an error. - while (RetryEINTR(send)(s_, buf.data(), buf.size(), 0) != -1) { - // Sleep to give linux a chance to move data from the send buffer to the - // receive buffer. - usleep(10000); // 10ms. - } - // The last error should have been EWOULDBLOCK. - ASSERT_EQ(errno, EWOULDBLOCK); - - // Now polling on the FD with a timeout should return 0 corresponding to no - // FDs ready. - struct pollfd poll_fd = {s_, POLLOUT, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), SyscallSucceedsWithValue(0)); -} - -TEST_P(TcpSocketTest, MsgTrunc) { - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(t_, received_data, sizeof(received_data) / 2, MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - - // Check that we didn't get anything. - char zeros[sizeof(received_data)] = {}; - EXPECT_EQ(0, memcmp(zeros, received_data, sizeof(received_data))); -} - -// MSG_CTRUNC is a return flag but linux allows it to be set on input flags -// without returning an error. -TEST_P(TcpSocketTest, MsgTruncWithCtrunc) { - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data)] = {}; - ASSERT_THAT(RetryEINTR(recv)(t_, received_data, sizeof(received_data) / 2, - MSG_TRUNC | MSG_CTRUNC), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - - // Check that we didn't get anything. - char zeros[sizeof(received_data)] = {}; - EXPECT_EQ(0, memcmp(zeros, received_data, sizeof(received_data))); -} - -// This test will verify that MSG_CTRUNC doesn't do anything when specified -// on input. -TEST_P(TcpSocketTest, MsgTruncWithCtruncOnly) { - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data)] = {}; - ASSERT_THAT(RetryEINTR(recv)(t_, received_data, sizeof(received_data) / 2, - MSG_CTRUNC), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - - // Since MSG_CTRUNC here had no affect, it should not behave like MSG_TRUNC. - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2)); -} - -TEST_P(TcpSocketTest, MsgTruncLargeSize) { - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data) * 2] = {}; - ASSERT_THAT( - RetryEINTR(recv)(t_, received_data, sizeof(received_data), MSG_TRUNC), - SyscallSucceedsWithValue(sizeof(sent_data))); - - // Check that we didn't get anything. - char zeros[sizeof(received_data)] = {}; - EXPECT_EQ(0, memcmp(zeros, received_data, sizeof(received_data))); -} - -TEST_P(TcpSocketTest, MsgTruncPeek) { - char sent_data[512]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - char received_data[sizeof(sent_data)] = {}; - ASSERT_THAT(RetryEINTR(recv)(t_, received_data, sizeof(received_data) / 2, - MSG_TRUNC | MSG_PEEK), - SyscallSucceedsWithValue(sizeof(sent_data) / 2)); - - // Check that we didn't get anything. - char zeros[sizeof(received_data)] = {}; - EXPECT_EQ(0, memcmp(zeros, received_data, sizeof(received_data))); - - // Check that we can still get all of the data. - ASSERT_THAT(RetryEINTR(recv)(t_, received_data, sizeof(received_data), 0), - SyscallSucceedsWithValue(sizeof(sent_data))); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(TcpSocketTest, NoDelayDefault) { - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -TEST_P(TcpSocketTest, SetNoDelay) { - ASSERT_THAT( - setsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); - - ASSERT_THAT(setsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &kSockOptOff, - sizeof(kSockOptOff)), - SyscallSucceeds()); - - EXPECT_THAT(getsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -#ifndef TCP_INQ -#define TCP_INQ 36 -#endif - -TEST_P(TcpSocketTest, TcpInqSetSockOpt) { - char buf[1024]; - ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // TCP_INQ is disabled by default. - int val = -1; - socklen_t slen = sizeof(val); - EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen), - SyscallSucceedsWithValue(0)); - ASSERT_EQ(val, 0); - - // Try to set TCP_INQ. - val = 1; - EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)), - SyscallSucceedsWithValue(0)); - val = -1; - slen = sizeof(val); - EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen), - SyscallSucceedsWithValue(0)); - ASSERT_EQ(val, 1); - - // Try to unset TCP_INQ. - val = 0; - EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)), - SyscallSucceedsWithValue(0)); - val = -1; - slen = sizeof(val); - EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen), - SyscallSucceedsWithValue(0)); - ASSERT_EQ(val, 0); -} - -TEST_P(TcpSocketTest, TcpInq) { - char buf[1024]; - // Write more than one TCP segment. - int size = sizeof(buf); - int kChunk = sizeof(buf) / 4; - for (int i = 0; i < size; i += kChunk) { - ASSERT_THAT(RetryEINTR(write)(s_, buf, kChunk), - SyscallSucceedsWithValue(kChunk)); - } - - int val = 1; - kChunk = sizeof(buf) / 2; - EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)), - SyscallSucceedsWithValue(0)); - - // Wait when all data will be in the received queue. - while (true) { - ASSERT_THAT(ioctl(t_, TIOCINQ, &size), SyscallSucceeds()); - if (size == sizeof(buf)) { - break; - } - absl::SleepFor(absl::Milliseconds(10)); - } - - struct msghdr msg = {}; - std::vector<char> control(CMSG_SPACE(sizeof(int))); - size = sizeof(buf); - struct iovec iov; - for (int i = 0; size != 0; i += kChunk) { - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - iov.iov_base = buf; - iov.iov_len = kChunk; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - ASSERT_THAT(RetryEINTR(recvmsg)(t_, &msg, 0), - SyscallSucceedsWithValue(kChunk)); - size -= kChunk; - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_TCP); - ASSERT_EQ(cmsg->cmsg_type, TCP_INQ); - - int inq = 0; - memcpy(&inq, CMSG_DATA(cmsg), sizeof(int)); - ASSERT_EQ(inq, size); - } -} - -TEST_P(TcpSocketTest, Tiocinq) { - char buf[1024]; - size_t size = sizeof(buf); - ASSERT_THAT(RetryEINTR(write)(s_, buf, size), SyscallSucceedsWithValue(size)); - - uint32_t seed = time(nullptr); - const size_t max_chunk = size / 10; - while (size > 0) { - size_t chunk = (rand_r(&seed) % max_chunk) + 1; - ssize_t read = RetryEINTR(recvfrom)(t_, buf, chunk, 0, nullptr, nullptr); - ASSERT_THAT(read, SyscallSucceeds()); - size -= read; - - int inq = 0; - ASSERT_THAT(ioctl(t_, TIOCINQ, &inq), SyscallSucceeds()); - ASSERT_EQ(inq, size); - } -} - -TEST_P(TcpSocketTest, TcpSCMPriority) { - char buf[1024]; - ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - int val = 1; - EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_TIMESTAMP, &val, sizeof(val)), - SyscallSucceedsWithValue(0)); - - struct msghdr msg = {}; - std::vector<char> control( - CMSG_SPACE(sizeof(struct timeval) + CMSG_SPACE(sizeof(int)))); - struct iovec iov; - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - ASSERT_THAT(RetryEINTR(recvmsg)(t_, &msg, 0), - SyscallSucceedsWithValue(sizeof(buf))); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - // TODO(b/78348848): SO_TIMESTAMP isn't implemented for TCP sockets. - if (!IsRunningOnGvisor() || cmsg->cmsg_level == SOL_SOCKET) { - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); - - cmsg = CMSG_NXTHDR(&msg, cmsg); - ASSERT_NE(cmsg, nullptr); - } - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_TCP); - ASSERT_EQ(cmsg->cmsg_type, TCP_INQ); - - int inq = 0; - memcpy(&inq, CMSG_DATA(cmsg), sizeof(int)); - ASSERT_EQ(inq, 0); - - cmsg = CMSG_NXTHDR(&msg, cmsg); - ASSERT_EQ(cmsg, nullptr); -} - -INSTANTIATE_TEST_SUITE_P(AllInetTests, TcpSocketTest, - ::testing::Values(AF_INET, AF_INET6)); - -// Fixture for tests parameterized by address family that don't want the fixture -// to do things. -using SimpleTcpSocketTest = ::testing::TestWithParam<int>; - -TEST_P(SimpleTcpSocketTest, SendUnconnected) { - int fd; - ASSERT_THAT(fd = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP), - SyscallSucceeds()); - FileDescriptor sock_fd(fd); - - char data = '\0'; - EXPECT_THAT(RetryEINTR(send)(fd, &data, sizeof(data), 0), - SyscallFailsWithErrno(EPIPE)); -} - -TEST_P(SimpleTcpSocketTest, SendtoWithoutAddressUnconnected) { - int fd; - ASSERT_THAT(fd = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP), - SyscallSucceeds()); - FileDescriptor sock_fd(fd); - - char data = '\0'; - EXPECT_THAT(RetryEINTR(sendto)(fd, &data, sizeof(data), 0, nullptr, 0), - SyscallFailsWithErrno(EPIPE)); -} - -TEST_P(SimpleTcpSocketTest, SendtoWithAddressUnconnected) { - int fd; - ASSERT_THAT(fd = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP), - SyscallSucceeds()); - FileDescriptor sock_fd(fd); - - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - char data = '\0'; - EXPECT_THAT( - RetryEINTR(sendto)(fd, &data, sizeof(data), 0, - reinterpret_cast<sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(EPIPE)); -} - -TEST_P(SimpleTcpSocketTest, GetPeerNameUnconnected) { - int fd; - ASSERT_THAT(fd = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP), - SyscallSucceeds()); - FileDescriptor sock_fd(fd); - - sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(TcpSocketTest, FullBuffer) { - // Set both FDs to be blocking. - int flags = 0; - ASSERT_THAT(flags = fcntl(s_, F_GETFL), SyscallSucceeds()); - EXPECT_THAT(fcntl(s_, F_SETFL, flags & ~O_NONBLOCK), SyscallSucceeds()); - flags = 0; - ASSERT_THAT(flags = fcntl(t_, F_GETFL), SyscallSucceeds()); - EXPECT_THAT(fcntl(t_, F_SETFL, flags & ~O_NONBLOCK), SyscallSucceeds()); - - // 2500 was chosen as a small value that can be set on Linux. - int set_snd = 2500; - EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &set_snd, sizeof(set_snd)), - SyscallSucceedsWithValue(0)); - int get_snd = -1; - socklen_t get_snd_len = sizeof(get_snd); - EXPECT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &get_snd, &get_snd_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_snd_len, sizeof(get_snd)); - EXPECT_GT(get_snd, 0); - - // 2500 was chosen as a small value that can be set on Linux and gVisor. - int set_rcv = 2500; - EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &set_rcv, sizeof(set_rcv)), - SyscallSucceedsWithValue(0)); - int get_rcv = -1; - socklen_t get_rcv_len = sizeof(get_rcv); - EXPECT_THAT(getsockopt(t_, SOL_SOCKET, SO_RCVBUF, &get_rcv, &get_rcv_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_rcv_len, sizeof(get_rcv)); - EXPECT_GE(get_rcv, 2500); - - // Quick sanity test. - EXPECT_LT(get_snd + get_rcv, 2500 * IOV_MAX); - - char data[2500] = {}; - std::vector<struct iovec> iovecs; - for (int i = 0; i < IOV_MAX; i++) { - struct iovec iov = {}; - iov.iov_base = data; - iov.iov_len = sizeof(data); - iovecs.push_back(iov); - } - ScopedThread t([this, &iovecs]() { - int result = -1; - EXPECT_THAT(result = RetryEINTR(writev)(s_, iovecs.data(), iovecs.size()), - SyscallSucceeds()); - EXPECT_GT(result, 1); - EXPECT_LT(result, sizeof(data) * iovecs.size()); - }); - - char recv = 0; - EXPECT_THAT(RetryEINTR(read)(t_, &recv, 1), SyscallSucceedsWithValue(1)); - EXPECT_THAT(close(t_), SyscallSucceedsWithValue(0)); - t_ = -1; -} - -TEST_P(TcpSocketTest, PollAfterShutdown) { - ScopedThread client_thread([this]() { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0)); - struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), - SyscallSucceedsWithValue(1)); - }); - - EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0)); - struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), - SyscallSucceedsWithValue(1)); -} - -TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) { - // Initialize address to the loopback one. - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t addrlen = sizeof(addr); - - const FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - // Set the FD to O_NONBLOCK. - int opts; - ASSERT_THAT(opts = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - opts |= O_NONBLOCK; - ASSERT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds()); - - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(EINPROGRESS)); - - // Now polling on the FD with a timeout should return 0 corresponding to no - // FDs ready. - struct pollfd poll_fd = {s.get(), POLLOUT, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), - SyscallSucceedsWithValue(1)); - - int err; - socklen_t optlen = sizeof(err); - ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), - SyscallSucceeds()); - - EXPECT_EQ(err, ECONNREFUSED); -} - -TEST_P(SimpleTcpSocketTest, NonBlockingConnect) { - const FileDescriptor listener = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - // Initialize address to the loopback one. - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t addrlen = sizeof(addr); - - // Bind to some port then start listening. - ASSERT_THAT( - bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds()); - - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - // Set the FD to O_NONBLOCK. - int opts; - ASSERT_THAT(opts = fcntl(s.get(), F_GETFL), SyscallSucceeds()); - opts |= O_NONBLOCK; - ASSERT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds()); - - ASSERT_THAT(getsockname(listener.get(), - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(EINPROGRESS)); - - int t; - ASSERT_THAT(t = RetryEINTR(accept)(listener.get(), nullptr, nullptr), - SyscallSucceeds()); - - // Now polling on the FD with a timeout should return 0 corresponding to no - // FDs ready. - struct pollfd poll_fd = {s.get(), POLLOUT, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), - SyscallSucceedsWithValue(1)); - - int err; - socklen_t optlen = sizeof(err); - ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), - SyscallSucceeds()); - - EXPECT_EQ(err, 0); - - EXPECT_THAT(close(t), SyscallSucceeds()); -} - -TEST_P(SimpleTcpSocketTest, NonBlockingConnectRemoteClose) { - const FileDescriptor listener = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - // Initialize address to the loopback one. - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t addrlen = sizeof(addr); - - // Bind to some port then start listening. - ASSERT_THAT( - bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds()); - - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - - ASSERT_THAT(getsockname(listener.get(), - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(EINPROGRESS)); - - int t; - ASSERT_THAT(t = RetryEINTR(accept)(listener.get(), nullptr, nullptr), - SyscallSucceeds()); - - EXPECT_THAT(close(t), SyscallSucceeds()); - - // Now polling on the FD with a timeout should return 0 corresponding to no - // FDs ready. - struct pollfd poll_fd = {s.get(), POLLOUT, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), - SyscallSucceedsWithValue(1)); - - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(EISCONN)); -} - -// Test that we get an ECONNREFUSED with a blocking socket when no one is -// listening on the other end. -TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - // Initialize address to the loopback one. - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t addrlen = sizeof(addr); - - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(ECONNREFUSED)); - - // Avoiding triggering save in destructor of s. - EXPECT_THAT(close(s.release()), SyscallSucceeds()); -} - -// Test that connecting to a non-listening port and thus receiving a RST is -// handled appropriately by the socket - the port that the socket was bound to -// is released and the expected error is returned. -TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) { - // Create a socket that is known to not be listening. As is it bound but not - // listening, when another socket connects to the port, it will refuse.. - FileDescriptor bound_s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - sockaddr_storage bound_addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t bound_addrlen = sizeof(bound_addr); - - ASSERT_THAT( - bind(bound_s.get(), reinterpret_cast<struct sockaddr*>(&bound_addr), - bound_addrlen), - SyscallSucceeds()); - - // Get the addresses the socket is bound to because the port is chosen by the - // stack. - ASSERT_THAT(getsockname(bound_s.get(), - reinterpret_cast<struct sockaddr*>(&bound_addr), - &bound_addrlen), - SyscallSucceeds()); - - // Create, initialize, and bind the socket that is used to test connecting to - // the non-listening port. - FileDescriptor client_s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - // Initialize client address to the loopback one. - sockaddr_storage client_addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t client_addrlen = sizeof(client_addr); - - ASSERT_THAT( - bind(client_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr), - client_addrlen), - SyscallSucceeds()); - - ASSERT_THAT(getsockname(client_s.get(), - reinterpret_cast<struct sockaddr*>(&client_addr), - &client_addrlen), - SyscallSucceeds()); - - // Now the test: connect to the bound but not listening socket with the - // client socket. The bound socket should return a RST and cause the client - // socket to return an error and clean itself up immediately. - // The error being ECONNREFUSED diverges with RFC 793, page 37, but does what - // Linux does. - ASSERT_THAT(connect(client_s.get(), - reinterpret_cast<const struct sockaddr*>(&bound_addr), - bound_addrlen), - SyscallFailsWithErrno(ECONNREFUSED)); - - FileDescriptor new_s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - // Test binding to the address from the client socket. This should be okay - // if it was dropped correctly. - ASSERT_THAT( - bind(new_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr), - client_addrlen), - SyscallSucceeds()); - - // Attempt #2, with the new socket and reused addr our connect should fail in - // the same way as before, not with an EADDRINUSE. - ASSERT_THAT(connect(client_s.get(), - reinterpret_cast<const struct sockaddr*>(&bound_addr), - bound_addrlen), - SyscallFailsWithErrno(ECONNREFUSED)); -} - -// Test that we get an ECONNREFUSED with a nonblocking socket. -TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) { - FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - - // Initialize address to the loopback one. - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t addrlen = sizeof(addr); - - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(EINPROGRESS)); - - // We don't need to specify any events to get POLLHUP or POLLERR as these - // are added before the poll. - struct pollfd poll_fd = {s.get(), /*events=*/0, 0}; - EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 1000), SyscallSucceedsWithValue(1)); - - // The ECONNREFUSED should cause us to be woken up with POLLHUP. - EXPECT_NE(poll_fd.revents & (POLLHUP | POLLERR), 0); - - // Avoiding triggering save in destructor of s. - EXPECT_THAT(close(s.release()), SyscallSucceeds()); -} - -// Test that setting a supported congestion control algorithm succeeds for an -// unconnected TCP socket -TEST_P(SimpleTcpSocketTest, SetCongestionControlSucceedsForSupported) { - // This is Linux's net/tcp.h TCP_CA_NAME_MAX. - const int kTcpCaNameMax = 16; - - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - { - const char kSetCC[kTcpCaNameMax] = "reno"; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, - strlen(kSetCC)), - SyscallSucceedsWithValue(0)); - - char got_cc[kTcpCaNameMax]; - memset(got_cc, '1', sizeof(got_cc)); - socklen_t optlen = sizeof(got_cc); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - // We ignore optlen here as the linux kernel sets optlen to the lower of the - // size of the buffer passed in or kTcpCaNameMax and not the length of the - // congestion control algorithm's actual name. - EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax))); - } - { - const char kSetCC[kTcpCaNameMax] = "cubic"; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, - strlen(kSetCC)), - SyscallSucceedsWithValue(0)); - - char got_cc[kTcpCaNameMax]; - memset(got_cc, '1', sizeof(got_cc)); - socklen_t optlen = sizeof(got_cc); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - // We ignore optlen here as the linux kernel sets optlen to the lower of the - // size of the buffer passed in or kTcpCaNameMax and not the length of the - // congestion control algorithm's actual name. - EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax))); - } -} - -// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is -// consistent between linux and gvisor when the passed in buffer is smaller than -// kTcpCaNameMax. -TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionShortReadBuffer) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - { - // Verify that getsockopt/setsockopt work with buffers smaller than - // kTcpCaNameMax. - const char kSetCC[] = "cubic"; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, - strlen(kSetCC)), - SyscallSucceedsWithValue(0)); - - char got_cc[sizeof(kSetCC)]; - socklen_t optlen = sizeof(got_cc); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(sizeof(got_cc), optlen); - EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc))); - } -} - -// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is -// consistent between linux and gvisor when the passed in buffer is larger than -// kTcpCaNameMax. -TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionLargeReadBuffer) { - // This is Linux's net/tcp.h TCP_CA_NAME_MAX. - const int kTcpCaNameMax = 16; - - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - { - // Verify that getsockopt works with buffers larger than - // kTcpCaNameMax. - const char kSetCC[] = "cubic"; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, - strlen(kSetCC)), - SyscallSucceedsWithValue(0)); - - char got_cc[kTcpCaNameMax + 5]; - socklen_t optlen = sizeof(got_cc); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - // Linux copies the minimum of kTcpCaNameMax or the length of the passed in - // buffer and sets optlen to the number of bytes actually copied - // irrespective of the actual length of the congestion control name. - EXPECT_EQ(kTcpCaNameMax, optlen); - EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); - } -} - -// Test that setting an unsupported congestion control algorithm fails for an -// unconnected TCP socket. -TEST_P(SimpleTcpSocketTest, SetCongestionControlFailsForUnsupported) { - // This is Linux's net/tcp.h TCP_CA_NAME_MAX. - const int kTcpCaNameMax = 16; - - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - char old_cc[kTcpCaNameMax]; - socklen_t optlen = sizeof(old_cc); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &old_cc, &optlen), - SyscallSucceedsWithValue(0)); - - const char kSetCC[] = "invalid_ca_kSetCC"; - ASSERT_THAT( - setsockopt(s.get(), SOL_TCP, TCP_CONGESTION, &kSetCC, strlen(kSetCC)), - SyscallFailsWithErrno(ENOENT)); - - char got_cc[kTcpCaNameMax]; - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), - SyscallSucceedsWithValue(0)); - // We ignore optlen here as the linux kernel sets optlen to the lower of the - // size of the buffer passed in or kTcpCaNameMax and not the length of the - // congestion control algorithm's actual name. - EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(kTcpCaNameMax))); -} - -TEST_P(SimpleTcpSocketTest, MaxSegDefault) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - constexpr int kDefaultMSS = 536; - int tcp_max_seg; - socklen_t optlen = sizeof(tcp_max_seg); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg, &optlen), - SyscallSucceedsWithValue(0)); - - EXPECT_EQ(kDefaultMSS, tcp_max_seg); - EXPECT_EQ(sizeof(tcp_max_seg), optlen); -} - -TEST_P(SimpleTcpSocketTest, SetMaxSeg) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - constexpr int kDefaultMSS = 536; - constexpr int kTCPMaxSeg = 1024; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &kTCPMaxSeg, - sizeof(kTCPMaxSeg)), - SyscallSucceedsWithValue(0)); - - // Linux actually never returns the user_mss value. It will always return the - // default MSS value defined above for an unconnected socket and always return - // the actual current MSS for a connected one. - int optval; - socklen_t optlen = sizeof(optval); - ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &optval, &optlen), - SyscallSucceedsWithValue(0)); - - EXPECT_EQ(kDefaultMSS, optval); - EXPECT_EQ(sizeof(optval), optlen); -} - -TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - { - constexpr int tcp_max_seg = 10; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg, - sizeof(tcp_max_seg)), - SyscallFailsWithErrno(EINVAL)); - } - { - constexpr int tcp_max_seg = 75000; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg, - sizeof(tcp_max_seg)), - SyscallFailsWithErrno(EINVAL)); - } -} - -TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - { - constexpr int kTCPUserTimeout = -1; - EXPECT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &kTCPUserTimeout, sizeof(kTCPUserTimeout)), - SyscallFailsWithErrno(EINVAL)); - } - - // kTCPUserTimeout is in milliseconds. - constexpr int kTCPUserTimeout = 100; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &kTCPUserTimeout, sizeof(kTCPUserTimeout)), - SyscallSucceedsWithValue(0)); - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kTCPUserTimeout); -} - -TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - // -ve TCP_DEFER_ACCEPT is same as setting it to zero. - constexpr int kNeg = -1; - EXPECT_THAT( - setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)), - SyscallSucceeds()); - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 0); -} - -TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 0); -} - -TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) { - FileDescriptor s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - // kTCPDeferAccept is in seconds. - // NOTE: linux translates seconds to # of retries and back from - // #of retries to seconds. Which means only certain values - // translate back exactly. That's why we use 3 here, a value of - // 5 will result in us getting back 7 instead of 5 in the - // getsockopt. - constexpr int kTCPDeferAccept = 3; - ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, - &kTCPDeferAccept, sizeof(kTCPDeferAccept)), - SyscallSucceeds()); - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), - SyscallSucceeds()); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kTCPDeferAccept); -} - -TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) { - auto s = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - char buf[1]; - EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN)); - EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) { - auto s = ASSERT_NO_ERRNO_AND_VALUE( - Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - sockaddr_storage addr = - ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); - socklen_t addrlen = sizeof(addr); - - RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), - addrlen); - int buf_sz = 1 << 18; - EXPECT_THAT( - setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), - SyscallSucceedsWithValue(0)); -} - -INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, - ::testing::Values(AF_INET, AF_INET6)); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/tgkill.cc b/test/syscalls/linux/tgkill.cc deleted file mode 100644 index 80acae5de..000000000 --- a/test/syscalls/linux/tgkill.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(TgkillTest, InvalidTID) { - EXPECT_THAT(tgkill(getpid(), -1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(tgkill(getpid(), 0, 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST(TgkillTest, InvalidTGID) { - EXPECT_THAT(tgkill(-1, gettid(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(tgkill(0, gettid(), 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST(TgkillTest, ValidInput) { - EXPECT_THAT(tgkill(getpid(), gettid(), 0), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc deleted file mode 100644 index e75bba669..000000000 --- a/test/syscalls/linux/time.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <time.h> - -#include "gtest/gtest.h" -#include "test/util/proc_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -constexpr long kFudgeSeconds = 5; - -#if defined(__x86_64__) || defined(__i386__) -// Mimics the time(2) wrapper from glibc prior to 2.15. -time_t vsyscall_time(time_t* t) { - constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; - return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t); -} - -TEST(TimeTest, VsyscallTime_Succeeds) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled())); - - time_t t1, t2; - - { - const DisableSave ds; // Timing assertions. - EXPECT_THAT(time(&t1), SyscallSucceeds()); - EXPECT_THAT(vsyscall_time(&t2), SyscallSucceeds()); - } - - // Time should be monotonic. - EXPECT_LE(static_cast<long>(t1), static_cast<long>(t2)); - - // Check that it's within kFudge seconds. - EXPECT_LE(static_cast<long>(t2), static_cast<long>(t1) + kFudgeSeconds); - - // Redo with save. - EXPECT_THAT(time(&t1), SyscallSucceeds()); - EXPECT_THAT(vsyscall_time(&t2), SyscallSucceeds()); - - // Time should be monotonic. - EXPECT_LE(static_cast<long>(t1), static_cast<long>(t2)); -} - -TEST(TimeTest, VsyscallTime_InvalidAddressSIGSEGV) { - EXPECT_EXIT(vsyscall_time(reinterpret_cast<time_t*>(0x1)), - ::testing::KilledBySignal(SIGSEGV), ""); -} - -// Mimics the gettimeofday(2) wrapper from the Go runtime <= 1.2. -int vsyscall_gettimeofday(struct timeval* tv, struct timezone* tz) { - constexpr uint64_t kVsyscallGettimeofdayEntry = 0xffffffffff600000; - return reinterpret_cast<int (*)(struct timeval*, struct timezone*)>( - kVsyscallGettimeofdayEntry)(tv, tz); -} - -TEST(TimeTest, VsyscallGettimeofday_Succeeds) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled())); - - struct timeval tv1, tv2; - struct timezone tz1, tz2; - - { - const DisableSave ds; // Timing assertions. - EXPECT_THAT(gettimeofday(&tv1, &tz1), SyscallSucceeds()); - EXPECT_THAT(vsyscall_gettimeofday(&tv2, &tz2), SyscallSucceeds()); - } - - // See above. - EXPECT_LE(static_cast<long>(tv1.tv_sec), static_cast<long>(tv2.tv_sec)); - EXPECT_LE(static_cast<long>(tv2.tv_sec), - static_cast<long>(tv1.tv_sec) + kFudgeSeconds); - - // Redo with save. - EXPECT_THAT(gettimeofday(&tv1, &tz1), SyscallSucceeds()); - EXPECT_THAT(vsyscall_gettimeofday(&tv2, &tz2), SyscallSucceeds()); -} - -TEST(TimeTest, VsyscallGettimeofday_InvalidAddressSIGSEGV) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled())); - - EXPECT_EXIT(vsyscall_gettimeofday(reinterpret_cast<struct timeval*>(0x1), - reinterpret_cast<struct timezone*>(0x1)), - ::testing::KilledBySignal(SIGSEGV), ""); -} -#endif - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/timerfd.cc b/test/syscalls/linux/timerfd.cc deleted file mode 100644 index 86ed87b7c..000000000 --- a/test/syscalls/linux/timerfd.cc +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <poll.h> -#include <sys/timerfd.h> -#include <time.h> - -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Wrapper around timerfd_create(2) that returns a FileDescriptor. -PosixErrorOr<FileDescriptor> TimerfdCreate(int clockid, int flags) { - int fd = timerfd_create(clockid, flags); - MaybeSave(); - if (fd < 0) { - return PosixError(errno, "timerfd_create failed"); - } - return FileDescriptor(fd); -} - -// In tests that race a timerfd with a sleep, some slack is required because: -// -// - Timerfd expirations are asynchronous with respect to nanosleeps. -// -// - Because clock_gettime(CLOCK_MONOTONIC) is implemented through the VDSO, -// it technically uses a closely-related, but distinct, time domain from the -// CLOCK_MONOTONIC used to trigger timerfd expirations. The same applies to -// CLOCK_BOOTTIME which is an alias for CLOCK_MONOTONIC. -absl::Duration TimerSlack() { return absl::Milliseconds(500); } - -class TimerfdTest : public ::testing::TestWithParam<int> {}; - -TEST_P(TimerfdTest, IsInitiallyStopped) { - auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); - struct itimerspec its = {}; - ASSERT_THAT(timerfd_gettime(tfd.get(), &its), SyscallSucceeds()); - EXPECT_EQ(0, its.it_value.tv_sec); - EXPECT_EQ(0, its.it_value.tv_nsec); -} - -TEST_P(TimerfdTest, SingleShot) { - constexpr absl::Duration kDelay = absl::Seconds(1); - - auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); - struct itimerspec its = {}; - its.it_value = absl::ToTimespec(kDelay); - ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr), - SyscallSucceeds()); - - // The timer should fire exactly once since the interval is zero. - absl::SleepFor(kDelay + TimerSlack()); - uint64_t val = 0; - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallSucceedsWithValue(sizeof(uint64_t))); - EXPECT_EQ(1, val); -} - -TEST_P(TimerfdTest, Periodic) { - constexpr absl::Duration kDelay = absl::Seconds(1); - constexpr int kPeriods = 3; - - auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); - struct itimerspec its = {}; - its.it_value = absl::ToTimespec(kDelay); - its.it_interval = absl::ToTimespec(kDelay); - ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr), - SyscallSucceeds()); - - // Expect to see at least kPeriods expirations. More may occur due to the - // timer slack, or due to delays from scheduling or save/restore. - absl::SleepFor(kPeriods * kDelay + TimerSlack()); - uint64_t val = 0; - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallSucceedsWithValue(sizeof(uint64_t))); - EXPECT_GE(val, kPeriods); -} - -TEST_P(TimerfdTest, BlockingRead) { - constexpr absl::Duration kDelay = absl::Seconds(3); - - auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); - struct itimerspec its = {}; - its.it_value.tv_sec = absl::ToInt64Seconds(kDelay); - auto const start_time = absl::Now(); - ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr), - SyscallSucceeds()); - - // read should block until the timer fires. - uint64_t val = 0; - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallSucceedsWithValue(sizeof(uint64_t))); - auto const end_time = absl::Now(); - EXPECT_EQ(1, val); - EXPECT_GE((end_time - start_time) + TimerSlack(), kDelay); -} - -TEST_P(TimerfdTest, NonblockingRead_NoRandomSave) { - constexpr absl::Duration kDelay = absl::Seconds(5); - - auto const tfd = - ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK)); - - // Since the timer is initially disabled and has never fired, read should - // return EAGAIN. - uint64_t val = 0; - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallFailsWithErrno(EAGAIN)); - - DisableSave ds; // Timing-sensitive. - - // Arm the timer. - struct itimerspec its = {}; - its.it_value.tv_sec = absl::ToInt64Seconds(kDelay); - ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr), - SyscallSucceeds()); - - // Since the timer has not yet fired, read should return EAGAIN. - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallFailsWithErrno(EAGAIN)); - - ds.reset(); // No longer timing-sensitive. - - // After the timer fires, read should indicate 1 expiration. - absl::SleepFor(kDelay + TimerSlack()); - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallSucceedsWithValue(sizeof(uint64_t))); - EXPECT_EQ(1, val); - - // The successful read should have reset the number of expirations. - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(TimerfdTest, BlockingPoll_SetTimeResetsExpirations) { - constexpr absl::Duration kDelay = absl::Seconds(3); - - auto const tfd = - ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK)); - struct itimerspec its = {}; - its.it_value.tv_sec = absl::ToInt64Seconds(kDelay); - auto const start_time = absl::Now(); - ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr), - SyscallSucceeds()); - - // poll should block until the timer fires. - struct pollfd pfd = {}; - pfd.fd = tfd.get(); - pfd.events = POLLIN; - ASSERT_THAT(poll(&pfd, /* nfds = */ 1, - /* timeout = */ 2 * absl::ToInt64Seconds(kDelay) * 1000), - SyscallSucceedsWithValue(1)); - auto const end_time = absl::Now(); - EXPECT_EQ(POLLIN, pfd.revents); - EXPECT_GE((end_time - start_time) + TimerSlack(), kDelay); - - // Call timerfd_settime again with a value of 0. This should reset the number - // of expirations to 0, causing read to return EAGAIN since the timerfd is - // non-blocking. - its.it_value.tv_sec = 0; - ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr), - SyscallSucceeds()); - uint64_t val = 0; - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(TimerfdTest, SetAbsoluteTime) { - constexpr absl::Duration kDelay = absl::Seconds(3); - - // Use a non-blocking timerfd so that if TFD_TIMER_ABSTIME is incorrectly - // non-functional, we get EAGAIN rather than a test timeout. - auto const tfd = - ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK)); - struct itimerspec its = {}; - ASSERT_THAT(clock_gettime(GetParam(), &its.it_value), SyscallSucceeds()); - its.it_value.tv_sec += absl::ToInt64Seconds(kDelay); - ASSERT_THAT(timerfd_settime(tfd.get(), TFD_TIMER_ABSTIME, &its, nullptr), - SyscallSucceeds()); - - absl::SleepFor(kDelay + TimerSlack()); - uint64_t val = 0; - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallSucceedsWithValue(sizeof(uint64_t))); - EXPECT_EQ(1, val); -} - -TEST_P(TimerfdTest, IllegalReadWrite) { - auto const tfd = - ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK)); - uint64_t val = 0; - EXPECT_THAT(PreadFd(tfd.get(), &val, sizeof(val), 0), - SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(WriteFd(tfd.get(), &val, sizeof(val)), - SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(PwriteFd(tfd.get(), &val, sizeof(val), 0), - SyscallFailsWithErrno(ESPIPE)); -} - -std::string PrintClockId(::testing::TestParamInfo<int> info) { - switch (info.param) { - case CLOCK_MONOTONIC: - return "CLOCK_MONOTONIC"; - case CLOCK_BOOTTIME: - return "CLOCK_BOOTTIME"; - default: - return absl::StrCat(info.param); - } -} - -INSTANTIATE_TEST_SUITE_P(AllTimerTypes, TimerfdTest, - ::testing::Values(CLOCK_MONOTONIC, CLOCK_BOOTTIME), - PrintClockId); - -TEST(TimerfdClockRealtimeTest, ClockRealtime) { - // Since CLOCK_REALTIME can, by definition, change, we can't make any - // non-flaky assertions about the amount of time it takes for a - // CLOCK_REALTIME-based timer to expire. Just check that it expires at all, - // and hope it happens before the test times out. - constexpr int kDelaySecs = 1; - - auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(CLOCK_REALTIME, 0)); - struct itimerspec its = {}; - its.it_value.tv_sec = kDelaySecs; - ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr), - SyscallSucceeds()); - - uint64_t val = 0; - ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)), - SyscallSucceedsWithValue(sizeof(uint64_t))); - EXPECT_EQ(1, val); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc deleted file mode 100644 index 4b3c44527..000000000 --- a/test/syscalls/linux/timers.cc +++ /dev/null @@ -1,662 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <signal.h> -#include <sys/resource.h> -#include <sys/time.h> -#include <syscall.h> -#include <time.h> -#include <unistd.h> - -#include <atomic> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/cleanup.h" -#include "test/util/logging.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -ABSL_FLAG(bool, timers_test_sleep, false, - "If true, sleep forever instead of running tests."); - -using ::testing::_; -using ::testing::AnyOf; - -namespace gvisor { -namespace testing { -namespace { - -#ifndef CPUCLOCK_PROF -#define CPUCLOCK_PROF 0 -#endif // CPUCLOCK_PROF - -PosixErrorOr<absl::Duration> ProcessCPUTime(pid_t pid) { - // Use pid-specific CPUCLOCK_PROF, which is the clock used to enforce - // RLIMIT_CPU. - clockid_t clockid = (~static_cast<clockid_t>(pid) << 3) | CPUCLOCK_PROF; - - struct timespec ts; - int ret = clock_gettime(clockid, &ts); - if (ret < 0) { - return PosixError(errno, "clock_gettime failed"); - } - - return absl::DurationFromTimespec(ts); -} - -void NoopSignalHandler(int signo) { - TEST_CHECK_MSG(SIGXCPU == signo, - "NoopSigHandler did not receive expected signal"); -} - -void UninstallingSignalHandler(int signo) { - TEST_CHECK_MSG(SIGXCPU == signo, - "UninstallingSignalHandler did not receive expected signal"); - struct sigaction rev_action; - rev_action.sa_handler = SIG_DFL; - rev_action.sa_flags = 0; - sigemptyset(&rev_action.sa_mask); - sigaction(SIGXCPU, &rev_action, nullptr); -} - -TEST(TimerTest, ProcessKilledOnCPUSoftLimit) { - constexpr absl::Duration kSoftLimit = absl::Seconds(1); - constexpr absl::Duration kHardLimit = absl::Seconds(3); - - struct rlimit cpu_limits; - cpu_limits.rlim_cur = absl::ToInt64Seconds(kSoftLimit); - cpu_limits.rlim_max = absl::ToInt64Seconds(kHardLimit); - - int pid = fork(); - MaybeSave(); - if (pid == 0) { - TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); - MaybeSave(); - for (;;) { - } - } - ASSERT_THAT(pid, SyscallSucceeds()); - auto c = Cleanup([pid] { - int status; - EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status)); - EXPECT_EQ(WTERMSIG(status), SIGXCPU); - }); - - // Wait for the child to exit, but do not reap it. This will allow us to check - // its CPU usage while it is zombied. - EXPECT_THAT(waitid(P_PID, pid, nullptr, WEXITED | WNOWAIT), - SyscallSucceeds()); - - // Assert that the child spent 1s of CPU before getting killed. - // - // We must be careful to use CPUCLOCK_PROF, the same clock used for RLIMIT_CPU - // enforcement, to get correct results. Note that this is slightly different - // from rusage-reported CPU usage: - // - // RLIMIT_CPU, CPUCLOCK_PROF use kernel/sched/cputime.c:thread_group_cputime. - // rusage uses kernel/sched/cputime.c:thread_group_cputime_adjusted. - absl::Duration cpu = ASSERT_NO_ERRNO_AND_VALUE(ProcessCPUTime(pid)); - EXPECT_GE(cpu, kSoftLimit); - - // Child did not make it to the hard limit. - // - // Linux sends SIGXCPU synchronously with CPU tick updates. See - // kernel/time/timer.c:update_process_times: - // => account_process_tick // update task CPU usage. - // => run_posix_cpu_timers // enforce RLIMIT_CPU, sending signal. - // - // Thus, only chance for this to flake is if the system time required to - // deliver the signal exceeds 2s. - EXPECT_LT(cpu, kHardLimit); -} - -TEST(TimerTest, ProcessPingedRepeatedlyAfterCPUSoftLimit) { - struct sigaction new_action; - new_action.sa_handler = UninstallingSignalHandler; - new_action.sa_flags = 0; - sigemptyset(&new_action.sa_mask); - - constexpr absl::Duration kSoftLimit = absl::Seconds(1); - constexpr absl::Duration kHardLimit = absl::Seconds(10); - - struct rlimit cpu_limits; - cpu_limits.rlim_cur = absl::ToInt64Seconds(kSoftLimit); - cpu_limits.rlim_max = absl::ToInt64Seconds(kHardLimit); - - int pid = fork(); - MaybeSave(); - if (pid == 0) { - TEST_PCHECK(sigaction(SIGXCPU, &new_action, nullptr) == 0); - MaybeSave(); - TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); - MaybeSave(); - for (;;) { - } - } - ASSERT_THAT(pid, SyscallSucceeds()); - auto c = Cleanup([pid] { - int status; - EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status)); - EXPECT_EQ(WTERMSIG(status), SIGXCPU); - }); - - // Wait for the child to exit, but do not reap it. This will allow us to check - // its CPU usage while it is zombied. - EXPECT_THAT(waitid(P_PID, pid, nullptr, WEXITED | WNOWAIT), - SyscallSucceeds()); - - absl::Duration cpu = ASSERT_NO_ERRNO_AND_VALUE(ProcessCPUTime(pid)); - // Following signals come every CPU second. - EXPECT_GE(cpu, kSoftLimit + absl::Seconds(1)); - - // Child did not make it to the hard limit. - // - // As above, should not flake. - EXPECT_LT(cpu, kHardLimit); -} - -TEST(TimerTest, ProcessKilledOnCPUHardLimit) { - struct sigaction new_action; - new_action.sa_handler = NoopSignalHandler; - new_action.sa_flags = 0; - sigemptyset(&new_action.sa_mask); - - constexpr absl::Duration kSoftLimit = absl::Seconds(1); - constexpr absl::Duration kHardLimit = absl::Seconds(3); - - struct rlimit cpu_limits; - cpu_limits.rlim_cur = absl::ToInt64Seconds(kSoftLimit); - cpu_limits.rlim_max = absl::ToInt64Seconds(kHardLimit); - - int pid = fork(); - MaybeSave(); - if (pid == 0) { - TEST_PCHECK(sigaction(SIGXCPU, &new_action, nullptr) == 0); - MaybeSave(); - TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); - MaybeSave(); - for (;;) { - } - } - ASSERT_THAT(pid, SyscallSucceeds()); - auto c = Cleanup([pid] { - int status; - EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); - EXPECT_TRUE(WIFSIGNALED(status)); - EXPECT_EQ(WTERMSIG(status), SIGKILL); - }); - - // Wait for the child to exit, but do not reap it. This will allow us to check - // its CPU usage while it is zombied. - EXPECT_THAT(waitid(P_PID, pid, nullptr, WEXITED | WNOWAIT), - SyscallSucceeds()); - - absl::Duration cpu = ASSERT_NO_ERRNO_AND_VALUE(ProcessCPUTime(pid)); - EXPECT_GE(cpu, kHardLimit); -} - -// RAII type for a kernel "POSIX" interval timer. (The kernel provides system -// calls such as timer_create that behave very similarly, but not identically, -// to those described by timer_create(2); in particular, the kernel does not -// implement SIGEV_THREAD. glibc builds POSIX-compliant interval timers based on -// these kernel interval timers.) -// -// Compare implementation to FileDescriptor. -class IntervalTimer { - public: - IntervalTimer() = default; - - explicit IntervalTimer(int id) { set_id(id); } - - IntervalTimer(IntervalTimer&& orig) : id_(orig.release()) {} - - IntervalTimer& operator=(IntervalTimer&& orig) { - if (this == &orig) return *this; - reset(orig.release()); - return *this; - } - - IntervalTimer(const IntervalTimer& other) = delete; - IntervalTimer& operator=(const IntervalTimer& other) = delete; - - ~IntervalTimer() { reset(); } - - int get() const { return id_; } - - int release() { - int const id = id_; - id_ = -1; - return id; - } - - void reset() { reset(-1); } - - void reset(int id) { - if (id_ >= 0) { - TEST_PCHECK(syscall(SYS_timer_delete, id_) == 0); - MaybeSave(); - } - set_id(id); - } - - PosixErrorOr<struct itimerspec> Set( - int flags, const struct itimerspec& new_value) const { - struct itimerspec old_value = {}; - if (syscall(SYS_timer_settime, id_, flags, &new_value, &old_value) < 0) { - return PosixError(errno, "timer_settime"); - } - MaybeSave(); - return old_value; - } - - PosixErrorOr<struct itimerspec> Get() const { - struct itimerspec curr_value = {}; - if (syscall(SYS_timer_gettime, id_, &curr_value) < 0) { - return PosixError(errno, "timer_gettime"); - } - MaybeSave(); - return curr_value; - } - - PosixErrorOr<int> Overruns() const { - int rv = syscall(SYS_timer_getoverrun, id_); - if (rv < 0) { - return PosixError(errno, "timer_getoverrun"); - } - MaybeSave(); - return rv; - } - - private: - void set_id(int id) { id_ = std::max(id, -1); } - - // Kernel timer_t is int; glibc timer_t is void*. - int id_ = -1; -}; - -PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid, - const struct sigevent& sev) { - int timerid; - int ret = syscall(SYS_timer_create, clockid, &sev, &timerid); - if (ret < 0) { - return PosixError(errno, "timer_create"); - } - if (ret > 0) { - return PosixError(EINVAL, "timer_create should never return positive"); - } - MaybeSave(); - return IntervalTimer(timerid); -} - -// See timerfd.cc:TimerSlack() for rationale. -constexpr absl::Duration kTimerSlack = absl::Milliseconds(500); - -TEST(IntervalTimerTest, IsInitiallyStopped) { - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_NONE; - const auto timer = - ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - const struct itimerspec its = ASSERT_NO_ERRNO_AND_VALUE(timer.Get()); - EXPECT_EQ(0, its.it_value.tv_sec); - EXPECT_EQ(0, its.it_value.tv_nsec); -} - -// Kernel can create multiple timers without issue. -// -// Regression test for gvisor.dev/issue/1738. -TEST(IntervalTimerTest, MultipleTimers) { - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_NONE; - const auto timer1 = - ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - const auto timer2 = - ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); -} - -TEST(IntervalTimerTest, SingleShotSilent) { - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_NONE; - const auto timer = - ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - - constexpr absl::Duration kDelay = absl::Seconds(1); - struct itimerspec its = {}; - its.it_value = absl::ToTimespec(kDelay); - ASSERT_NO_ERRNO(timer.Set(0, its)); - - // The timer should count down to 0 and stop since the interval is zero. No - // overruns should be counted. - absl::SleepFor(kDelay + kTimerSlack); - its = ASSERT_NO_ERRNO_AND_VALUE(timer.Get()); - EXPECT_EQ(0, its.it_value.tv_sec); - EXPECT_EQ(0, its.it_value.tv_nsec); - EXPECT_THAT(timer.Overruns(), IsPosixErrorOkAndHolds(0)); -} - -TEST(IntervalTimerTest, PeriodicSilent) { - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_NONE; - const auto timer = - ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - - constexpr absl::Duration kPeriod = absl::Seconds(1); - struct itimerspec its = {}; - its.it_value = its.it_interval = absl::ToTimespec(kPeriod); - ASSERT_NO_ERRNO(timer.Set(0, its)); - - absl::SleepFor(kPeriod * 3 + kTimerSlack); - - // The timer should still be running. - its = ASSERT_NO_ERRNO_AND_VALUE(timer.Get()); - EXPECT_TRUE(its.it_value.tv_nsec != 0 || its.it_value.tv_sec != 0); - - // Timer expirations are not counted as overruns under SIGEV_NONE. - EXPECT_THAT(timer.Overruns(), IsPosixErrorOkAndHolds(0)); -} - -std::atomic<int> counted_signals; - -void IntervalTimerCountingSignalHandler(int sig, siginfo_t* info, - void* ucontext) { - counted_signals.fetch_add(1 + info->si_overrun); -} - -TEST(IntervalTimerTest, PeriodicGroupDirectedSignal) { - constexpr int kSigno = SIGUSR1; - constexpr int kSigvalue = 42; - - // Install our signal handler. - counted_signals.store(0); - struct sigaction sa = {}; - sa.sa_sigaction = IntervalTimerCountingSignalHandler; - sigemptyset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - const auto scoped_sigaction = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kSigno, sa)); - - // Ensure that kSigno is unblocked on at least one thread. - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, kSigno)); - - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_SIGNAL; - sev.sigev_signo = kSigno; - sev.sigev_value.sival_int = kSigvalue; - auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - - constexpr absl::Duration kPeriod = absl::Seconds(1); - constexpr int kCycles = 3; - struct itimerspec its = {}; - its.it_value = its.it_interval = absl::ToTimespec(kPeriod); - ASSERT_NO_ERRNO(timer.Set(0, its)); - - absl::SleepFor(kPeriod * kCycles + kTimerSlack); - EXPECT_GE(counted_signals.load(), kCycles); -} - -// From Linux's include/uapi/asm-generic/siginfo.h. -#ifndef sigev_notify_thread_id -#define sigev_notify_thread_id _sigev_un._tid -#endif - -TEST(IntervalTimerTest, PeriodicThreadDirectedSignal) { - constexpr int kSigno = SIGUSR1; - constexpr int kSigvalue = 42; - - // Block kSigno so that we can accumulate overruns. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, kSigno); - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, mask)); - - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_THREAD_ID; - sev.sigev_signo = kSigno; - sev.sigev_value.sival_int = kSigvalue; - sev.sigev_notify_thread_id = gettid(); - auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - - constexpr absl::Duration kPeriod = absl::Seconds(1); - constexpr int kCycles = 3; - struct itimerspec its = {}; - its.it_value = its.it_interval = absl::ToTimespec(kPeriod); - ASSERT_NO_ERRNO(timer.Set(0, its)); - absl::SleepFor(kPeriod * kCycles + kTimerSlack); - - // At least kCycles expirations should have occurred, resulting in kCycles-1 - // overruns (the first expiration sent the signal successfully). - siginfo_t si; - struct timespec zero_ts = absl::ToTimespec(absl::ZeroDuration()); - ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts), - SyscallSucceedsWithValue(kSigno)); - EXPECT_EQ(si.si_signo, kSigno); - EXPECT_EQ(si.si_code, SI_TIMER); - EXPECT_EQ(si.si_timerid, timer.get()); - EXPECT_GE(si.si_overrun, kCycles - 1); - EXPECT_EQ(si.si_int, kSigvalue); - - // Kill the timer, then drain any additional signal it may have enqueued. We - // can't do this before the preceding sigtimedwait because stopping or - // deleting the timer resets si_overrun to 0. - timer.reset(); - sigtimedwait(&mask, &si, &zero_ts); -} - -TEST(IntervalTimerTest, OtherThreadGroup) { - constexpr int kSigno = SIGUSR1; - - // Create a subprocess that does nothing until killed. - pid_t child_pid; - const auto sp = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec( - "/proc/self/exe", ExecveArray({"timers", "--timers_test_sleep"}), - ExecveArray(), &child_pid, nullptr)); - - // Verify that we can't create a timer that would send signals to it. - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_THREAD_ID; - sev.sigev_signo = kSigno; - sev.sigev_notify_thread_id = child_pid; - EXPECT_THAT(TimerCreate(CLOCK_MONOTONIC, sev), PosixErrorIs(EINVAL, _)); -} - -TEST(IntervalTimerTest, RealTimeSignalsAreNotDuplicated) { - const int kSigno = SIGRTMIN; - constexpr int kSigvalue = 42; - - // Block signo so that we can accumulate overruns. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, kSigno); - const auto scoped_sigmask = ScopedSignalMask(SIG_BLOCK, mask); - - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_THREAD_ID; - sev.sigev_signo = kSigno; - sev.sigev_value.sival_int = kSigvalue; - sev.sigev_notify_thread_id = gettid(); - const auto timer = - ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - - constexpr absl::Duration kPeriod = absl::Seconds(1); - constexpr int kCycles = 3; - struct itimerspec its = {}; - its.it_value = its.it_interval = absl::ToTimespec(kPeriod); - ASSERT_NO_ERRNO(timer.Set(0, its)); - absl::SleepFor(kPeriod * kCycles + kTimerSlack); - - // Stop the timer so that no further signals are enqueued after sigtimedwait. - struct timespec zero_ts = absl::ToTimespec(absl::ZeroDuration()); - its.it_value = its.it_interval = zero_ts; - ASSERT_NO_ERRNO(timer.Set(0, its)); - - // The timer should have sent only a single signal, even though the kernel - // supports enqueueing of multiple RT signals. - siginfo_t si; - ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts), - SyscallSucceedsWithValue(kSigno)); - EXPECT_EQ(si.si_signo, kSigno); - EXPECT_EQ(si.si_code, SI_TIMER); - EXPECT_EQ(si.si_timerid, timer.get()); - // si_overrun was reset by timer_settime. - EXPECT_EQ(si.si_overrun, 0); - EXPECT_EQ(si.si_int, kSigvalue); - EXPECT_THAT(sigtimedwait(&mask, &si, &zero_ts), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST(IntervalTimerTest, AlreadyPendingSignal) { - constexpr int kSigno = SIGUSR1; - constexpr int kSigvalue = 42; - - // Block kSigno so that we can accumulate overruns. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, kSigno); - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, mask)); - - // Send ourselves a signal, preventing the timer from enqueuing. - ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); - - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_THREAD_ID; - sev.sigev_signo = kSigno; - sev.sigev_value.sival_int = kSigvalue; - sev.sigev_notify_thread_id = gettid(); - auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - - constexpr absl::Duration kPeriod = absl::Seconds(1); - constexpr int kCycles = 3; - struct itimerspec its = {}; - its.it_value = its.it_interval = absl::ToTimespec(kPeriod); - ASSERT_NO_ERRNO(timer.Set(0, its)); - - // End the sleep one cycle short; we will sleep for one more cycle below. - absl::SleepFor(kPeriod * (kCycles - 1)); - - // Dequeue the first signal, which we sent to ourselves with tgkill. - siginfo_t si; - struct timespec zero_ts = absl::ToTimespec(absl::ZeroDuration()); - ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts), - SyscallSucceedsWithValue(kSigno)); - EXPECT_EQ(si.si_signo, kSigno); - // glibc sigtimedwait silently replaces SI_TKILL with SI_USER: - // sysdeps/unix/sysv/linux/sigtimedwait.c:__sigtimedwait(). This isn't - // documented, so we don't depend on it. - EXPECT_THAT(si.si_code, AnyOf(SI_USER, SI_TKILL)); - - // Sleep for 1 more cycle to give the timer time to send a signal. - absl::SleepFor(kPeriod + kTimerSlack); - - // At least kCycles expirations should have occurred, resulting in kCycles-1 - // overruns (the last expiration sent the signal successfully). - ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts), - SyscallSucceedsWithValue(kSigno)); - EXPECT_EQ(si.si_signo, kSigno); - EXPECT_EQ(si.si_code, SI_TIMER); - EXPECT_EQ(si.si_timerid, timer.get()); - EXPECT_GE(si.si_overrun, kCycles - 1); - EXPECT_EQ(si.si_int, kSigvalue); - - // Kill the timer, then drain any additional signal it may have enqueued. We - // can't do this before the preceding sigtimedwait because stopping or - // deleting the timer resets si_overrun to 0. - timer.reset(); - sigtimedwait(&mask, &si, &zero_ts); -} - -TEST(IntervalTimerTest, IgnoredSignalCountsAsOverrun) { - constexpr int kSigno = SIGUSR1; - constexpr int kSigvalue = 42; - - // Ignore kSigno. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - const auto scoped_sigaction = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kSigno, sa)); - - // Unblock kSigno so that ignored signals will be discarded. - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, kSigno); - auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, mask)); - - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_THREAD_ID; - sev.sigev_signo = kSigno; - sev.sigev_value.sival_int = kSigvalue; - sev.sigev_notify_thread_id = gettid(); - auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - - constexpr absl::Duration kPeriod = absl::Seconds(1); - constexpr int kCycles = 3; - struct itimerspec its = {}; - its.it_value = its.it_interval = absl::ToTimespec(kPeriod); - ASSERT_NO_ERRNO(timer.Set(0, its)); - - // End the sleep one cycle short; we will sleep for one more cycle below. - absl::SleepFor(kPeriod * (kCycles - 1)); - - // Block kSigno so that ignored signals will be enqueued. - scoped_sigmask.Release()(); - scoped_sigmask = ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, mask)); - - // Sleep for 1 more cycle to give the timer time to send a signal. - absl::SleepFor(kPeriod + kTimerSlack); - - // At least kCycles expirations should have occurred, resulting in kCycles-1 - // overruns (the last expiration sent the signal successfully). - siginfo_t si; - struct timespec zero_ts = absl::ToTimespec(absl::ZeroDuration()); - ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts), - SyscallSucceedsWithValue(kSigno)); - EXPECT_EQ(si.si_signo, kSigno); - EXPECT_EQ(si.si_code, SI_TIMER); - EXPECT_EQ(si.si_timerid, timer.get()); - EXPECT_GE(si.si_overrun, kCycles - 1); - EXPECT_EQ(si.si_int, kSigvalue); - - // Kill the timer, then drain any additional signal it may have enqueued. We - // can't do this before the preceding sigtimedwait because stopping or - // deleting the timer resets si_overrun to 0. - timer.reset(); - sigtimedwait(&mask, &si, &zero_ts); -} - -} // namespace -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - gvisor::testing::TestInit(&argc, &argv); - - if (absl::GetFlag(FLAGS_timers_test_sleep)) { - while (true) { - absl::SleepFor(absl::Seconds(10)); - } - } - - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/tkill.cc b/test/syscalls/linux/tkill.cc deleted file mode 100644 index 8d8ebbb24..000000000 --- a/test/syscalls/linux/tkill.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/syscall.h> -#include <sys/types.h> -#include <unistd.h> - -#include <cerrno> -#include <csignal> - -#include "gtest/gtest.h" -#include "test/util/logging.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -static int tkill(pid_t tid, int sig) { - int ret; - do { - // NOTE(b/25434735): tkill(2) could return EAGAIN for RT signals. - ret = syscall(SYS_tkill, tid, sig); - } while (ret == -1 && errno == EAGAIN); - return ret; -} - -TEST(TkillTest, InvalidTID) { - EXPECT_THAT(tkill(-1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(tkill(0, 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST(TkillTest, ValidTID) { - EXPECT_THAT(tkill(gettid(), 0), SyscallSucceeds()); -} - -void SigHandler(int sig, siginfo_t* info, void* context) { - TEST_CHECK(sig == SIGRTMAX); - TEST_CHECK(info->si_pid == getpid()); - TEST_CHECK(info->si_uid == getuid()); - TEST_CHECK(info->si_code == SI_TKILL); -} - -// Test with a real signal. Regression test for b/24790092. -TEST(TkillTest, ValidTIDAndRealSignal) { - struct sigaction sa; - sa.sa_sigaction = SigHandler; - sigfillset(&sa.sa_mask); - sa.sa_flags = SA_SIGINFO; - ASSERT_THAT(sigaction(SIGRTMAX, &sa, nullptr), SyscallSucceeds()); - // InitGoogle blocks all RT signals, so we need undo it. - sigset_t unblock; - sigemptyset(&unblock); - sigaddset(&unblock, SIGRTMAX); - ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &unblock, nullptr), SyscallSucceeds()); - EXPECT_THAT(tkill(gettid(), SIGRTMAX), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc deleted file mode 100644 index c988c6380..000000000 --- a/test/syscalls/linux/truncate.cc +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <signal.h> -#include <sys/resource.h> -#include <sys/stat.h> -#include <sys/vfs.h> -#include <time.h> -#include <unistd.h> - -#include <iostream> -#include <string> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/capability_util.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class FixtureTruncateTest : public FileTest { - void SetUp() override { FileTest::SetUp(); } -}; - -TEST_F(FixtureTruncateTest, Truncate) { - // Get the current rlimit and restore after test run. - struct rlimit initial_lim; - ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - auto cleanup = Cleanup([&initial_lim] { - EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - }); - - // Check that it starts at size zero. - struct stat buf; - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 0); - - // Stay at size zero. - EXPECT_THAT(truncate(test_file_name_.c_str(), 0), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 0); - - // Grow to ten bytes. - EXPECT_THAT(truncate(test_file_name_.c_str(), 10), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 10); - - // Can't be truncated to a negative number. - EXPECT_THAT(truncate(test_file_name_.c_str(), -1), - SyscallFailsWithErrno(EINVAL)); - - // Try growing past the file size limit. - sigset_t new_mask; - sigemptyset(&new_mask); - sigaddset(&new_mask, SIGXFSZ); - sigprocmask(SIG_BLOCK, &new_mask, nullptr); - struct timespec timelimit; - timelimit.tv_sec = 10; - timelimit.tv_nsec = 0; - - struct rlimit setlim; - setlim.rlim_cur = 1024; - setlim.rlim_max = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds()); - EXPECT_THAT(truncate(test_file_name_.c_str(), 1025), - SyscallFailsWithErrno(EFBIG)); - EXPECT_EQ(sigtimedwait(&new_mask, nullptr, &timelimit), SIGXFSZ); - ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds()); - - // Shrink back down to zero. - EXPECT_THAT(truncate(test_file_name_.c_str(), 0), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 0); -} - -TEST_F(FixtureTruncateTest, Ftruncate) { - // Get the current rlimit and restore after test run. - struct rlimit initial_lim; - ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - auto cleanup = Cleanup([&initial_lim] { - EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - }); - - // Check that it starts at size zero. - struct stat buf; - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 0); - - // Stay at size zero. - EXPECT_THAT(ftruncate(test_file_fd_.get(), 0), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 0); - - // Grow to ten bytes. - EXPECT_THAT(ftruncate(test_file_fd_.get(), 10), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 10); - - // Can't be truncated to a negative number. - EXPECT_THAT(ftruncate(test_file_fd_.get(), -1), - SyscallFailsWithErrno(EINVAL)); - - // Try growing past the file size limit. - sigset_t new_mask; - sigemptyset(&new_mask); - sigaddset(&new_mask, SIGXFSZ); - sigprocmask(SIG_BLOCK, &new_mask, nullptr); - struct timespec timelimit; - timelimit.tv_sec = 10; - timelimit.tv_nsec = 0; - - struct rlimit setlim; - setlim.rlim_cur = 1024; - setlim.rlim_max = RLIM_INFINITY; - ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds()); - EXPECT_THAT(ftruncate(test_file_fd_.get(), 1025), - SyscallFailsWithErrno(EFBIG)); - EXPECT_EQ(sigtimedwait(&new_mask, nullptr, &timelimit), SIGXFSZ); - ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds()); - - // Shrink back down to zero. - EXPECT_THAT(ftruncate(test_file_fd_.get(), 0), SyscallSucceeds()); - ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); - EXPECT_EQ(buf.st_size, 0); -} - -// Truncating a file down clears that portion of the file. -TEST_F(FixtureTruncateTest, FtruncateShrinkGrow) { - std::vector<char> buf(10, 'a'); - EXPECT_THAT(WriteFd(test_file_fd_.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - // Shrink then regrow the file. This should clear the second half of the file. - EXPECT_THAT(ftruncate(test_file_fd_.get(), 5), SyscallSucceeds()); - EXPECT_THAT(ftruncate(test_file_fd_.get(), 10), SyscallSucceeds()); - - EXPECT_THAT(lseek(test_file_fd_.get(), 0, SEEK_SET), SyscallSucceeds()); - - std::vector<char> buf2(10); - EXPECT_THAT(ReadFd(test_file_fd_.get(), buf2.data(), buf2.size()), - SyscallSucceedsWithValue(buf2.size())); - - std::vector<char> expect = {'a', 'a', 'a', 'a', 'a', - '\0', '\0', '\0', '\0', '\0'}; - EXPECT_EQ(expect, buf2); -} - -TEST(TruncateTest, TruncateDir) { - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(truncate(temp_dir.path().c_str(), 0), - SyscallFailsWithErrno(EISDIR)); -} - -TEST(TruncateTest, FtruncateDir) { - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(temp_dir.path(), O_DIRECTORY | O_RDONLY)); - EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST(TruncateTest, TruncateNonWriteable) { - // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to - // always override write permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */)); - EXPECT_THAT(truncate(temp_file.path().c_str(), 0), - SyscallFailsWithErrno(EACCES)); -} - -TEST(TruncateTest, FtruncateNonWriteable) { - auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDONLY)); - EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); -} - -TEST(TruncateTest, TruncateNonExist) { - EXPECT_THAT(truncate("/foo/bar", 0), SyscallFailsWithErrno(ENOENT)); -} - -TEST(TruncateTest, FtruncateVirtualTmp_NoRandomSave) { - auto temp_file = NewTempAbsPathInDir("/dev/shm"); - const DisableSave ds; // Incompatible permissions. - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file, O_RDWR | O_CREAT | O_EXCL, 0)); - EXPECT_THAT(ftruncate(fd.get(), 100), SyscallSucceeds()); -} - -// NOTE: There are additional truncate(2)/ftruncate(2) tests in mknod.cc -// which are there to avoid running the tests on a number of different -// filesystems which may not support mknod. - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc deleted file mode 100644 index f734511d6..000000000 --- a/test/syscalls/linux/tuntap.cc +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <linux/capability.h> -#include <linux/if_arp.h> -#include <linux/if_ether.h> -#include <linux/if_tun.h> -#include <netinet/ip.h> -#include <netinet/ip_icmp.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/ascii.h" -#include "absl/strings/str_split.h" -#include "test/syscalls/linux/socket_netlink_route_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { -namespace { - -constexpr int kIPLen = 4; - -constexpr const char kDevNetTun[] = "/dev/net/tun"; -constexpr const char kTapName[] = "tap0"; - -constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}; -constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}; - -PosixErrorOr<std::set<std::string>> DumpLinkNames() { - ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); - std::set<std::string> names; - for (const auto& link : links) { - names.emplace(link.name); - } - return names; -} - -PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) { - ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); - for (const auto& link : links) { - if (link.name == name) { - return absl::optional<Link>(link); - } - } - return absl::optional<Link>(); -} - -struct pihdr { - uint16_t pi_flags; - uint16_t pi_protocol; -} __attribute__((packed)); - -struct ping_pkt { - pihdr pi; - struct ethhdr eth; - struct iphdr ip; - struct icmphdr icmp; - char payload[64]; -} __attribute__((packed)); - -ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, - const uint8_t dstmac[ETH_ALEN], const char* dstip) { - ping_pkt pkt = {}; - - pkt.pi.pi_protocol = htons(ETH_P_IP); - - memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest)); - memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source)); - pkt.eth.h_proto = htons(ETH_P_IP); - - pkt.ip.ihl = 5; - pkt.ip.version = 4; - pkt.ip.tos = 0; - pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) + - sizeof(pkt.payload)); - pkt.ip.id = 1; - pkt.ip.frag_off = 1 << 6; // Do not fragment - pkt.ip.ttl = 64; - pkt.ip.protocol = IPPROTO_ICMP; - inet_pton(AF_INET, dstip, &pkt.ip.daddr); - inet_pton(AF_INET, srcip, &pkt.ip.saddr); - pkt.ip.check = IPChecksum(pkt.ip); - - pkt.icmp.type = ICMP_ECHO; - pkt.icmp.code = 0; - pkt.icmp.checksum = 0; - pkt.icmp.un.echo.sequence = 1; - pkt.icmp.un.echo.id = 1; - - strncpy(pkt.payload, "abcd", sizeof(pkt.payload)); - pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload)); - - return pkt; -} - -struct arp_pkt { - pihdr pi; - struct ethhdr eth; - struct arphdr arp; - uint8_t arp_sha[ETH_ALEN]; - uint8_t arp_spa[kIPLen]; - uint8_t arp_tha[ETH_ALEN]; - uint8_t arp_tpa[kIPLen]; -} __attribute__((packed)); - -std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, - const uint8_t dstmac[ETH_ALEN], const char* dstip) { - std::string buffer; - buffer.resize(sizeof(arp_pkt)); - - arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]); - { - pkt->pi.pi_protocol = htons(ETH_P_ARP); - - memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest)); - memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source)); - pkt->eth.h_proto = htons(ETH_P_ARP); - - pkt->arp.ar_hrd = htons(ARPHRD_ETHER); - pkt->arp.ar_pro = htons(ETH_P_IP); - pkt->arp.ar_hln = ETH_ALEN; - pkt->arp.ar_pln = kIPLen; - pkt->arp.ar_op = htons(ARPOP_REPLY); - - memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha)); - inet_pton(AF_INET, srcip, pkt->arp_spa); - memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha)); - inet_pton(AF_INET, dstip, pkt->arp_tpa); - } - return buffer; -} - -} // namespace - -TEST(TuntapStaticTest, NetTunExists) { - struct stat statbuf; - ASSERT_THAT(stat(kDevNetTun, &statbuf), SyscallSucceeds()); - // Check that it's a character device with rw-rw-rw- permissions. - EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); -} - -class TuntapTest : public ::testing::Test { - protected: - void TearDown() override { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) { - // Bring back capability if we had dropped it in test case. - ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); - } - } -}; - -TEST_F(TuntapTest, CreateInterfaceNoCap) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - - ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); - - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); - - struct ifreq ifr = {}; - ifr.ifr_flags = IFF_TAP; - strncpy(ifr.ifr_name, kTapName, IFNAMSIZ); - - EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM)); -} - -TEST_F(TuntapTest, CreateFixedNameInterface) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); - - struct ifreq ifr_set = {}; - ifr_set.ifr_flags = IFF_TAP; - strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); - EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), - SyscallSucceedsWithValue(0)); - - struct ifreq ifr_get = {}; - EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), - SyscallSucceedsWithValue(0)); - - struct ifreq ifr_expect = ifr_set; - // See __tun_chr_ioctl() in net/drivers/tun.c. - ifr_expect.ifr_flags |= IFF_NOFILTER; - - EXPECT_THAT(DumpLinkNames(), - IsPosixErrorOkAndHolds(::testing::Contains(kTapName))); - EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0)); -} - -TEST_F(TuntapTest, CreateInterface) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); - - struct ifreq ifr = {}; - ifr.ifr_flags = IFF_TAP; - // Empty ifr.ifr_name. Let kernel assign. - - EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); - - struct ifreq ifr_get = {}; - EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), - SyscallSucceedsWithValue(0)); - - std::string ifname = ifr_get.ifr_name; - EXPECT_THAT(ifname, ::testing::StartsWith("tap")); - EXPECT_THAT(DumpLinkNames(), - IsPosixErrorOkAndHolds(::testing::Contains(ifname))); -} - -TEST_F(TuntapTest, InvalidReadWrite) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); - - char buf[128] = {}; - EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); - EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); -} - -TEST_F(TuntapTest, WriteToDownDevice) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - - // FIXME: gVisor always creates enabled/up'd interfaces. - SKIP_IF(IsRunningOnGvisor()); - - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); - - // Device created should be down by default. - struct ifreq ifr = {}; - ifr.ifr_flags = IFF_TAP; - EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); - - char buf[128] = {}; - EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO)); -} - -// This test sets up a TAP device and pings kernel by sending ICMP echo request. -// -// It works as the following: -// * Open /dev/net/tun, and create kTapName interface. -// * Use rtnetlink to do initial setup of the interface: -// * Assign IP address 10.0.0.1/24 to kernel. -// * MAC address: kMacA -// * Bring up the interface. -// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel. -// * Loop to receive packets from TAP device/fd: -// * If packet is an ICMP echo reply, it stops and passes the test. -// * If packet is an ARP request, it responds with canned reply and resends -// the -// ICMP request packet. -TEST_F(TuntapTest, PingKernel) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - - // Interface creation. - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); - - struct ifreq ifr_set = {}; - ifr_set.ifr_flags = IFF_TAP; - strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); - EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), - SyscallSucceedsWithValue(0)); - - absl::optional<Link> link = - ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName)); - ASSERT_TRUE(link.has_value()); - - // Interface setup. - struct in_addr addr; - inet_pton(AF_INET, "10.0.0.1", &addr); - EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24, - &addr, sizeof(addr))); - - if (!IsRunningOnGvisor()) { - // FIXME: gVisor doesn't support setting MAC address on interfaces yet. - EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA))); - - // FIXME: gVisor always creates enabled/up'd interfaces. - EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP)); - } - - ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); - std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); - - // Send ping, this would trigger an ARP request on Linux. - EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), - SyscallSucceedsWithValue(sizeof(ping_req))); - - // Receive loop to process inbound packets. - struct inpkt { - union { - pihdr pi; - ping_pkt ping; - arp_pkt arp; - }; - }; - while (1) { - inpkt r = {}; - int n = read(fd.get(), &r, sizeof(r)); - EXPECT_THAT(n, SyscallSucceeds()); - - if (n < sizeof(pihdr)) { - std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol - << " len: " << n << std::endl; - continue; - } - - // Process ARP packet. - if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { - // Respond with canned ARP reply. - EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()), - SyscallSucceedsWithValue(arp_rep.size())); - // First ping request might have been dropped due to mac address not in - // ARP cache. Send it again. - EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), - SyscallSucceedsWithValue(sizeof(ping_req))); - } - - // Process ping response packet. - if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol && - r.ping.ip.protocol == ping_req.ip.protocol && - !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) && - !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) && - r.ping.icmp.type == 0 && r.ping.icmp.code == 0) { - // Ends and passes the test. - break; - } - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/tuntap_hostinet.cc b/test/syscalls/linux/tuntap_hostinet.cc deleted file mode 100644 index 1513fb9d5..000000000 --- a/test/syscalls/linux/tuntap_hostinet.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(TuntapHostInetTest, NoNetTun) { - SKIP_IF(!IsRunningOnGvisor()); - SKIP_IF(!IsRunningWithHostinet()); - - struct stat statbuf; - ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallFailsWithErrno(ENOENT)); -} - -} // namespace -} // namespace testing - -} // namespace gvisor diff --git a/test/syscalls/linux/udp_bind.cc b/test/syscalls/linux/udp_bind.cc deleted file mode 100644 index 6d92bdbeb..000000000 --- a/test/syscalls/linux/udp_bind.cc +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <arpa/inet.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -struct sockaddr_in_common { - sa_family_t sin_family; - in_port_t sin_port; -}; - -struct SendtoTestParam { - // Human readable description of test parameter. - std::string description; - - // Test is broken in gVisor, skip. - bool skip_on_gvisor; - - // Domain for the socket that will do the sending. - int send_domain; - - // Address to bind for the socket that will do the sending. - struct sockaddr_storage send_addr; - socklen_t send_addr_len; // 0 for unbound. - - // Address to connect to for the socket that will do the sending. - struct sockaddr_storage connect_addr; - socklen_t connect_addr_len; // 0 for no connection. - - // Domain for the socket that will do the receiving. - int recv_domain; - - // Address to bind for the socket that will do the receiving. - struct sockaddr_storage recv_addr; - socklen_t recv_addr_len; - - // Address to send to. - struct sockaddr_storage sendto_addr; - socklen_t sendto_addr_len; - - // Expected errno for the sendto call. - std::vector<int> sendto_errnos; // empty on success. -}; - -class SendtoTest : public ::testing::TestWithParam<SendtoTestParam> { - protected: - SendtoTest() { - // gUnit uses printf, so so will we. - printf("Testing with %s\n", GetParam().description.c_str()); - } -}; - -TEST_P(SendtoTest, Sendto) { - auto param = GetParam(); - - SKIP_IF(param.skip_on_gvisor && IsRunningOnGvisor()); - - const FileDescriptor s1 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(param.send_domain, SOCK_DGRAM, 0)); - const FileDescriptor s2 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(param.recv_domain, SOCK_DGRAM, 0)); - - if (param.send_addr_len > 0) { - ASSERT_THAT(bind(s1.get(), reinterpret_cast<sockaddr*>(¶m.send_addr), - param.send_addr_len), - SyscallSucceeds()); - } - - if (param.connect_addr_len > 0) { - ASSERT_THAT( - connect(s1.get(), reinterpret_cast<sockaddr*>(¶m.connect_addr), - param.connect_addr_len), - SyscallSucceeds()); - } - - ASSERT_THAT(bind(s2.get(), reinterpret_cast<sockaddr*>(¶m.recv_addr), - param.recv_addr_len), - SyscallSucceeds()); - - struct sockaddr_storage real_recv_addr = {}; - socklen_t real_recv_addr_len = param.recv_addr_len; - ASSERT_THAT( - getsockname(s2.get(), reinterpret_cast<sockaddr*>(&real_recv_addr), - &real_recv_addr_len), - SyscallSucceeds()); - - ASSERT_EQ(real_recv_addr_len, param.recv_addr_len); - - int recv_port = - reinterpret_cast<sockaddr_in_common*>(&real_recv_addr)->sin_port; - - struct sockaddr_storage sendto_addr = param.sendto_addr; - reinterpret_cast<sockaddr_in_common*>(&sendto_addr)->sin_port = recv_port; - - char buf[20] = {}; - if (!param.sendto_errnos.empty()) { - ASSERT_THAT(RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr), - param.sendto_addr_len), - SyscallFailsWithErrno(ElementOf(param.sendto_errnos))); - return; - } - - ASSERT_THAT(RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr), - param.sendto_addr_len), - SyscallSucceedsWithValue(sizeof(buf))); - - struct sockaddr_storage got_addr = {}; - socklen_t got_addr_len = sizeof(sockaddr_storage); - ASSERT_THAT(RetryEINTR(recvfrom)(s2.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&got_addr), - &got_addr_len), - SyscallSucceedsWithValue(sizeof(buf))); - - ASSERT_GT(got_addr_len, sizeof(sockaddr_in_common)); - int got_port = reinterpret_cast<sockaddr_in_common*>(&got_addr)->sin_port; - - struct sockaddr_storage sender_addr = {}; - socklen_t sender_addr_len = sizeof(sockaddr_storage); - ASSERT_THAT(getsockname(s1.get(), reinterpret_cast<sockaddr*>(&sender_addr), - &sender_addr_len), - SyscallSucceeds()); - - ASSERT_GT(sender_addr_len, sizeof(sockaddr_in_common)); - int sender_port = - reinterpret_cast<sockaddr_in_common*>(&sender_addr)->sin_port; - - EXPECT_EQ(got_port, sender_port); -} - -socklen_t Ipv4Addr(sockaddr_storage* addr, int port = 0) { - auto addr4 = reinterpret_cast<sockaddr_in*>(addr); - addr4->sin_family = AF_INET; - addr4->sin_port = port; - inet_pton(AF_INET, "127.0.0.1", &addr4->sin_addr.s_addr); - return sizeof(struct sockaddr_in); -} - -socklen_t Ipv6Addr(sockaddr_storage* addr, int port = 0) { - auto addr6 = reinterpret_cast<sockaddr_in6*>(addr); - addr6->sin6_family = AF_INET6; - addr6->sin6_port = port; - inet_pton(AF_INET6, "::1", &addr6->sin6_addr.s6_addr); - return sizeof(struct sockaddr_in6); -} - -socklen_t Ipv4MappedIpv6Addr(sockaddr_storage* addr, int port = 0) { - auto addr6 = reinterpret_cast<sockaddr_in6*>(addr); - addr6->sin6_family = AF_INET6; - addr6->sin6_port = port; - inet_pton(AF_INET6, "::ffff:127.0.0.1", &addr6->sin6_addr.s6_addr); - return sizeof(struct sockaddr_in6); -} - -INSTANTIATE_TEST_SUITE_P( - UdpBindTest, SendtoTest, - ::testing::Values( - []() { - SendtoTestParam param = {}; - param.description = "IPv4 mapped IPv6 sendto IPv4 mapped IPv6"; - param.send_domain = AF_INET6; - param.send_addr_len = Ipv4MappedIpv6Addr(¶m.send_addr); - param.recv_domain = AF_INET6; - param.recv_addr_len = Ipv4MappedIpv6Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4MappedIpv6Addr(¶m.sendto_addr); - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "IPv6 sendto IPv6"; - param.send_domain = AF_INET6; - param.send_addr_len = Ipv6Addr(¶m.send_addr); - param.recv_domain = AF_INET6; - param.recv_addr_len = Ipv6Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv6Addr(¶m.sendto_addr); - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "IPv4 sendto IPv4"; - param.send_domain = AF_INET; - param.send_addr_len = Ipv4Addr(¶m.send_addr); - param.recv_domain = AF_INET; - param.recv_addr_len = Ipv4Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4Addr(¶m.sendto_addr); - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "IPv4 mapped IPv6 sendto IPv4"; - param.send_domain = AF_INET6; - param.send_addr_len = Ipv4MappedIpv6Addr(¶m.send_addr); - param.recv_domain = AF_INET; - param.recv_addr_len = Ipv4Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4MappedIpv6Addr(¶m.sendto_addr); - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "IPv4 sendto IPv4 mapped IPv6"; - param.send_domain = AF_INET; - param.send_addr_len = Ipv4Addr(¶m.send_addr); - param.recv_domain = AF_INET6; - param.recv_addr_len = Ipv4MappedIpv6Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4Addr(¶m.sendto_addr); - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "unbound IPv6 sendto IPv4 mapped IPv6"; - param.send_domain = AF_INET6; - param.recv_domain = AF_INET6; - param.recv_addr_len = Ipv4MappedIpv6Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4MappedIpv6Addr(¶m.sendto_addr); - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "unbound IPv6 sendto IPv4"; - param.send_domain = AF_INET6; - param.recv_domain = AF_INET; - param.recv_addr_len = Ipv4Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4MappedIpv6Addr(¶m.sendto_addr); - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "IPv6 sendto IPv4"; - param.send_domain = AF_INET6; - param.send_addr_len = Ipv6Addr(¶m.send_addr); - param.recv_domain = AF_INET; - param.recv_addr_len = Ipv4Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4MappedIpv6Addr(¶m.sendto_addr); - param.sendto_errnos = {ENETUNREACH}; - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "IPv4 mapped IPv6 sendto IPv6"; - param.send_domain = AF_INET6; - param.send_addr_len = Ipv4MappedIpv6Addr(¶m.send_addr); - param.recv_domain = AF_INET6; - param.recv_addr_len = Ipv6Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv6Addr(¶m.sendto_addr); - param.sendto_errnos = {EAFNOSUPPORT}; - // The errno returned changed in Linux commit c8e6ad0829a723. - param.sendto_errnos = {EINVAL, EAFNOSUPPORT}; - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "connected IPv4 mapped IPv6 sendto IPv6"; - param.send_domain = AF_INET6; - param.connect_addr_len = - Ipv4MappedIpv6Addr(¶m.connect_addr, 5000); - param.recv_domain = AF_INET6; - param.recv_addr_len = Ipv6Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv6Addr(¶m.sendto_addr); - // The errno returned changed in Linux commit c8e6ad0829a723. - param.sendto_errnos = {EINVAL, EAFNOSUPPORT}; - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "connected IPv6 sendto IPv4 mapped IPv6"; - // TODO(igudger): Determine if this inconsistent behavior is worth - // implementing. - param.skip_on_gvisor = true; - param.send_domain = AF_INET6; - param.connect_addr_len = Ipv6Addr(¶m.connect_addr, 5000); - param.recv_domain = AF_INET6; - param.recv_addr_len = Ipv4MappedIpv6Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4MappedIpv6Addr(¶m.sendto_addr); - return param; - }(), - []() { - SendtoTestParam param = {}; - param.description = "connected IPv6 sendto IPv4"; - // TODO(igudger): Determine if this inconsistent behavior is worth - // implementing. - param.skip_on_gvisor = true; - param.send_domain = AF_INET6; - param.connect_addr_len = Ipv6Addr(¶m.connect_addr, 5000); - param.recv_domain = AF_INET; - param.recv_addr_len = Ipv4Addr(¶m.recv_addr); - param.sendto_addr_len = Ipv4MappedIpv6Addr(¶m.sendto_addr); - return param; - }())); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc deleted file mode 100644 index 7a8ac30a4..000000000 --- a/test/syscalls/linux/udp_socket.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/udp_socket_test_cases.h" - -namespace gvisor { -namespace testing { - -namespace { - -INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, - ::testing::Values(AddressFamily::kIpv4, - AddressFamily::kIpv6, - AddressFamily::kDualStack)); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc deleted file mode 100644 index fcdba7279..000000000 --- a/test/syscalls/linux/udp_socket_errqueue_test_case.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __fuchsia__ - -#include <arpa/inet.h> -#include <fcntl.h> -#include <linux/errqueue.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/udp_socket_test_cases.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(UdpSocketTest, ErrorQueue) { - char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. - EXPECT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_ERRQUEUE), - SyscallFailsWithErrno(EAGAIN)); -} - -} // namespace testing -} // namespace gvisor - -#endif // __fuchsia__ diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc deleted file mode 100644 index 740c7986d..000000000 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ /dev/null @@ -1,1499 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/udp_socket_test_cases.h" - -#include <arpa/inet.h> -#include <fcntl.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#ifndef SIOCGSTAMP -#include <linux/sockios.h> -#endif - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -// Gets a pointer to the port component of the given address. -uint16_t* Port(struct sockaddr_storage* addr) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - return &sin->sin_port; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - return &sin6->sin6_port; - } - } - - return nullptr; -} - -void UdpSocketTest::SetUp() { - int type; - if (GetParam() == AddressFamily::kIpv4) { - type = AF_INET; - auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - type = AF_INET6; - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin6); - if (GetParam() == AddressFamily::kIpv6) { - sin6->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - sin6->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - - ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - - memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_)); - anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_); - anyaddr_->sa_family = type; - - if (gvisor::testing::IsRunningOnGvisor()) { - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - ports_[i] = TestPort + i; - } - } else { - // When not under gvisor, use utility function to pick port. Assert that - // all ports are different. - std::string error; - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - // Find an unused port, we specify port 0 to allow the kernel to provide - // the port. - bool unique = true; - do { - ports_[i] = ASSERT_NO_ERRNO_AND_VALUE(PortAvailable( - 0, AddressFamily::kDualStack, SocketType::kUdp, false)); - ASSERT_GT(ports_[i], 0); - for (size_t j = 0; j < i; ++j) { - if (ports_[j] == ports_[i]) { - unique = false; - break; - } - } - } while (!unique); - } - } - - // Initialize the sockaddrs. - for (size_t i = 0; i < ABSL_ARRAYSIZE(addr_); ++i) { - memset(&addr_storage_[i], 0, sizeof(addr_storage_[i])); - - addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]); - addr_[i]->sa_family = type; - - switch (type) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]); - sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); - sin->sin_port = htons(ports_[i]); - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr_[i]); - sin6->sin6_addr = in6addr_loopback; - sin6->sin6_port = htons(ports_[i]); - break; - } - } - } -} - -TEST_P(UdpSocketTest, Creation) { - int type = AF_INET6; - if (GetParam() == AddressFamily::kIpv4) { - type = AF_INET; - } - - int s_; - - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - EXPECT_THAT(close(s_), SyscallSucceeds()); - - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds()); - EXPECT_THAT(close(s_), SyscallSucceeds()); - - ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails()); -} - -TEST_P(UdpSocketTest, Getsockname) { - // Check that we're not bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, anyaddr_, addrlen_), 0); - - // Bind, then check that we get the right address. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, Getpeername) { - // Check that we're not connected. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Connect, then check that we get the right address. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, SendNotConnected) { - // Do send & write, they must fail. - char buf[512]; - EXPECT_THAT(send(s_, buf, sizeof(buf), 0), - SyscallFailsWithErrno(EDESTADDRREQ)); - - EXPECT_THAT(write(s_, buf, sizeof(buf)), SyscallFailsWithErrno(EDESTADDRREQ)); - - // Use sendto. - ASSERT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ConnectBinds) { - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ReceiveNotBound) { - char buf[512]; - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, Bind) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Try to bind again. - EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); - - // Check that we're still bound to the original address. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, BindInUse) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Try to bind again. - EXPECT_THAT(bind(t_, addr_[0], addrlen_), SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(UdpSocketTest, ReceiveAfterConnect) { - // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Get the address s_ was bound to during connect. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - // Send from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { - // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Get the address s_ was bound to during connect. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - for (int i = 0; i < 2; i++) { - // Send from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Disconnect s_. - struct sockaddr addr = {}; - addr.sa_family = AF_UNSPEC; - ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds()); - // Connect s_ loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - } -} - -TEST_P(UdpSocketTest, Connect) { - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[0], addrlen_), 0); - - // Try to bind after connect. - EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); - - // Try to connect again. - EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Check that peer name changed. - peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); -} - -void ConnectAny(AddressFamily family, int sockfd, uint16_t port) { - struct sockaddr_storage addr = {}; - - // Precondition check. - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr any = IN6ADDR_ANY_INIT; - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0); - } - - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - } - - struct sockaddr_storage baddr = {}; - if (family == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin_family = AF_INET; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - addr_in->sin_port = port; - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - if (family == AddressFamily::kIpv6) { - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - addr_in->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - - // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. - if (port == 0) { - SKIP_IF(IsRunningOnGvisor()); - } - - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), - SyscallSucceeds()); - } - - // Postcondition check. - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback; - if (family == AddressFamily::kIpv6) { - loopback = IN6ADDR_LOOPBACK_INIT; - } else { - TestAddress const& v4_mapped_loopback = V4MappedLoopback(); - loopback = reinterpret_cast<const struct sockaddr_in6*>( - &v4_mapped_loopback.addr) - ->sin6_addr; - } - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } - - addrlen = sizeof(addr); - if (port == 0) { - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - } else { - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - } - } -} - -TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); } - -TEST_P(UdpSocketTest, ConnectAnyWithPort) { - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - ConnectAny(GetParam(), s_, port); -} - -void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) { - struct sockaddr_storage addr = {}; - - socklen_t addrlen = sizeof(addr); - struct sockaddr_storage baddr = {}; - if (family == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin_family = AF_INET; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - addr_in->sin_port = port; - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - if (family == AddressFamily::kIpv6) { - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - addr_in->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - - // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. - if (port == 0) { - SKIP_IF(IsRunningOnGvisor()); - } - - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), - SyscallSucceeds()); - // Now the socket is bound to the loopback address. - - // Disconnect - addrlen = sizeof(addr); - addr.ss_family = AF_UNSPEC; - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - // Check that after disconnect the socket is bound to the ANY address. - EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback = IN6ADDR_ANY_INIT; - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { - DisconnectAfterConnectAny(GetParam(), s_, 0); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - DisconnectAfterConnectAny(GetParam(), s_, port); -} - -TEST_P(UdpSocketTest, DisconnectAfterBind) { - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { - struct sockaddr_storage baddr = {}; - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addr_in->sin_family = AF_INET; - addr_in->sin_port = port; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - addr_in->sin6_scope_id = 0; - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } - ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), - SyscallSucceeds()); - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr = {}; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - // If the socket is bound to ANY and connected to a loopback address, - // getsockname() has to return the loopback address. - if (GetParam() == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { - struct sockaddr_storage baddr = {}; - socklen_t addrlen; - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addr_in->sin_family = AF_INET; - addr_in->sin_port = port; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - addr_in->sin6_scope_id = 0; - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } - ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), - SyscallSucceeds()); - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, Disconnect) { - for (int i = 0; i < 2; i++) { - // Try to connect again. - EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); - - // Try to disconnect. - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&addr), 0); - } -} - -TEST_P(UdpSocketTest, ConnectBadAddress) { - struct sockaddr addr = {}; - addr.sa_family = addr_[0]->sa_family; - ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send to a different destination than we're connected to. - char buf[512]; - EXPECT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[1], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { - // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. - SKIP_IF(IsRunningWithHostinet()); - - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from s_ to t_. - ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); - // Receive the packet. - char received[3]; - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { - // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. - SKIP_IF(IsRunningWithHostinet()); - - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Set t_ to non-blocking. - int opts = 0; - ASSERT_THAT(opts = fcntl(t_, F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(t_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from s_ to t_. - ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); - // Receive the packet. - char received[3]; - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send some data to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, SendAndReceiveConnected) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveFromNotConnected) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that the data isn't_ received because it was sent from a different - // address than we're connected. - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveBeforeConnect) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Connect to loopback:TestPort+1. - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Receive the data. It works because it was sent before the connect. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Send again. This time it should not be received. - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveFrom) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data and sender address. - char received[sizeof(buf)]; - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0, - reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); -} - -TEST_P(UdpSocketTest, Listen) { - ASSERT_THAT(listen(s_, SOMAXCONN), SyscallFailsWithErrno(EOPNOTSUPP)); -} - -TEST_P(UdpSocketTest, Accept) { - ASSERT_THAT(accept(s_, nullptr, nullptr), SyscallFailsWithErrno(EOPNOTSUPP)); -} - -// This test validates that a read shutdown with pending data allows the read -// to proceed with the data before returning EAGAIN. -TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { - char received[512]; - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Verify that we get EWOULDBLOCK when there is nothing to read. - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - const char* buf = "abc"; - EXPECT_THAT(write(t_, buf, 3), SyscallSucceedsWithValue(3)); - - int opts = 0; - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(s_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - ASSERT_NE(opts & O_NONBLOCK, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - // We should get the data even though read has been shutdown. - EXPECT_THAT(recv(s_, received, 2, 0), SyscallSucceedsWithValue(2)); - - // Because we read less than the entire packet length, since it's a packet - // based socket any subsequent reads should return EWOULDBLOCK. - EXPECT_THAT(recv(s_, received, 1, 0), SyscallFailsWithErrno(EWOULDBLOCK)); -} - -// This test is validating that even after a socket is shutdown if it's -// reconnected it will reset the shutdown state. -TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReadShutdown) { - // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without - // MSG_DONTWAIT blocks indefinitely. - SKIP_IF(IsRunningWithHostinet()); - - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { - // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without - // MSG_DONTWAIT blocks indefinitely. - SKIP_IF(IsRunningWithHostinet()); - - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then shutdown from another thread. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - EXPECT_THAT(shutdown(this->s_, SHUT_RD), SyscallSucceeds()); - }); - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); - t.Join(); - - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, WriteShutdown) { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, SynchronousReceive) { - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send some data to s_ from another thread. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - // Receive the data prior to actually starting the other thread. - char received[512]; - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Start the thread. - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - ASSERT_THAT( - sendto(this->t_, buf, sizeof(buf), 0, this->addr_[0], this->addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - }); - - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(512)); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(psize)); - } - - // Receive the data as 3 separate packets. - char received[6 * psize]; - for (int i = 0; i < 3; ++i) { - EXPECT_THAT(recv(s_, received + i * psize, 3 * psize, 0), - SyscallSucceedsWithValue(psize)); - } - EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Direct writes from t_ to s_. - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 2 packets from t_ to s_, where each packet's data consists of 2 - // discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(writev(t_, iov, 2), SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(readv(s_, iov, 3), SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 2 packets from t_ to s_, where each packet's data consists of 2 - // discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_name = addr_[0]; - msg.msg_namelen = addrlen_; - msg.msg_iov = iov; - msg.msg_iovlen = 2; - ASSERT_THAT(sendmsg(t_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_iov = iov; - msg.msg_iovlen = 3; - ASSERT_THAT(recvmsg(s_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, FIONREADShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, FIONREADWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(s_, str, sizeof(str), 0), - SyscallSucceedsWithValue(sizeof(str))); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); -} - -TEST_P(UdpSocketTest, Fionread) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(psize)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, psize); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, 0, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(0)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(s_, str, 0, 0), SyscallSucceedsWithValue(0)); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, SoTimestampOffByDefault) { - // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestamp) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - int v = 1; - ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); - - struct timeval tv = {}; - memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); - - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // There should be nothing to get via ioctl. - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UdpSocketTest, WriteShutdownNotConnected) { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, TimestampIoctl) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); -} - -TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); -} - -// Test that the timestamp accessed via SIOCGSTAMP is still accessible after -// SO_TIMESTAMP is enabled and used to retrieve a timestamp. -TEST_P(UdpSocketTest, TimestampIoctlPersistence) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // Enable SO_TIMESTAMP and send a message. - int v = 1; - EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - // There should be a message for SO_TIMESTAMP. - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg = {}; - iovec iov = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - - // The ioctl should return the exact same values as before. - struct timeval tv2 = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv2), SyscallSucceeds()); - ASSERT_EQ(tv.tv_sec, tv2.tv_sec); - ASSERT_EQ(tv.tv_usec, tv2.tv_usec); -} - -// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on -// outgoing packets, and that a receiving socket with IP_RECVTOS or -// IPV6_RECVTCLASS will create the corresponding control message. -TEST_P(UdpSocketTest, SetAndReceiveTOS) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Allow socket to receive control message. - int recv_level = SOL_IP; - int recv_type = IP_RECVTOS; - if (GetParam() != AddressFamily::kIpv4) { - recv_level = SOL_IPV6; - recv_type = IPV6_RECVTCLASS; - } - ASSERT_THAT( - setsockopt(s_, recv_level, recv_type, &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Set socket TOS. - int sent_level = recv_level; - int sent_type = IP_TOS; - if (sent_level == SOL_IPV6) { - sent_type = IPV6_TCLASS; - } - int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - ASSERT_THAT( - setsockopt(t_, sent_level, sent_type, &sent_tos, sizeof(sent_tos)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - struct msghdr sent_msg = {}; - struct iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = &sent_data[0]; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(t_, &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - // Receive message. - struct msghdr received_msg = {}; - struct iovec received_iov = {}; - char received_data[kDataLength]; - received_iov.iov_base = &received_data[0]; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - size_t cmsg_data_len = sizeof(int8_t); - if (sent_type == IPV6_TCLASS) { - cmsg_data_len = sizeof(int); - } - std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - received_msg.msg_control = &received_cmsgbuf[0]; - received_msg.msg_controllen = received_cmsgbuf.size(); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, sent_level); - EXPECT_EQ(cmsg->cmsg_type, sent_type); - int8_t received_tos = 0; - memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); - EXPECT_EQ(received_tos, sent_tos); -} - -// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the -// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or -// IPV6_RECVTCLASS will create the corresponding control message. -TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. - SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Allow socket to receive control message. - int recv_level = SOL_IP; - int recv_type = IP_RECVTOS; - if (GetParam() != AddressFamily::kIpv4) { - recv_level = SOL_IPV6; - recv_type = IPV6_RECVTCLASS; - } - int recv_opt = kSockOptOn; - ASSERT_THAT( - setsockopt(s_, recv_level, recv_type, &recv_opt, sizeof(recv_opt)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - int sent_level = recv_level; - int sent_type = IP_TOS; - int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - - struct msghdr sent_msg = {}; - struct iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = &sent_data[0]; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - size_t cmsg_data_len = sizeof(int8_t); - if (sent_level == SOL_IPV6) { - sent_type = IPV6_TCLASS; - cmsg_data_len = sizeof(int); - } - std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - sent_msg.msg_control = &sent_cmsgbuf[0]; - sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - - // Manually add control message. - struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); - sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); - sent_cmsg->cmsg_level = sent_level; - sent_cmsg->cmsg_type = sent_type; - *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; - - ASSERT_THAT(RetryEINTR(sendmsg)(t_, &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - // Receive message. - struct msghdr received_msg = {}; - struct iovec received_iov = {}; - char received_data[kDataLength]; - received_iov.iov_base = &received_data[0]; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - received_msg.msg_control = &received_cmsgbuf[0]; - received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, sent_level); - EXPECT_EQ(cmsg->cmsg_type, sent_type); - int8_t received_tos = 0; - memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); - EXPECT_EQ(received_tos, sent_tos); -} -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h deleted file mode 100644 index 2fd79d99e..000000000 --- a/test/syscalls/linux/udp_socket_test_cases.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ -#define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// The initial port to be be used on gvisor. -constexpr int TestPort = 40000; - -// Fixture for tests parameterized by the address family to use (AF_INET and -// AF_INET6) when creating sockets. -class UdpSocketTest - : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { - protected: - // Creates two sockets that will be used by test cases. - void SetUp() override; - - // Closes the sockets created by SetUp(). - void TearDown() override { - EXPECT_THAT(close(s_), SyscallSucceeds()); - EXPECT_THAT(close(t_), SyscallSucceeds()); - - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - ASSERT_NO_ERRNO(FreeAvailablePort(ports_[i])); - } - } - - // First UDP socket. - int s_; - - // Second UDP socket. - int t_; - - // The length of the socket address. - socklen_t addrlen_; - - // Initialized address pointing to loopback and port TestPort+i. - struct sockaddr* addr_[3]; - - // Initialize "any" address. - struct sockaddr* anyaddr_; - - // Used ports. - int ports_[3]; - - private: - // Storage for the loopback addresses. - struct sockaddr_storage addr_storage_[3]; - - // Storage for the "any" address. - struct sockaddr_storage anyaddr_storage_; -}; - -} // namespace testing -} // namespace gvisor - -#endif // THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc deleted file mode 100644 index 6218fbce1..000000000 --- a/test/syscalls/linux/uidgid.cc +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <grp.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "test/util/capability_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/uid_util.h" - -ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID"); -ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); -ABSL_FLAG(int32_t, scratch_gid1, 65534, "first scratch GID"); -ABSL_FLAG(int32_t, scratch_gid2, 65533, "second scratch GID"); - -using ::testing::UnorderedElementsAreArray; - -namespace gvisor { -namespace testing { - -namespace { - -TEST(UidGidTest, Getuid) { - uid_t ruid, euid, suid; - EXPECT_THAT(getresuid(&ruid, &euid, &suid), SyscallSucceeds()); - EXPECT_THAT(getuid(), SyscallSucceedsWithValue(ruid)); - EXPECT_THAT(geteuid(), SyscallSucceedsWithValue(euid)); -} - -TEST(UidGidTest, Getgid) { - gid_t rgid, egid, sgid; - EXPECT_THAT(getresgid(&rgid, &egid, &sgid), SyscallSucceeds()); - EXPECT_THAT(getgid(), SyscallSucceedsWithValue(rgid)); - EXPECT_THAT(getegid(), SyscallSucceedsWithValue(egid)); -} - -TEST(UidGidTest, Getgroups) { - // "If size is zero, list is not modified, but the total number of - // supplementary group IDs for the process is returned." - getgroups(2) - int nr_groups; - ASSERT_THAT(nr_groups = getgroups(0, nullptr), SyscallSucceeds()); - std::vector<gid_t> list(nr_groups); - EXPECT_THAT(getgroups(list.size(), list.data()), SyscallSucceeds()); - - // "EINVAL: size is less than the number of supplementary group IDs, but is - // not zero." - EXPECT_THAT(getgroups(-1, nullptr), SyscallFailsWithErrno(EINVAL)); - - // Testing for EFAULT requires actually having groups, which isn't guaranteed - // here; see the setgroups test below. -} - -// Checks that the calling process' real/effective/saved user IDs are -// ruid/euid/suid respectively. -PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) { - uid_t actual_ruid, actual_euid, actual_suid; - int rc = getresuid(&actual_ruid, &actual_euid, &actual_suid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresuid"); - } - if (ruid != actual_ruid || euid != actual_euid || suid != actual_suid) { - return PosixError( - EPERM, absl::StrCat( - "incorrect user IDs: got (", - absl::StrJoin({actual_ruid, actual_euid, actual_suid}, ", "), - ", wanted (", absl::StrJoin({ruid, euid, suid}, ", "), ")")); - } - return NoError(); -} - -PosixError CheckGIDs(gid_t rgid, gid_t egid, gid_t sgid) { - gid_t actual_rgid, actual_egid, actual_sgid; - int rc = getresgid(&actual_rgid, &actual_egid, &actual_sgid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresgid"); - } - if (rgid != actual_rgid || egid != actual_egid || sgid != actual_sgid) { - return PosixError( - EPERM, absl::StrCat( - "incorrect group IDs: got (", - absl::StrJoin({actual_rgid, actual_egid, actual_sgid}, ", "), - ", wanted (", absl::StrJoin({rgid, egid, sgid}, ", "), ")")); - } - return NoError(); -} - -// N.B. These tests may break horribly unless run via a gVisor test runner, -// because changing UID in one test may forfeit permissions required by other -// tests. (The test runner runs each test in a separate process.) - -TEST(UidGidRootTest, Setuid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - - // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. After calling - // setuid(non-zero-UID), there is no way to get root privileges back. - ScopedThread([&] { - // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. POSIX threads, however, require that all - // threads have the same UIDs, so using the setuid wrapper sets all threads' - // real UID. - EXPECT_THAT(syscall(SYS_setuid, -1), SyscallFailsWithErrno(EINVAL)); - - const uid_t uid = absl::GetFlag(FLAGS_scratch_uid1); - EXPECT_THAT(syscall(SYS_setuid, uid), SyscallSucceeds()); - // "If the effective UID of the caller is root (more precisely: if the - // caller has the CAP_SETUID capability), the real UID and saved set-user-ID - // are also set." - setuid(2) - EXPECT_NO_ERRNO(CheckUIDs(uid, uid, uid)); - }); -} - -TEST(UidGidRootTest, Setgid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - - EXPECT_THAT(setgid(-1), SyscallFailsWithErrno(EINVAL)); - - const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); - ASSERT_THAT(setgid(gid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); -} - -TEST(UidGidRootTest, SetgidNotFromThreadGroupLeader) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - - const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); - // NOTE(b/64676707): Do setgid in a separate thread so that we can test if - // info.si_pid is set correctly. - ScopedThread([gid] { ASSERT_THAT(setgid(gid), SyscallSucceeds()); }); - EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); -} - -TEST(UidGidRootTest, Setreuid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - - // "Supplying a value of -1 for either the real or effective user ID forces - // the system to leave that ID unchanged." - setreuid(2) - EXPECT_THAT(setreuid(-1, -1), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckUIDs(0, 0, 0)); - - // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. After calling - // setuid(non-zero-UID), there is no way to get root privileges back. - ScopedThread([&] { - const uid_t ruid = absl::GetFlag(FLAGS_scratch_uid1); - const uid_t euid = absl::GetFlag(FLAGS_scratch_uid2); - - // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. posix threads, however, require that all - // threads have the same UIDs, so using the setuid wrapper sets all threads' - // real UID. - EXPECT_THAT(syscall(SYS_setreuid, ruid, euid), SyscallSucceeds()); - - // "If the real user ID is set or the effective user ID is set to a value - // not equal to the previous real user ID, the saved set-user-ID will be set - // to the new effective user ID." - setreuid(2) - EXPECT_NO_ERRNO(CheckUIDs(ruid, euid, euid)); - }); -} - -TEST(UidGidRootTest, Setregid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - - EXPECT_THAT(setregid(-1, -1), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0)); - - const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1); - const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2); - ASSERT_THAT(setregid(rgid, egid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid)); -} - -TEST(UidGidRootTest, Setresuid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - - // "If one of the arguments equals -1, the corresponding value is not - // changed." - setresuid(2) - EXPECT_THAT(setresuid(-1, -1, -1), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckUIDs(0, 0, 0)); - - // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. After calling - // setuid(non-zero-UID), there is no way to get root privileges back. - ScopedThread([&] { - const uid_t ruid = 12345; - const uid_t euid = 23456; - const uid_t suid = 34567; - - // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. posix threads, however, require that all - // threads have the same UIDs, so using the setuid wrapper sets all threads' - // real UID. - EXPECT_THAT(syscall(SYS_setresuid, ruid, euid, suid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckUIDs(ruid, euid, suid)); - }); -} - -TEST(UidGidRootTest, Setresgid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - - EXPECT_THAT(setresgid(-1, -1, -1), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0)); - - const gid_t rgid = 12345; - const gid_t egid = 23456; - const gid_t sgid = 34567; - ASSERT_THAT(setresgid(rgid, egid, sgid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, sgid)); -} - -TEST(UidGidRootTest, Setgroups) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - - std::vector<gid_t> list = {123, 500}; - ASSERT_THAT(setgroups(list.size(), list.data()), SyscallSucceeds()); - std::vector<gid_t> list2(list.size()); - ASSERT_THAT(getgroups(list2.size(), list2.data()), SyscallSucceeds()); - EXPECT_THAT(list, UnorderedElementsAreArray(list2)); - - // "EFAULT: list has an invalid address." - EXPECT_THAT(getgroups(100, reinterpret_cast<gid_t*>(-1)), - SyscallFailsWithErrno(EFAULT)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc deleted file mode 100644 index d8824b171..000000000 --- a/test/syscalls/linux/uname.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sched.h> -#include <sys/utsname.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "test/util/capability_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(UnameTest, Sanity) { - struct utsname buf; - ASSERT_THAT(uname(&buf), SyscallSucceeds()); - EXPECT_NE(strlen(buf.release), 0); - EXPECT_NE(strlen(buf.version), 0); - EXPECT_NE(strlen(buf.machine), 0); - EXPECT_NE(strlen(buf.sysname), 0); - EXPECT_NE(strlen(buf.nodename), 0); - EXPECT_NE(strlen(buf.domainname), 0); -} - -TEST(UnameTest, SetNames) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - char hostname[65]; - ASSERT_THAT(sethostname("0123456789", 3), SyscallSucceeds()); - EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); - EXPECT_EQ(absl::string_view(hostname), "012"); - - ASSERT_THAT(sethostname("0123456789\0xxx", 11), SyscallSucceeds()); - EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); - EXPECT_EQ(absl::string_view(hostname), "0123456789"); - - ASSERT_THAT(sethostname("0123456789\0xxx", 12), SyscallSucceeds()); - EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); - EXPECT_EQ(absl::string_view(hostname), "0123456789"); - - constexpr char kHostname[] = "wubbalubba"; - ASSERT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds()); - - constexpr char kDomainname[] = "dubdub.com"; - ASSERT_THAT(setdomainname(kDomainname, sizeof(kDomainname)), - SyscallSucceeds()); - - struct utsname buf; - EXPECT_THAT(uname(&buf), SyscallSucceeds()); - EXPECT_EQ(absl::string_view(buf.nodename), kHostname); - EXPECT_EQ(absl::string_view(buf.domainname), kDomainname); - - // These should just be glibc wrappers that also call uname(2). - EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); - EXPECT_EQ(absl::string_view(hostname), kHostname); - - char domainname[65]; - EXPECT_THAT(getdomainname(domainname, sizeof(domainname)), SyscallSucceeds()); - EXPECT_EQ(absl::string_view(domainname), kDomainname); -} - -TEST(UnameTest, UnprivilegedSetNames) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); - } - - EXPECT_THAT(sethostname("", 0), SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(setdomainname("", 0), SyscallFailsWithErrno(EPERM)); -} - -TEST(UnameTest, UnshareUTS) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - struct utsname init; - ASSERT_THAT(uname(&init), SyscallSucceeds()); - - ScopedThread([&]() { - EXPECT_THAT(unshare(CLONE_NEWUTS), SyscallSucceeds()); - - constexpr char kHostname[] = "wubbalubba"; - EXPECT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds()); - - char hostname[65]; - EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); - }); - - struct utsname after; - EXPECT_THAT(uname(&after), SyscallSucceeds()); - EXPECT_EQ(absl::string_view(after.nodename), init.nodename); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/unix_domain_socket_test_util.cc b/test/syscalls/linux/unix_domain_socket_test_util.cc deleted file mode 100644 index b05ab2900..000000000 --- a/test/syscalls/linux/unix_domain_socket_test_util.cc +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/unix_domain_socket_test_util.h" - -#include <sys/un.h> - -#include <vector> - -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -std::string DescribeUnixDomainSocketType(int type) { - const char* type_str = nullptr; - switch (type & ~(SOCK_NONBLOCK | SOCK_CLOEXEC)) { - case SOCK_STREAM: - type_str = "SOCK_STREAM"; - break; - case SOCK_DGRAM: - type_str = "SOCK_DGRAM"; - break; - case SOCK_SEQPACKET: - type_str = "SOCK_SEQPACKET"; - break; - } - if (!type_str) { - return absl::StrCat("Unix domain socket with unknown type ", type); - } else { - return absl::StrCat(((type & SOCK_NONBLOCK) != 0) ? "non-blocking " : "", - ((type & SOCK_CLOEXEC) != 0) ? "close-on-exec " : "", - type_str, " Unix domain socket"); - } -} - -SocketPairKind UnixDomainSocketPair(int type) { - return SocketPairKind{DescribeUnixDomainSocketType(type), AF_UNIX, type, 0, - SyscallSocketPairCreator(AF_UNIX, type, 0)}; -} - -SocketPairKind FilesystemBoundUnixDomainSocketPair(int type) { - std::string description = absl::StrCat(DescribeUnixDomainSocketType(type), - " created with filesystem binding"); - if ((type & SOCK_DGRAM) == SOCK_DGRAM) { - return SocketPairKind{ - description, AF_UNIX, type, 0, - FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)}; - } - return SocketPairKind{ - description, AF_UNIX, type, 0, - FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)}; -} - -SocketPairKind AbstractBoundUnixDomainSocketPair(int type) { - std::string description = - absl::StrCat(DescribeUnixDomainSocketType(type), - " created with abstract namespace binding"); - if ((type & SOCK_DGRAM) == SOCK_DGRAM) { - return SocketPairKind{ - description, AF_UNIX, type, 0, - AbstractBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)}; - } - return SocketPairKind{description, AF_UNIX, type, 0, - AbstractAcceptBindSocketPairCreator(AF_UNIX, type, 0)}; -} - -SocketPairKind SocketpairGoferUnixDomainSocketPair(int type) { - std::string description = absl::StrCat(DescribeUnixDomainSocketType(type), - " created with the socketpair gofer"); - return SocketPairKind{description, AF_UNIX, type, 0, - SocketpairGoferSocketPairCreator(AF_UNIX, type, 0)}; -} - -SocketPairKind SocketpairGoferFileSocketPair(int type) { - std::string description = - absl::StrCat(((type & O_NONBLOCK) != 0) ? "non-blocking " : "", - ((type & O_CLOEXEC) != 0) ? "close-on-exec " : "", - "file socket created with the socketpair gofer"); - // The socketpair gofer always creates SOCK_STREAM sockets on open(2). - return SocketPairKind{description, AF_UNIX, SOCK_STREAM, 0, - SocketpairGoferFileSocketPairCreator(type)}; -} - -SocketPairKind FilesystemUnboundUnixDomainSocketPair(int type) { - return SocketPairKind{absl::StrCat(DescribeUnixDomainSocketType(type), - " unbound with a filesystem address"), - AF_UNIX, type, 0, - FilesystemUnboundSocketPairCreator(AF_UNIX, type, 0)}; -} - -SocketPairKind AbstractUnboundUnixDomainSocketPair(int type) { - return SocketPairKind{ - absl::StrCat(DescribeUnixDomainSocketType(type), - " unbound with an abstract namespace address"), - AF_UNIX, type, 0, AbstractUnboundSocketPairCreator(AF_UNIX, type, 0)}; -} - -void SendSingleFD(int sock, int fd, char buf[], int buf_size) { - ASSERT_NO_FATAL_FAILURE(SendFDs(sock, &fd, 1, buf, buf_size)); -} - -void SendFDs(int sock, int fds[], int fds_size, char buf[], int buf_size) { - struct msghdr msg = {}; - std::vector<char> control(CMSG_SPACE(fds_size * sizeof(int))); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = CMSG_LEN(fds_size * sizeof(int)); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - for (int i = 0; i < fds_size; i++) { - memcpy(CMSG_DATA(cmsg) + i * sizeof(int), &fds[i], sizeof(int)); - } - - ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size), - IsPosixErrorOkAndHolds(buf_size)); -} - -void RecvSingleFD(int sock, int* fd, char buf[], int buf_size) { - ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, buf_size)); -} - -void RecvSingleFD(int sock, int* fd, char buf[], int buf_size, - int expected_size) { - ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, expected_size)); -} - -void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size) { - ASSERT_NO_FATAL_FAILURE( - RecvFDs(sock, fds, fds_size, buf, buf_size, buf_size)); -} - -void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size, - int expected_size, bool peek) { - struct msghdr msg = {}; - std::vector<char> control(CMSG_SPACE(fds_size * sizeof(int))); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = buf_size; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - int flags = 0; - if (peek) { - flags |= MSG_PEEK; - } - - ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, flags), - SyscallSucceedsWithValue(expected_size)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(fds_size * sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); - - for (int i = 0; i < fds_size; i++) { - memcpy(&fds[i], CMSG_DATA(cmsg) + i * sizeof(int), sizeof(int)); - } -} - -void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size, - int expected_size) { - ASSERT_NO_FATAL_FAILURE( - RecvFDs(sock, fds, fds_size, buf, buf_size, expected_size, false)); -} - -void PeekSingleFD(int sock, int* fd, char buf[], int buf_size) { - ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, buf_size, true)); -} - -void RecvNoCmsg(int sock, char buf[], int buf_size, int expected_size) { - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) + CMSG_SPACE(sizeof(struct ucred))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = buf_size; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0), - SyscallSucceedsWithValue(expected_size)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - EXPECT_EQ(cmsg, nullptr); -} - -void SendNullCmsg(int sock, char buf[], int buf_size) { - struct msghdr msg = {}; - msg.msg_control = nullptr; - msg.msg_controllen = 0; - - ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size), - IsPosixErrorOkAndHolds(buf_size)); -} - -void SendCreds(int sock, ucred creds, char buf[], int buf_size) { - struct msghdr msg = {}; - - char control[CMSG_SPACE(sizeof(struct ucred))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_CREDENTIALS; - cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred)); - memcpy(CMSG_DATA(cmsg), &creds, sizeof(struct ucred)); - - ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size), - IsPosixErrorOkAndHolds(buf_size)); -} - -void SendCredsAndFD(int sock, ucred creds, int fd, char buf[], int buf_size) { - struct msghdr msg = {}; - - char control[CMSG_SPACE(sizeof(struct ucred)) + CMSG_SPACE(sizeof(int))] = {}; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg1 = CMSG_FIRSTHDR(&msg); - cmsg1->cmsg_level = SOL_SOCKET; - cmsg1->cmsg_type = SCM_CREDENTIALS; - cmsg1->cmsg_len = CMSG_LEN(sizeof(struct ucred)); - memcpy(CMSG_DATA(cmsg1), &creds, sizeof(struct ucred)); - - struct cmsghdr* cmsg2 = CMSG_NXTHDR(&msg, cmsg1); - cmsg2->cmsg_level = SOL_SOCKET; - cmsg2->cmsg_type = SCM_RIGHTS; - cmsg2->cmsg_len = CMSG_LEN(sizeof(int)); - memcpy(CMSG_DATA(cmsg2), &fd, sizeof(int)); - - ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size), - IsPosixErrorOkAndHolds(buf_size)); -} - -void RecvCreds(int sock, ucred* creds, char buf[], int buf_size) { - ASSERT_NO_FATAL_FAILURE(RecvCreds(sock, creds, buf, buf_size, buf_size)); -} - -void RecvCreds(int sock, ucred* creds, char buf[], int buf_size, - int expected_size) { - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(struct ucred))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = buf_size; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0), - SyscallSucceedsWithValue(expected_size)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); - - memcpy(creds, CMSG_DATA(cmsg), sizeof(struct ucred)); -} - -void RecvCredsAndFD(int sock, ucred* creds, int* fd, char buf[], int buf_size) { - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(struct ucred)) + CMSG_SPACE(sizeof(int))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = buf_size; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0), - SyscallSucceedsWithValue(buf_size)); - - struct cmsghdr* cmsg1 = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg1, nullptr); - ASSERT_EQ(cmsg1->cmsg_len, CMSG_LEN(sizeof(struct ucred))); - ASSERT_EQ(cmsg1->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg1->cmsg_type, SCM_CREDENTIALS); - memcpy(creds, CMSG_DATA(cmsg1), sizeof(struct ucred)); - - struct cmsghdr* cmsg2 = CMSG_NXTHDR(&msg, cmsg1); - ASSERT_NE(cmsg2, nullptr); - ASSERT_EQ(cmsg2->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg2->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg2->cmsg_type, SCM_RIGHTS); - memcpy(fd, CMSG_DATA(cmsg2), sizeof(int)); -} - -void RecvSingleFDUnaligned(int sock, int* fd, char buf[], int buf_size) { - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - iov.iov_base = buf; - iov.iov_len = buf_size; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0), - SyscallSucceedsWithValue(buf_size)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); - - memcpy(fd, CMSG_DATA(cmsg), sizeof(int)); -} - -void SetSoPassCred(int sock) { - int one = 1; - EXPECT_THAT(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &one, sizeof(one)), - SyscallSucceeds()); -} - -void UnsetSoPassCred(int sock) { - int zero = 0; - EXPECT_THAT(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &zero, sizeof(zero)), - SyscallSucceeds()); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/unix_domain_socket_test_util.h b/test/syscalls/linux/unix_domain_socket_test_util.h deleted file mode 100644 index b8073db17..000000000 --- a/test/syscalls/linux/unix_domain_socket_test_util.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_ -#define GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_ - -#include <string> - -#include "test/syscalls/linux/socket_test_util.h" - -namespace gvisor { -namespace testing { - -// DescribeUnixDomainSocketType returns a human-readable string explaining the -// given Unix domain socket type. -std::string DescribeUnixDomainSocketType(int type); - -// UnixDomainSocketPair returns a SocketPairKind that represents SocketPairs -// created by invoking the socketpair() syscall with AF_UNIX and the given type. -SocketPairKind UnixDomainSocketPair(int type); - -// FilesystemBoundUnixDomainSocketPair returns a SocketPairKind that represents -// SocketPairs created with bind() and accept() syscalls with a temp file path, -// AF_UNIX and the given type. -SocketPairKind FilesystemBoundUnixDomainSocketPair(int type); - -// AbstractBoundUnixDomainSocketPair returns a SocketPairKind that represents -// SocketPairs created with bind() and accept() syscalls with a temp abstract -// path, AF_UNIX and the given type. -SocketPairKind AbstractBoundUnixDomainSocketPair(int type); - -// SocketpairGoferUnixDomainSocketPair returns a SocketPairKind that was created -// with two sockets connected to the socketpair gofer. -SocketPairKind SocketpairGoferUnixDomainSocketPair(int type); - -// SocketpairGoferFileSocketPair returns a SocketPairKind that was created with -// two open() calls on paths backed by the socketpair gofer. -SocketPairKind SocketpairGoferFileSocketPair(int type); - -// FilesystemUnboundUnixDomainSocketPair returns a SocketPairKind that -// represents two unbound sockets and a filesystem path for binding. -SocketPairKind FilesystemUnboundUnixDomainSocketPair(int type); - -// AbstractUnboundUnixDomainSocketPair returns a SocketPairKind that represents -// two unbound sockets and an abstract namespace path for binding. -SocketPairKind AbstractUnboundUnixDomainSocketPair(int type); - -// SendSingleFD sends both a single FD and some data over a unix domain socket -// specified by an FD. Note that calls to this function must be wrapped in -// ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test. -void SendSingleFD(int sock, int fd, char buf[], int buf_size); - -// SendFDs sends an arbitrary number of FDs and some data over a unix domain -// socket specified by an FD. Note that calls to this function must be wrapped -// in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test. -void SendFDs(int sock, int fds[], int fds_size, char buf[], int buf_size); - -// RecvSingleFD receives both a single FD and some data over a unix domain -// socket specified by an FD. Note that calls to this function must be wrapped -// in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test. -void RecvSingleFD(int sock, int* fd, char buf[], int buf_size); - -// RecvSingleFD receives both a single FD and some data over a unix domain -// socket specified by an FD. This version allows the expected amount of data -// received to be different than the buffer size. Note that calls to this -// function must be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions -// to halt the test. -void RecvSingleFD(int sock, int* fd, char buf[], int buf_size, - int expected_size); - -// PeekSingleFD peeks at both a single FD and some data over a unix domain -// socket specified by an FD. Note that calls to this function must be wrapped -// in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test. -void PeekSingleFD(int sock, int* fd, char buf[], int buf_size); - -// RecvFDs receives both an arbitrary number of FDs and some data over a unix -// domain socket specified by an FD. Note that calls to this function must be -// wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test. -void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size); - -// RecvFDs receives both an arbitrary number of FDs and some data over a unix -// domain socket specified by an FD. This version allows the expected amount of -// data received to be different than the buffer size. Note that calls to this -// function must be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions -// to halt the test. -void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size, - int expected_size); - -// RecvNoCmsg receives some data over a unix domain socket specified by an FD -// and asserts that no control messages are available for receiving. Note that -// calls to this function must be wrapped in ASSERT_NO_FATAL_FAILURE for -// internal assertions to halt the test. -void RecvNoCmsg(int sock, char buf[], int buf_size, int expected_size); - -inline void RecvNoCmsg(int sock, char buf[], int buf_size) { - RecvNoCmsg(sock, buf, buf_size, buf_size); -} - -// SendCreds sends the credentials of the current process and some data over a -// unix domain socket specified by an FD. Note that calls to this function must -// be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the -// test. -void SendCreds(int sock, ucred creds, char buf[], int buf_size); - -// SendCredsAndFD sends the credentials of the current process, a single FD, and -// some data over a unix domain socket specified by an FD. Note that calls to -// this function must be wrapped in ASSERT_NO_FATAL_FAILURE for internal -// assertions to halt the test. -void SendCredsAndFD(int sock, ucred creds, int fd, char buf[], int buf_size); - -// RecvCreds receives some credentials and some data over a unix domain socket -// specified by an FD. Note that calls to this function must be wrapped in -// ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test. -void RecvCreds(int sock, ucred* creds, char buf[], int buf_size); - -// RecvCreds receives some credentials and some data over a unix domain socket -// specified by an FD. This version allows the expected amount of data received -// to be different than the buffer size. Note that calls to this function must -// be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the -// test. -void RecvCreds(int sock, ucred* creds, char buf[], int buf_size, - int expected_size); - -// RecvCredsAndFD receives some credentials, a single FD, and some data over a -// unix domain socket specified by an FD. Note that calls to this function must -// be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the -// test. -void RecvCredsAndFD(int sock, ucred* creds, int* fd, char buf[], int buf_size); - -// SendNullCmsg sends a null control message and some data over a unix domain -// socket specified by an FD. Note that calls to this function must be wrapped -// in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test. -void SendNullCmsg(int sock, char buf[], int buf_size); - -// RecvSingleFDUnaligned sends both a single FD and some data over a unix domain -// socket specified by an FD. This function does not obey the spec, but Linux -// allows it and the apphosting code depends on this quirk. Note that calls to -// this function must be wrapped in ASSERT_NO_FATAL_FAILURE for internal -// assertions to halt the test. -void RecvSingleFDUnaligned(int sock, int* fd, char buf[], int buf_size); - -// SetSoPassCred sets the SO_PASSCRED option on the specified socket. -void SetSoPassCred(int sock); - -// UnsetSoPassCred clears the SO_PASSCRED option on the specified socket. -void UnsetSoPassCred(int sock); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_ diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc deleted file mode 100644 index 2040375c9..000000000 --- a/test/syscalls/linux/unlink.cc +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <unistd.h> - -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(UnlinkTest, IsDir) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - EXPECT_THAT(unlink(dir.path().c_str()), SyscallFailsWithErrno(EISDIR)); -} - -TEST(UnlinkTest, DirNotEmpty) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - int fd; - std::string path = JoinPath(dir.path(), "ExistingFile"); - EXPECT_THAT(fd = open(path.c_str(), O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - EXPECT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(ENOTEMPTY)); -} - -TEST(UnlinkTest, Rmdir) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(rmdir(dir.path().c_str()), SyscallSucceeds()); -} - -TEST(UnlinkTest, AtDir) { - int dirfd; - auto tmpdir = GetAbsoluteTestTmpdir(); - EXPECT_THAT(dirfd = open(tmpdir.c_str(), O_DIRECTORY, 0), SyscallSucceeds()); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(tmpdir)); - auto dir_relpath = - ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(tmpdir, dir.path())); - EXPECT_THAT(unlinkat(dirfd, dir_relpath.c_str(), AT_REMOVEDIR), - SyscallSucceeds()); - ASSERT_THAT(close(dirfd), SyscallSucceeds()); -} - -TEST(UnlinkTest, AtDirDegradedPermissions_NoRandomSave) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - int dirfd; - ASSERT_THAT(dirfd = open(dir.path().c_str(), O_DIRECTORY, 0), - SyscallSucceeds()); - - std::string sub_dir = JoinPath(dir.path(), "NewDir"); - EXPECT_THAT(mkdir(sub_dir.c_str(), 0755), SyscallSucceeds()); - EXPECT_THAT(fchmod(dirfd, 0444), SyscallSucceeds()); - EXPECT_THAT(unlinkat(dirfd, "NewDir", AT_REMOVEDIR), - SyscallFailsWithErrno(EACCES)); - ASSERT_THAT(close(dirfd), SyscallSucceeds()); -} - -// Files cannot be unlinked if the parent is not writable and executable. -TEST(UnlinkTest, ParentDegradedPermissions) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - - ASSERT_THAT(chmod(dir.path().c_str(), 0000), SyscallSucceeds()); - - struct stat st; - ASSERT_THAT(stat(file.path().c_str(), &st), SyscallFailsWithErrno(EACCES)); - ASSERT_THAT(unlinkat(AT_FDCWD, file.path().c_str(), 0), - SyscallFailsWithErrno(EACCES)); - - // Non-existent files also return EACCES. - const std::string nonexist = JoinPath(dir.path(), "doesnotexist"); - ASSERT_THAT(stat(nonexist.c_str(), &st), SyscallFailsWithErrno(EACCES)); - ASSERT_THAT(unlinkat(AT_FDCWD, nonexist.c_str(), 0), - SyscallFailsWithErrno(EACCES)); -} - -TEST(UnlinkTest, AtBad) { - int dirfd; - EXPECT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_DIRECTORY, 0), - SyscallSucceeds()); - - // Try removing a directory as a file. - std::string path = JoinPath(GetAbsoluteTestTmpdir(), "NewDir"); - EXPECT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); - EXPECT_THAT(unlinkat(dirfd, "NewDir", 0), SyscallFailsWithErrno(EISDIR)); - EXPECT_THAT(unlinkat(dirfd, "NewDir", AT_REMOVEDIR), SyscallSucceeds()); - - // Try removing a file as a directory. - int fd; - EXPECT_THAT(fd = openat(dirfd, "UnlinkAtFile", O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", AT_REMOVEDIR), - SyscallFailsWithErrno(ENOTDIR)); - EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile/", 0), - SyscallFailsWithErrno(ENOTDIR)); - ASSERT_THAT(close(fd), SyscallSucceeds()); - EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds()); - - // Cleanup. - ASSERT_THAT(close(dirfd), SyscallSucceeds()); -} - -TEST(UnlinkTest, AbsTmpFile) { - int fd; - std::string path = JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile"); - EXPECT_THAT(fd = open(path.c_str(), O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - EXPECT_THAT(unlink(path.c_str()), SyscallSucceeds()); -} - -TEST(UnlinkTest, TooLongName) { - EXPECT_THAT(unlink(std::vector<char>(16384, '0').data()), - SyscallFailsWithErrno(ENAMETOOLONG)); -} - -TEST(UnlinkTest, BadNamePtr) { - EXPECT_THAT(unlink(reinterpret_cast<char*>(1)), - SyscallFailsWithErrno(EFAULT)); -} - -TEST(UnlinkTest, AtFile) { - int dirfd; - EXPECT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_DIRECTORY, 0666), - SyscallSucceeds()); - int fd; - EXPECT_THAT(fd = openat(dirfd, "UnlinkAtFile", O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds()); -} - -TEST(UnlinkTest, OpenFile_NoRandomSave) { - // We can't save unlinked file unless they are on tmpfs. - const DisableSave ds; - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - int fd; - EXPECT_THAT(fd = open(file.path().c_str(), O_RDWR, 0666), SyscallSucceeds()); - EXPECT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(UnlinkTest, CannotRemoveDots) { - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const std::string self = JoinPath(file.path(), "."); - ASSERT_THAT(unlink(self.c_str()), SyscallFailsWithErrno(ENOTDIR)); - const std::string parent = JoinPath(file.path(), ".."); - ASSERT_THAT(unlink(parent.c_str()), SyscallFailsWithErrno(ENOTDIR)); -} - -TEST(UnlinkTest, CannotRemoveRoot) { - ASSERT_THAT(unlinkat(-1, "/", AT_REMOVEDIR), SyscallFailsWithErrno(EBUSY)); -} - -TEST(UnlinkTest, CannotRemoveRootWithAtDir) { - const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE( - Open(GetAbsoluteTestTmpdir(), O_DIRECTORY, 0666)); - ASSERT_THAT(unlinkat(dirfd.get(), "/", AT_REMOVEDIR), - SyscallFailsWithErrno(EBUSY)); -} - -TEST(RmdirTest, CannotRemoveDots) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string self = JoinPath(dir.path(), "."); - ASSERT_THAT(rmdir(self.c_str()), SyscallFailsWithErrno(EINVAL)); - const std::string parent = JoinPath(dir.path(), ".."); - ASSERT_THAT(rmdir(parent.c_str()), SyscallFailsWithErrno(ENOTEMPTY)); -} - -TEST(RmdirTest, CanRemoveWithTrailingSlashes) { - auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string slash = absl::StrCat(dir1.path(), "/"); - ASSERT_THAT(rmdir(slash.c_str()), SyscallSucceeds()); - auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string slashslash = absl::StrCat(dir2.path(), "//"); - ASSERT_THAT(rmdir(slashslash.c_str()), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/unshare.cc b/test/syscalls/linux/unshare.cc deleted file mode 100644 index e32619efe..000000000 --- a/test/syscalls/linux/unshare.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <sched.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/synchronization/mutex.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(UnshareTest, AllowsZeroFlags) { - ASSERT_THAT(unshare(0), SyscallSucceeds()); -} - -TEST(UnshareTest, ThreadFlagFailsIfMultithreaded) { - absl::Mutex mu; - bool finished = false; - ScopedThread t([&] { - mu.Lock(); - mu.Await(absl::Condition(&finished)); - mu.Unlock(); - }); - ASSERT_THAT(unshare(CLONE_THREAD), SyscallFailsWithErrno(EINVAL)); - mu.Lock(); - finished = true; - mu.Unlock(); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc deleted file mode 100644 index 3a927a430..000000000 --- a/test/syscalls/linux/utimes.cc +++ /dev/null @@ -1,332 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/syscall.h> -#include <sys/time.h> -#include <sys/types.h> -#include <time.h> -#include <unistd.h> -#include <utime.h> - -#include <string> - -#include "absl/time/time.h" -#include "test/util/file_descriptor.h" -#include "test/util/fs_util.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// TODO(b/36516566): utimes(nullptr) does not pick the "now" time in the -// application's time domain, so when asserting that times are within a window, -// we expand the window to allow for differences between the time domains. -constexpr absl::Duration kClockSlack = absl::Milliseconds(100); - -// TimeBoxed runs fn, setting before and after to (coarse realtime) times -// guaranteed* to come before and after fn started and completed, respectively. -// -// fn may be called more than once if the clock is adjusted. -// -// * See the comment on kClockSlack. gVisor breaks this guarantee. -void TimeBoxed(absl::Time* before, absl::Time* after, - std::function<void()> const& fn) { - do { - // N.B. utimes and friends use CLOCK_REALTIME_COARSE for setting time (i.e., - // current_kernel_time()). See fs/attr.c:notify_change. - // - // notify_change truncates the time to a multiple of s_time_gran, but most - // filesystems set it to 1, so we don't do any truncation. - struct timespec ts; - EXPECT_THAT(clock_gettime(CLOCK_REALTIME_COARSE, &ts), SyscallSucceeds()); - *before = absl::TimeFromTimespec(ts); - - fn(); - - EXPECT_THAT(clock_gettime(CLOCK_REALTIME_COARSE, &ts), SyscallSucceeds()); - *after = absl::TimeFromTimespec(ts); - - if (*after < *before) { - // Clock jumped backwards; retry. - // - // Technically this misses jumps small enough to keep after > before, - // which could lead to test failures, but that is very unlikely to happen. - continue; - } - - if (IsRunningOnGvisor()) { - // See comment on kClockSlack. - *before -= kClockSlack; - *after += kClockSlack; - } - } while (*after < *before); -} - -void TestUtimesOnPath(std::string const& path) { - struct stat statbuf; - - struct timeval times[2] = {{1, 0}, {2, 0}}; - EXPECT_THAT(utimes(path.c_str(), times), SyscallSucceeds()); - EXPECT_THAT(stat(path.c_str(), &statbuf), SyscallSucceeds()); - EXPECT_EQ(1, statbuf.st_atime); - EXPECT_EQ(2, statbuf.st_mtime); - - absl::Time before; - absl::Time after; - TimeBoxed(&before, &after, [&] { - EXPECT_THAT(utimes(path.c_str(), nullptr), SyscallSucceeds()); - }); - - EXPECT_THAT(stat(path.c_str(), &statbuf), SyscallSucceeds()); - - absl::Time atime = absl::TimeFromTimespec(statbuf.st_atim); - EXPECT_GE(atime, before); - EXPECT_LE(atime, after); - - absl::Time mtime = absl::TimeFromTimespec(statbuf.st_mtim); - EXPECT_GE(mtime, before); - EXPECT_LE(mtime, after); -} - -TEST(UtimesTest, OnFile) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - TestUtimesOnPath(f.path()); -} - -TEST(UtimesTest, OnDir) { - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TestUtimesOnPath(dir.path()); -} - -TEST(UtimesTest, MissingPath) { - auto path = NewTempAbsPath(); - struct timeval times[2] = {{1, 0}, {2, 0}}; - EXPECT_THAT(utimes(path.c_str(), times), SyscallFailsWithErrno(ENOENT)); -} - -void TestFutimesat(int dirFd, std::string const& path) { - struct stat statbuf; - - struct timeval times[2] = {{1, 0}, {2, 0}}; - EXPECT_THAT(futimesat(dirFd, path.c_str(), times), SyscallSucceeds()); - EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds()); - EXPECT_EQ(1, statbuf.st_atime); - EXPECT_EQ(2, statbuf.st_mtime); - - absl::Time before; - absl::Time after; - TimeBoxed(&before, &after, [&] { - EXPECT_THAT(futimesat(dirFd, path.c_str(), nullptr), SyscallSucceeds()); - }); - - EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds()); - - absl::Time atime = absl::TimeFromTimespec(statbuf.st_atim); - EXPECT_GE(atime, before); - EXPECT_LE(atime, after); - - absl::Time mtime = absl::TimeFromTimespec(statbuf.st_mtim); - EXPECT_GE(mtime, before); - EXPECT_LE(mtime, after); -} - -TEST(FutimesatTest, OnAbsPath) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - TestFutimesat(0, f.path()); -} - -TEST(FutimesatTest, OnRelPath) { - auto d = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(d.path())); - auto basename = std::string(Basename(f.path())); - const FileDescriptor dirFd = - ASSERT_NO_ERRNO_AND_VALUE(Open(d.path(), O_RDONLY | O_DIRECTORY)); - TestFutimesat(dirFd.get(), basename); -} - -TEST(FutimesatTest, InvalidNsec) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - struct timeval times[4][2] = {{ - {0, 1}, // Valid - {1, static_cast<int64_t>(1e7)} // Invalid - }, - { - {1, static_cast<int64_t>(1e7)}, // Invalid - {0, 1} // Valid - }, - { - {0, 1}, // Valid - {1, -1} // Invalid - }, - { - {1, -1}, // Invalid - {0, 1} // Valid - }}; - - for (unsigned int i = 0; i < sizeof(times) / sizeof(times[0]); i++) { - std::cout << "test:" << i << "\n"; - EXPECT_THAT(futimesat(0, f.path().c_str(), times[i]), - SyscallFailsWithErrno(EINVAL)); - } -} - -void TestUtimensat(int dirFd, std::string const& path) { - struct stat statbuf; - const struct timespec times[2] = {{1, 0}, {2, 0}}; - EXPECT_THAT(utimensat(dirFd, path.c_str(), times, 0), SyscallSucceeds()); - EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds()); - EXPECT_EQ(1, statbuf.st_atime); - EXPECT_EQ(2, statbuf.st_mtime); - - // Test setting with UTIME_NOW and UTIME_OMIT. - struct stat statbuf2; - const struct timespec times2[2] = { - {0, UTIME_NOW}, // Should set atime to now. - {0, UTIME_OMIT} // Should not change mtime. - }; - - absl::Time before; - absl::Time after; - TimeBoxed(&before, &after, [&] { - EXPECT_THAT(utimensat(dirFd, path.c_str(), times2, 0), SyscallSucceeds()); - }); - - EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf2, 0), SyscallSucceeds()); - - absl::Time atime2 = absl::TimeFromTimespec(statbuf2.st_atim); - EXPECT_GE(atime2, before); - EXPECT_LE(atime2, after); - - absl::Time mtime = absl::TimeFromTimespec(statbuf.st_mtim); - absl::Time mtime2 = absl::TimeFromTimespec(statbuf2.st_mtim); - // mtime should not be changed. - EXPECT_EQ(mtime, mtime2); - - // Test setting with times = NULL. Should set both atime and mtime to the - // current system time. - struct stat statbuf3; - TimeBoxed(&before, &after, [&] { - EXPECT_THAT(utimensat(dirFd, path.c_str(), nullptr, 0), SyscallSucceeds()); - }); - - EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf3, 0), SyscallSucceeds()); - - absl::Time atime3 = absl::TimeFromTimespec(statbuf3.st_atim); - EXPECT_GE(atime3, before); - EXPECT_LE(atime3, after); - - absl::Time mtime3 = absl::TimeFromTimespec(statbuf3.st_mtim); - EXPECT_GE(mtime3, before); - EXPECT_LE(mtime3, after); - - if (!IsRunningOnGvisor()) { - // FIXME(b/36516566): Gofers set atime and mtime to different "now" times. - EXPECT_EQ(atime3, mtime3); - } -} - -TEST(UtimensatTest, OnAbsPath) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - TestUtimensat(0, f.path()); -} - -TEST(UtimensatTest, OnRelPath) { - auto d = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(d.path())); - auto basename = std::string(Basename(f.path())); - const FileDescriptor dirFd = - ASSERT_NO_ERRNO_AND_VALUE(Open(d.path(), O_RDONLY | O_DIRECTORY)); - TestUtimensat(dirFd.get(), basename); -} - -TEST(UtimensatTest, OmitNoop) { - // Setting both timespecs to UTIME_OMIT on a nonexistant path should succeed. - auto path = NewTempAbsPath(); - const struct timespec times[2] = {{0, UTIME_OMIT}, {0, UTIME_OMIT}}; - EXPECT_THAT(utimensat(0, path.c_str(), times, 0), SyscallSucceeds()); -} - -// Verify that we can actually set atime and mtime to 0. -TEST(UtimeTest, ZeroAtimeandMtime) { - const auto tmp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const auto tmp_file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(tmp_dir.path())); - - // Stat the file before and after updating atime and mtime. - struct stat stat_before = {}; - EXPECT_THAT(stat(tmp_file.path().c_str(), &stat_before), SyscallSucceeds()); - - ASSERT_NE(stat_before.st_atime, 0); - ASSERT_NE(stat_before.st_mtime, 0); - - const struct utimbuf times = {}; // Zero for both atime and mtime. - EXPECT_THAT(utime(tmp_file.path().c_str(), ×), SyscallSucceeds()); - - struct stat stat_after = {}; - EXPECT_THAT(stat(tmp_file.path().c_str(), &stat_after), SyscallSucceeds()); - - // We should see the atime and mtime changed when we set them to 0. - ASSERT_EQ(stat_after.st_atime, 0); - ASSERT_EQ(stat_after.st_mtime, 0); -} - -TEST(UtimensatTest, InvalidNsec) { - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - struct timespec times[2][2] = { - { - {0, UTIME_OMIT}, // Valid - {2, static_cast<int64_t>(1e10)} // Invalid - }, - { - {2, static_cast<int64_t>(1e10)}, // Invalid - {0, UTIME_OMIT} // Valid - }}; - - for (unsigned int i = 0; i < sizeof(times) / sizeof(times[0]); i++) { - std::cout << "test:" << i << "\n"; - EXPECT_THAT(utimensat(0, f.path().c_str(), times[i], 0), - SyscallFailsWithErrno(EINVAL)); - } -} - -TEST(Utimensat, NullPath) { - // From man utimensat(2): - // "the Linux utimensat() system call implements a nonstandard feature: if - // pathname is NULL, then the call modifies the timestamps of the file - // referred to by the file descriptor dirfd (which may refer to any type of - // file). - // Note, however, that the glibc wrapper for utimensat() disallows - // passing NULL as the value for file: the wrapper function returns the error - // EINVAL in this case." - auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR)); - struct stat statbuf; - const struct timespec times[2] = {{1, 0}, {2, 0}}; - // Call syscall directly. - EXPECT_THAT(syscall(SYS_utimensat, fd.get(), NULL, times, 0), - SyscallSucceeds()); - EXPECT_THAT(fstatat(0, f.path().c_str(), &statbuf, 0), SyscallSucceeds()); - EXPECT_EQ(1, statbuf.st_atime); - EXPECT_EQ(2, statbuf.st_mtime); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/vdso.cc b/test/syscalls/linux/vdso.cc deleted file mode 100644 index 19c80add8..000000000 --- a/test/syscalls/linux/vdso.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <string.h> -#include <sys/mman.h> - -#include <algorithm> - -#include "gtest/gtest.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/proc_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// Ensure that the vvar page cannot be made writable. -TEST(VvarTest, WriteVvar) { - auto contents = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); - auto maps = ASSERT_NO_ERRNO_AND_VALUE(ParseProcMaps(contents)); - auto it = std::find_if(maps.begin(), maps.end(), [](const ProcMapsEntry& e) { - return e.filename == "[vvar]"; - }); - - SKIP_IF(it == maps.end()); - EXPECT_THAT(mprotect(reinterpret_cast<void*>(it->start), kPageSize, - PROT_READ | PROT_WRITE), - SyscallFailsWithErrno(EACCES)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc deleted file mode 100644 index ce1899f45..000000000 --- a/test/syscalls/linux/vdso_clock_gettime.cc +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdint.h> -#include <sys/time.h> -#include <syscall.h> -#include <time.h> -#include <unistd.h> - -#include <map> -#include <string> -#include <utility> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/numbers.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) { - switch (info.param) { - case CLOCK_MONOTONIC: - return "CLOCK_MONOTONIC"; - case CLOCK_REALTIME: - return "CLOCK_REALTIME"; - case CLOCK_BOOTTIME: - return "CLOCK_BOOTTIME"; - default: - return absl::StrCat(info.param); - } -} - -class CorrectVDSOClockTest : public ::testing::TestWithParam<clockid_t> {}; - -TEST_P(CorrectVDSOClockTest, IsCorrect) { - struct timespec tvdso, tsys; - absl::Time vdso_time, sys_time; - uint64_t total_calls = 0; - - // It is expected that 82.5% of clock_gettime calls will be less than 100us - // skewed from the system time. - // Unfortunately this is not only influenced by the VDSO clock skew, but also - // by arbitrary scheduling delays and the like. The test is therefore - // regularly disabled. - std::map<absl::Duration, std::tuple<double, uint64_t, uint64_t>> confidence = - { - {absl::Microseconds(100), std::make_tuple(0.825, 0, 0)}, - {absl::Microseconds(250), std::make_tuple(0.94, 0, 0)}, - {absl::Milliseconds(1), std::make_tuple(0.999, 0, 0)}, - }; - - absl::Time start = absl::Now(); - while (absl::Now() < start + absl::Seconds(30)) { - EXPECT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds()); - EXPECT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), - SyscallSucceeds()); - - vdso_time = absl::TimeFromTimespec(tvdso); - - for (auto const& conf : confidence) { - std::get<1>(confidence[conf.first]) += - (sys_time - vdso_time) < conf.first; - } - - sys_time = absl::TimeFromTimespec(tsys); - - for (auto const& conf : confidence) { - std::get<2>(confidence[conf.first]) += - (vdso_time - sys_time) < conf.first; - } - - ++total_calls; - } - - for (auto const& conf : confidence) { - EXPECT_GE(std::get<1>(conf.second) / static_cast<double>(total_calls), - std::get<0>(conf.second)); - EXPECT_GE(std::get<2>(conf.second) / static_cast<double>(total_calls), - std::get<0>(conf.second)); - } -} - -INSTANTIATE_TEST_SUITE_P(ClockGettime, CorrectVDSOClockTest, - ::testing::Values(CLOCK_MONOTONIC, CLOCK_REALTIME, - CLOCK_BOOTTIME), - PrintClockId); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc deleted file mode 100644 index 19d05998e..000000000 --- a/test/syscalls/linux/vfork.cc +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <sys/types.h> -#include <sys/wait.h> -#include <unistd.h> - -#include <string> -#include <utility> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/time/time.h" -#include "test/util/logging.h" -#include "test/util/multiprocess_util.h" -#include "test/util/test_util.h" -#include "test/util/time_util.h" - -ABSL_FLAG(bool, vfork_test_child, false, - "If true, run the VforkTest child workload."); - -namespace gvisor { -namespace testing { - -namespace { - -// We don't test with raw CLONE_VFORK to avoid interacting with glibc's use of -// TLS. -// -// Even with vfork(2), we must be careful to do little more in the child than -// call execve(2). We use the simplest sleep function possible, though this is -// still precarious, as we're officially only allowed to call execve(2) and -// _exit(2). -constexpr absl::Duration kChildDelay = absl::Seconds(10); - -// Exit code for successful child subprocesses. We don't want to use 0 since -// it's too common, and an execve(2) failure causes the child to exit with the -// errno, so kChildExitCode is chosen to be an unlikely errno: -constexpr int kChildExitCode = 118; // ENOTNAM: Not a XENIX named type file - -int64_t MonotonicNow() { - struct timespec now; - TEST_PCHECK(clock_gettime(CLOCK_MONOTONIC, &now) == 0); - return now.tv_sec * 1000000000ll + now.tv_nsec; -} - -TEST(VforkTest, ParentStopsUntilChildExits) { - const auto test = [] { - // N.B. Run the test in a single-threaded subprocess because - // vfork is not safe in a multi-threaded process. - - const int64_t start = MonotonicNow(); - - pid_t pid = vfork(); - if (pid == 0) { - SleepSafe(kChildDelay); - _exit(kChildExitCode); - } - TEST_PCHECK_MSG(pid > 0, "vfork failed"); - MaybeSave(); - - const int64_t end = MonotonicNow(); - - absl::Duration dur = absl::Nanoseconds(end - start); - - TEST_CHECK(dur >= kChildDelay); - - int status = 0; - TEST_PCHECK(RetryEINTR(waitpid)(pid, &status, 0)); - TEST_CHECK(WIFEXITED(status)); - TEST_CHECK(WEXITSTATUS(status) == kChildExitCode); - }; - - EXPECT_THAT(InForkedProcess(test), IsPosixErrorOkAndHolds(0)); -} - -TEST(VforkTest, ParentStopsUntilChildExecves_NoRandomSave) { - ExecveArray const owned_child_argv = {"/proc/self/exe", "--vfork_test_child"}; - char* const* const child_argv = owned_child_argv.get(); - - const auto test = [&] { - const int64_t start = MonotonicNow(); - - pid_t pid = vfork(); - if (pid == 0) { - SleepSafe(kChildDelay); - execve(child_argv[0], child_argv, /* envp = */ nullptr); - _exit(errno); - } - // Don't attempt save/restore until after recording end_time, - // since the test expects an upper bound on the time spent - // stopped. - int saved_errno = errno; - const int64_t end = MonotonicNow(); - errno = saved_errno; - TEST_PCHECK_MSG(pid > 0, "vfork failed"); - MaybeSave(); - - absl::Duration dur = absl::Nanoseconds(end - start); - - // The parent should resume execution after execve, but before - // the post-execve test child exits. - TEST_CHECK(dur >= kChildDelay); - TEST_CHECK(dur <= 2 * kChildDelay); - - int status = 0; - TEST_PCHECK(RetryEINTR(waitpid)(pid, &status, 0)); - TEST_CHECK(WIFEXITED(status)); - TEST_CHECK(WEXITSTATUS(status) == kChildExitCode); - }; - - EXPECT_THAT(InForkedProcess(test), IsPosixErrorOkAndHolds(0)); -} - -// A vfork child does not unstop the parent a second time when it exits after -// exec. -TEST(VforkTest, ExecedChildExitDoesntUnstopParent_NoRandomSave) { - ExecveArray const owned_child_argv = {"/proc/self/exe", "--vfork_test_child"}; - char* const* const child_argv = owned_child_argv.get(); - - const auto test = [&] { - pid_t pid1 = vfork(); - if (pid1 == 0) { - execve(child_argv[0], child_argv, /* envp = */ nullptr); - _exit(errno); - } - TEST_PCHECK_MSG(pid1 > 0, "vfork failed"); - MaybeSave(); - - // pid1 exec'd and is now sleeping. - SleepSafe(kChildDelay / 2); - - const int64_t start = MonotonicNow(); - - pid_t pid2 = vfork(); - if (pid2 == 0) { - SleepSafe(kChildDelay); - _exit(kChildExitCode); - } - TEST_PCHECK_MSG(pid2 > 0, "vfork failed"); - MaybeSave(); - - const int64_t end = MonotonicNow(); - - absl::Duration dur = absl::Nanoseconds(end - start); - - // The parent should resume execution only after pid2 exits, not - // when pid1 exits. - TEST_CHECK(dur >= kChildDelay); - - int status = 0; - TEST_PCHECK(RetryEINTR(waitpid)(pid1, &status, 0)); - TEST_CHECK(WIFEXITED(status)); - TEST_CHECK(WEXITSTATUS(status) == kChildExitCode); - - TEST_PCHECK(RetryEINTR(waitpid)(pid2, &status, 0)); - TEST_CHECK(WIFEXITED(status)); - TEST_CHECK(WEXITSTATUS(status) == kChildExitCode); - }; - - EXPECT_THAT(InForkedProcess(test), IsPosixErrorOkAndHolds(0)); -} - -int RunChild() { - SleepSafe(kChildDelay); - return kChildExitCode; -} - -} // namespace - -} // namespace testing -} // namespace gvisor - -int main(int argc, char** argv) { - gvisor::testing::TestInit(&argc, &argv); - - if (absl::GetFlag(FLAGS_vfork_test_child)) { - return gvisor::testing::RunChild(); - } - - return gvisor::testing::RunAllTests(); -} diff --git a/test/syscalls/linux/vsyscall.cc b/test/syscalls/linux/vsyscall.cc deleted file mode 100644 index ae4377108..000000000 --- a/test/syscalls/linux/vsyscall.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <time.h> - -#include "gtest/gtest.h" -#include "test/util/proc_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -#if defined(__x86_64__) || defined(__i386__) -time_t vsyscall_time(time_t* t) { - constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; - return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t); -} - -TEST(VsyscallTest, VsyscallAlwaysAvailableOnGvisor) { - SKIP_IF(!IsRunningOnGvisor()); - // Vsyscall is always advertised by gvisor. - EXPECT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled())); - // Vsyscall should always works on gvisor. - time_t t; - EXPECT_THAT(vsyscall_time(&t), SyscallSucceeds()); -} -#endif - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/wait.cc b/test/syscalls/linux/wait.cc deleted file mode 100644 index 944149d5e..000000000 --- a/test/syscalls/linux/wait.cc +++ /dev/null @@ -1,913 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> -#include <sys/mman.h> -#include <sys/ptrace.h> -#include <sys/resource.h> -#include <sys/time.h> -#include <sys/types.h> -#include <sys/wait.h> -#include <unistd.h> - -#include <functional> -#include <tuple> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/logging.h" -#include "test/util/multiprocess_util.h" -#include "test/util/posix_error.h" -#include "test/util/signal_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" -#include "test/util/time_util.h" - -using ::testing::UnorderedElementsAre; - -// These unit tests focus on the wait4(2) system call, but include a basic -// checks for the i386 waitpid(2) syscall, which is a subset of wait4(2). -// -// NOTE(b/22640830,b/27680907,b/29049891): Some functionality is not tested as -// it is not currently supported by gVisor: -// * Process groups. -// * Core dump status (WCOREDUMP). -// -// Tests for waiting on stopped/continued children are in sigstop.cc. - -namespace gvisor { -namespace testing { - -namespace { - -// The CloneChild function seems to need more than one page of stack space. -static const size_t kStackSize = 2 * kPageSize; - -// The child thread created in CloneAndExit runs this function. -// This child does not have the TLS setup, so it must not use glibc functions. -int CloneChild(void* priv) { - int64_t sleep = reinterpret_cast<int64_t>(priv); - SleepSafe(absl::Seconds(sleep)); - - // glibc's _exit(2) function wrapper will helpfully call exit_group(2), - // exiting the entire process. - syscall(__NR_exit, 0); - return 1; -} - -// ForkAndExit forks a child process which exits with exit_code, after -// sleeping for the specified duration (seconds). -pid_t ForkAndExit(int exit_code, int64_t sleep) { - pid_t child = fork(); - if (child == 0) { - SleepSafe(absl::Seconds(sleep)); - _exit(exit_code); - } - return child; -} - -int64_t clock_gettime_nsecs(clockid_t id) { - struct timespec ts; - TEST_PCHECK(clock_gettime(id, &ts) == 0); - return (ts.tv_sec * 1000000000 + ts.tv_nsec); -} - -void spin(int64_t sec) { - int64_t ns = sec * 1000000000; - int64_t start = clock_gettime_nsecs(CLOCK_THREAD_CPUTIME_ID); - int64_t end = start + ns; - - do { - constexpr int kLoopCount = 1000000; // large and arbitrary - // volatile to prevent the compiler from skipping this loop. - for (volatile int i = 0; i < kLoopCount; i++) { - } - } while (clock_gettime_nsecs(CLOCK_THREAD_CPUTIME_ID) < end); -} - -// ForkSpinAndExit forks a child process which exits with exit_code, after -// spinning for the specified duration (seconds). -pid_t ForkSpinAndExit(int exit_code, int64_t spintime) { - pid_t child = fork(); - if (child == 0) { - spin(spintime); - _exit(exit_code); - } - return child; -} - -absl::Duration RusageCpuTime(const struct rusage& ru) { - return absl::DurationFromTimeval(ru.ru_utime) + - absl::DurationFromTimeval(ru.ru_stime); -} - -// Returns the address of the top of the stack. -// Free with FreeStack. -uintptr_t AllocStack() { - void* addr = mmap(nullptr, kStackSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - - if (addr == MAP_FAILED) { - return reinterpret_cast<uintptr_t>(MAP_FAILED); - } - - return reinterpret_cast<uintptr_t>(addr) + kStackSize; -} - -// Frees a stack page allocated with AllocStack. -int FreeStack(uintptr_t addr) { - addr -= kStackSize; - return munmap(reinterpret_cast<void*>(addr), kPageSize); -} - -// CloneAndExit clones a child thread, which exits with 0 after sleeping for -// the specified duration (must be in seconds). extra_flags are ORed against -// the standard clone(2) flags. -int CloneAndExit(int64_t sleep, uintptr_t stack, int extra_flags) { - return clone(CloneChild, reinterpret_cast<void*>(stack), - CLONE_FILES | CLONE_FS | CLONE_SIGHAND | CLONE_VM | extra_flags, - reinterpret_cast<void*>(sleep)); -} - -// Simple wrappers around wait4(2) and waitid(2) that ignore interrupts. -constexpr auto Wait4 = RetryEINTR(wait4); -constexpr auto Waitid = RetryEINTR(waitid); - -// Fixture for tests parameterized by a function that waits for any child to -// exit with the given options, checks that it exited with the given code, and -// then returns its PID. -// -// N.B. These tests run in a multi-threaded environment. We assume that -// background threads do not create child processes and are not themselves -// created with clone(... | SIGCHLD). Either may cause these tests to -// erroneously wait on child processes/threads. -class WaitAnyChildTest : public ::testing::TestWithParam< - std::function<PosixErrorOr<pid_t>(int, int)>> { - protected: - PosixErrorOr<pid_t> WaitAny(int code) { return WaitAnyWithOptions(code, 0); } - - PosixErrorOr<pid_t> WaitAnyWithOptions(int code, int options) { - return GetParam()(code, options); - } -}; - -// Wait for any child to exit. -TEST_P(WaitAnyChildTest, Fork) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - - EXPECT_THAT(WaitAny(0), IsPosixErrorOkAndHolds(child)); -} - -// Call wait4 for any process after the child has already exited. -TEST_P(WaitAnyChildTest, AfterExit) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - - absl::SleepFor(absl::Seconds(5)); - - EXPECT_THAT(WaitAny(0), IsPosixErrorOkAndHolds(child)); -} - -// Wait for multiple children to exit, waiting for either at a time. -TEST_P(WaitAnyChildTest, MultipleFork) { - pid_t child1, child2; - ASSERT_THAT(child1 = ForkAndExit(0, 0), SyscallSucceeds()); - ASSERT_THAT(child2 = ForkAndExit(0, 0), SyscallSucceeds()); - - std::vector<pid_t> pids; - pids.push_back(ASSERT_NO_ERRNO_AND_VALUE(WaitAny(0))); - pids.push_back(ASSERT_NO_ERRNO_AND_VALUE(WaitAny(0))); - EXPECT_THAT(pids, UnorderedElementsAre(child1, child2)); -} - -// Wait for any child to exit. -// A non-CLONE_THREAD child which sends SIGCHLD upon exit behaves much like -// a forked process. -TEST_P(WaitAnyChildTest, CloneSIGCHLD) { - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - int child; - ASSERT_THAT(child = CloneAndExit(0, stack, SIGCHLD), SyscallSucceeds()); - - EXPECT_THAT(WaitAny(0), IsPosixErrorOkAndHolds(child)); -} - -// Wait for a child thread and process. -TEST_P(WaitAnyChildTest, ForkAndClone) { - pid_t process; - ASSERT_THAT(process = ForkAndExit(0, 0), SyscallSucceeds()); - - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - int thread; - // Send SIGCHLD for normal wait semantics. - ASSERT_THAT(thread = CloneAndExit(0, stack, SIGCHLD), SyscallSucceeds()); - - std::vector<pid_t> pids; - pids.push_back(ASSERT_NO_ERRNO_AND_VALUE(WaitAny(0))); - pids.push_back(ASSERT_NO_ERRNO_AND_VALUE(WaitAny(0))); - EXPECT_THAT(pids, UnorderedElementsAre(process, thread)); -} - -// Return immediately if no child has exited. -TEST_P(WaitAnyChildTest, WaitWNOHANG) { - EXPECT_THAT(WaitAnyWithOptions(0, WNOHANG), - PosixErrorIs(ECHILD, ::testing::_)); -} - -// Bad options passed -TEST_P(WaitAnyChildTest, BadOption) { - EXPECT_THAT(WaitAnyWithOptions(0, 123456), - PosixErrorIs(EINVAL, ::testing::_)); -} - -TEST_P(WaitAnyChildTest, WaitedChildRusage) { - struct rusage before; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &before), SyscallSucceeds()); - - pid_t child; - constexpr absl::Duration kSpin = absl::Seconds(3); - ASSERT_THAT(child = ForkSpinAndExit(0, absl::ToInt64Seconds(kSpin)), - SyscallSucceeds()); - ASSERT_THAT(WaitAny(0), IsPosixErrorOkAndHolds(child)); - - struct rusage after; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &after), SyscallSucceeds()); - - EXPECT_GE(RusageCpuTime(after) - RusageCpuTime(before), kSpin); -} - -TEST_P(WaitAnyChildTest, IgnoredChildRusage) { - // "POSIX.1-2001 specifies that if the disposition of SIGCHLD is - // set to SIG_IGN or the SA_NOCLDWAIT flag is set for SIGCHLD (see - // sigaction(2)), then children that terminate do not become zombies and a - // call to wait() or waitpid() will block until all children have terminated, - // and then fail with errno set to ECHILD." - waitpid(2) - // - // "RUSAGE_CHILDREN: Return resource usage statistics for all children of the - // calling process that have terminated *and been waited for*." - - // getrusage(2), emphasis added - - struct sigaction sa; - sa.sa_handler = SIG_IGN; - const auto cleanup_sigact = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa)); - - struct rusage before; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &before), SyscallSucceeds()); - - const absl::Duration start = - absl::Nanoseconds(clock_gettime_nsecs(CLOCK_MONOTONIC)); - - constexpr absl::Duration kSpin = absl::Seconds(3); - - // ForkAndSpin uses CLOCK_THREAD_CPUTIME_ID, which is lower resolution than, - // and may diverge from, CLOCK_MONOTONIC, so we allow a small grace period but - // still check that we blocked for a while. - constexpr absl::Duration kSpinGrace = absl::Milliseconds(100); - - pid_t child; - ASSERT_THAT(child = ForkSpinAndExit(0, absl::ToInt64Seconds(kSpin)), - SyscallSucceeds()); - ASSERT_THAT(WaitAny(0), PosixErrorIs(ECHILD, ::testing::_)); - const absl::Duration end = - absl::Nanoseconds(clock_gettime_nsecs(CLOCK_MONOTONIC)); - EXPECT_GE(end - start, kSpin - kSpinGrace); - - struct rusage after; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &after), SyscallSucceeds()); - EXPECT_EQ(before.ru_utime.tv_sec, after.ru_utime.tv_sec); - EXPECT_EQ(before.ru_utime.tv_usec, after.ru_utime.tv_usec); - EXPECT_EQ(before.ru_stime.tv_sec, after.ru_stime.tv_sec); - EXPECT_EQ(before.ru_stime.tv_usec, after.ru_stime.tv_usec); -} - -INSTANTIATE_TEST_SUITE_P( - Waiters, WaitAnyChildTest, - ::testing::Values( - [](int code, int options) -> PosixErrorOr<pid_t> { - int status; - auto const pid = Wait4(-1, &status, options, nullptr); - MaybeSave(); - if (pid < 0) { - return PosixError(errno, "wait4"); - } - if (!WIFEXITED(status) || WEXITSTATUS(status) != code) { - return PosixError( - EINVAL, absl::StrCat("unexpected wait status: got ", status, - ", wanted ", code)); - } - return static_cast<pid_t>(pid); - }, - [](int code, int options) -> PosixErrorOr<pid_t> { - siginfo_t si; - auto const rv = Waitid(P_ALL, 0, &si, WEXITED | options); - MaybeSave(); - if (rv < 0) { - return PosixError(errno, "waitid"); - } - if (si.si_signo != SIGCHLD) { - return PosixError( - EINVAL, absl::StrCat("unexpected signo: got ", si.si_signo, - ", wanted ", SIGCHLD)); - } - if (si.si_status != code) { - return PosixError( - EINVAL, absl::StrCat("unexpected status: got ", si.si_status, - ", wanted ", code)); - } - if (si.si_code != CLD_EXITED) { - return PosixError(EINVAL, - absl::StrCat("unexpected code: got ", si.si_code, - ", wanted ", CLD_EXITED)); - } - auto const uid = getuid(); - if (si.si_uid != uid) { - return PosixError(EINVAL, - absl::StrCat("unexpected uid: got ", si.si_uid, - ", wanted ", uid)); - } - return static_cast<pid_t>(si.si_pid); - })); - -// Fixture for tests parameterized by a (sysno, function) tuple. The function -// takes the PID of a specific child to wait for, waits for it to exit, and -// checks that it exits with the given code. -class WaitSpecificChildTest - : public ::testing::TestWithParam< - std::tuple<int, std::function<PosixError(pid_t, int, int)>>> { - protected: - int Sysno() { return std::get<0>(GetParam()); } - - PosixError WaitForWithOptions(pid_t pid, int options, int code) { - return std::get<1>(GetParam())(pid, options, code); - } - - PosixError WaitFor(pid_t pid, int code) { - return std::get<1>(GetParam())(pid, 0, code); - } -}; - -// Wait for specific child to exit. -TEST_P(WaitSpecificChildTest, Fork) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitFor(child, 0)); -} - -// Non-zero exit codes are correctly propagated. -TEST_P(WaitSpecificChildTest, NormalExit) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(42, 0), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitFor(child, 42)); -} - -// Wait for multiple children to exit. -TEST_P(WaitSpecificChildTest, MultipleFork) { - pid_t child1, child2; - ASSERT_THAT(child1 = ForkAndExit(0, 0), SyscallSucceeds()); - ASSERT_THAT(child2 = ForkAndExit(0, 0), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitFor(child1, 0)); - EXPECT_NO_ERRNO(WaitFor(child2, 0)); -} - -// Wait for multiple children to exit, out of the order they were created. -TEST_P(WaitSpecificChildTest, MultipleForkOutOfOrder) { - pid_t child1, child2; - ASSERT_THAT(child1 = ForkAndExit(0, 0), SyscallSucceeds()); - ASSERT_THAT(child2 = ForkAndExit(0, 0), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitFor(child2, 0)); - EXPECT_NO_ERRNO(WaitFor(child1, 0)); -} - -// Wait for specific child to exit, entering wait4 before the exit occurs. -TEST_P(WaitSpecificChildTest, ForkSleep) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 5), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitFor(child, 0)); -} - -// Wait should block until the child exits. -TEST_P(WaitSpecificChildTest, ForkBlock) { - pid_t child; - - auto start = absl::Now(); - ASSERT_THAT(child = ForkAndExit(0, 5), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitFor(child, 0)); - - EXPECT_GE(absl::Now() - start, absl::Seconds(5)); -} - -// Waiting after the child has already exited returns immediately. -TEST_P(WaitSpecificChildTest, AfterExit) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - - absl::SleepFor(absl::Seconds(5)); - - EXPECT_NO_ERRNO(WaitFor(child, 0)); -} - -// Wait for child of sibling thread. -TEST_P(WaitSpecificChildTest, SiblingChildren) { - absl::Mutex mu; - pid_t child; - bool ready = false; - bool stop = false; - - ScopedThread t([&] { - absl::MutexLock ml(&mu); - EXPECT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - ready = true; - mu.Await(absl::Condition(&stop)); - }); - - // N.B. This must be declared after ScopedThread, so it is destructed first, - // thus waking the thread. - absl::MutexLock ml(&mu); - mu.Await(absl::Condition(&ready)); - - EXPECT_NO_ERRNO(WaitFor(child, 0)); - - // Keep the sibling alive until after we've waited so the child isn't - // reparented. - stop = true; -} - -// Waiting for child of sibling thread not allowed with WNOTHREAD. -TEST_P(WaitSpecificChildTest, SiblingChildrenWNOTHREAD) { - // Linux added WNOTHREAD support to waitid(2) in - // 91c4e8ea8f05916df0c8a6f383508ac7c9e10dba ("wait: allow sys_waitid() to - // accept __WNOTHREAD/__WCLONE/__WALL"). i.e., Linux 4.7. - // - // Skip the test if it isn't supported yet. - if (Sysno() == SYS_waitid) { - int ret = waitid(P_ALL, 0, nullptr, WEXITED | WNOHANG | __WNOTHREAD); - SKIP_IF(ret < 0 && errno == EINVAL); - } - - absl::Mutex mu; - pid_t child; - bool ready = false; - bool stop = false; - - ScopedThread t([&] { - absl::MutexLock ml(&mu); - EXPECT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - ready = true; - mu.Await(absl::Condition(&stop)); - - // This thread can wait on child. - EXPECT_NO_ERRNO(WaitForWithOptions(child, __WNOTHREAD, 0)); - }); - - // N.B. This must be declared after ScopedThread, so it is destructed first, - // thus waking the thread. - absl::MutexLock ml(&mu); - mu.Await(absl::Condition(&ready)); - - // This thread can't wait on child. - EXPECT_THAT(WaitForWithOptions(child, __WNOTHREAD, 0), - PosixErrorIs(ECHILD, ::testing::_)); - - // Keep the sibling alive until after we've waited so the child isn't - // reparented. - stop = true; -} - -// Wait for specific child to exit. -// A non-CLONE_THREAD child which sends SIGCHLD upon exit behaves much like -// a forked process. -TEST_P(WaitSpecificChildTest, CloneSIGCHLD) { - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - int child; - ASSERT_THAT(child = CloneAndExit(0, stack, SIGCHLD), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitFor(child, 0)); -} - -// Wait for specific child to exit. -// A non-CLONE_THREAD child which does not send SIGCHLD upon exit can be waited -// on, but returns ECHILD. -TEST_P(WaitSpecificChildTest, CloneNoSIGCHLD) { - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - int child; - ASSERT_THAT(child = CloneAndExit(0, stack, 0), SyscallSucceeds()); - - EXPECT_THAT(WaitFor(child, 0), PosixErrorIs(ECHILD, ::testing::_)); -} - -// Waiting after the child has already exited returns immediately. -TEST_P(WaitSpecificChildTest, CloneAfterExit) { - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - int child; - // Send SIGCHLD for normal wait semantics. - ASSERT_THAT(child = CloneAndExit(0, stack, SIGCHLD), SyscallSucceeds()); - - absl::SleepFor(absl::Seconds(5)); - - EXPECT_NO_ERRNO(WaitFor(child, 0)); -} - -// A CLONE_THREAD child cannot be waited on. -TEST_P(WaitSpecificChildTest, CloneThread) { - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - int child; - ASSERT_THAT(child = CloneAndExit(15, stack, CLONE_THREAD), SyscallSucceeds()); - auto start = absl::Now(); - - EXPECT_THAT(WaitFor(child, 0), PosixErrorIs(ECHILD, ::testing::_)); - - // Ensure wait4 didn't block. - EXPECT_LE(absl::Now() - start, absl::Seconds(10)); - - // Since we can't wait on the child, we sleep to try to avoid freeing its - // stack before it exits. - absl::SleepFor(absl::Seconds(5)); -} - -// A child that does not send a SIGCHLD on exit may be waited on with -// the __WCLONE flag. -TEST_P(WaitSpecificChildTest, CloneWCLONE) { - // Linux added WCLONE support to waitid(2) in - // 91c4e8ea8f05916df0c8a6f383508ac7c9e10dba ("wait: allow sys_waitid() to - // accept __WNOTHREAD/__WCLONE/__WALL"). i.e., Linux 4.7. - // - // Skip the test if it isn't supported yet. - if (Sysno() == SYS_waitid) { - int ret = waitid(P_ALL, 0, nullptr, WEXITED | WNOHANG | __WCLONE); - SKIP_IF(ret < 0 && errno == EINVAL); - } - - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - int child; - ASSERT_THAT(child = CloneAndExit(0, stack, 0), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitForWithOptions(child, __WCLONE, 0)); -} - -// A forked child cannot be waited on with WCLONE. -TEST_P(WaitSpecificChildTest, ForkWCLONE) { - // Linux added WCLONE support to waitid(2) in - // 91c4e8ea8f05916df0c8a6f383508ac7c9e10dba ("wait: allow sys_waitid() to - // accept __WNOTHREAD/__WCLONE/__WALL"). i.e., Linux 4.7. - // - // Skip the test if it isn't supported yet. - if (Sysno() == SYS_waitid) { - int ret = waitid(P_ALL, 0, nullptr, WEXITED | WNOHANG | __WCLONE); - SKIP_IF(ret < 0 && errno == EINVAL); - } - - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - - EXPECT_THAT(WaitForWithOptions(child, WNOHANG | __WCLONE, 0), - PosixErrorIs(ECHILD, ::testing::_)); - - EXPECT_NO_ERRNO(WaitFor(child, 0)); -} - -// Any type of child can be waited on with WALL. -TEST_P(WaitSpecificChildTest, WALL) { - // Linux added WALL support to waitid(2) in - // 91c4e8ea8f05916df0c8a6f383508ac7c9e10dba ("wait: allow sys_waitid() to - // accept __WNOTHREAD/__WCLONE/__WALL"). i.e., Linux 4.7. - // - // Skip the test if it isn't supported yet. - if (Sysno() == SYS_waitid) { - int ret = waitid(P_ALL, 0, nullptr, WEXITED | WNOHANG | __WALL); - SKIP_IF(ret < 0 && errno == EINVAL); - } - - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitForWithOptions(child, __WALL, 0)); - - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - ASSERT_THAT(child = CloneAndExit(0, stack, 0), SyscallSucceeds()); - - EXPECT_NO_ERRNO(WaitForWithOptions(child, __WALL, 0)); -} - -// Return ECHILD for bad child. -TEST_P(WaitSpecificChildTest, BadChild) { - EXPECT_THAT(WaitFor(42, 0), PosixErrorIs(ECHILD, ::testing::_)); -} - -// Wait for a child process that only exits after calling execve(2) from a -// non-leader thread. -TEST_P(WaitSpecificChildTest, AfterChildExecve) { - ExecveArray const owned_child_argv = {"/bin/true"}; - char* const* const child_argv = owned_child_argv.get(); - - uintptr_t stack; - ASSERT_THAT(stack = AllocStack(), SyscallSucceeds()); - auto free = - Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); }); - - pid_t const child = fork(); - if (child == 0) { - // Give the parent some time to start waiting. - SleepSafe(absl::Seconds(5)); - // Pass CLONE_VFORK to block the original thread in the child process until - // the clone thread calls execve, annihilating them both. (This means that - // if clone returns at all, something went wrong.) - // - // N.B. clone(2) is not officially async-signal-safe, but at minimum glibc's - // x86_64 implementation is safe. See glibc - // sysdeps/unix/sysv/linux/x86_64/clone.S. - clone( - +[](void* arg) { - auto child_argv = static_cast<char* const*>(arg); - execve(child_argv[0], child_argv, /* envp = */ nullptr); - return errno; - }, - reinterpret_cast<void*>(stack), - CLONE_FILES | CLONE_FS | CLONE_SIGHAND | CLONE_THREAD | CLONE_VM | - CLONE_VFORK, - const_cast<char**>(child_argv)); - _exit(errno); - } - ASSERT_THAT(child, SyscallSucceeds()); - EXPECT_NO_ERRNO(WaitFor(child, 0)); -} - -PosixError CheckWait4(pid_t pid, int options, int code) { - int status; - auto const rv = Wait4(pid, &status, options, nullptr); - MaybeSave(); - if (rv < 0) { - return PosixError(errno, "wait4"); - } else if (rv != pid) { - return PosixError( - EINVAL, absl::StrCat("unexpected pid: got ", rv, ", wanted ", pid)); - } - if (!WIFEXITED(status) || WEXITSTATUS(status) != code) { - return PosixError(EINVAL, absl::StrCat("unexpected wait status: got ", - status, ", wanted ", code)); - } - return NoError(); -}; - -PosixError CheckWaitid(pid_t pid, int options, int code) { - siginfo_t si; - auto const rv = Waitid(P_PID, pid, &si, options | WEXITED); - MaybeSave(); - if (rv < 0) { - return PosixError(errno, "waitid"); - } - if (si.si_pid != pid) { - return PosixError(EINVAL, absl::StrCat("unexpected pid: got ", si.si_pid, - ", wanted ", pid)); - } - if (si.si_signo != SIGCHLD) { - return PosixError(EINVAL, absl::StrCat("unexpected signo: got ", - si.si_signo, ", wanted ", SIGCHLD)); - } - if (si.si_status != code) { - return PosixError(EINVAL, absl::StrCat("unexpected status: got ", - si.si_status, ", wanted ", code)); - } - if (si.si_code != CLD_EXITED) { - return PosixError(EINVAL, absl::StrCat("unexpected code: got ", si.si_code, - ", wanted ", CLD_EXITED)); - } - return NoError(); -} - -INSTANTIATE_TEST_SUITE_P( - Waiters, WaitSpecificChildTest, - ::testing::Values(std::make_tuple(SYS_wait4, CheckWait4), - std::make_tuple(SYS_waitid, CheckWaitid))); - -// WIFEXITED, WIFSIGNALED, WTERMSIG indicate signal exit. -TEST(WaitTest, SignalExit) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 10), SyscallSucceeds()); - - EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds()); - - int status; - EXPECT_THAT(Wait4(child, &status, 0, nullptr), - SyscallSucceedsWithValue(child)); - - EXPECT_FALSE(WIFEXITED(status)); - EXPECT_TRUE(WIFSIGNALED(status)); - EXPECT_EQ(SIGKILL, WTERMSIG(status)); -} - -// waitid requires at least one option. -TEST(WaitTest, WaitidOptions) { - EXPECT_THAT(Waitid(P_ALL, 0, nullptr, 0), SyscallFailsWithErrno(EINVAL)); -} - -// waitid does not wait for a child to exit if not passed WEXITED. -TEST(WaitTest, WaitidNoWEXITED) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds()); - EXPECT_THAT(Waitid(P_ALL, 0, nullptr, WSTOPPED), - SyscallFailsWithErrno(ECHILD)); - EXPECT_THAT(Waitid(P_ALL, 0, nullptr, WEXITED), SyscallSucceeds()); -} - -// WNOWAIT allows the same wait result to be returned again. -TEST(WaitTest, WaitidWNOWAIT) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(42, 0), SyscallSucceeds()); - - siginfo_t info; - ASSERT_THAT(Waitid(P_PID, child, &info, WEXITED | WNOWAIT), - SyscallSucceeds()); - EXPECT_EQ(child, info.si_pid); - EXPECT_EQ(SIGCHLD, info.si_signo); - EXPECT_EQ(CLD_EXITED, info.si_code); - EXPECT_EQ(42, info.si_status); - - ASSERT_THAT(Waitid(P_PID, child, &info, WEXITED), SyscallSucceeds()); - EXPECT_EQ(child, info.si_pid); - EXPECT_EQ(SIGCHLD, info.si_signo); - EXPECT_EQ(CLD_EXITED, info.si_code); - EXPECT_EQ(42, info.si_status); - - EXPECT_THAT(Waitid(P_PID, child, &info, WEXITED), - SyscallFailsWithErrno(ECHILD)); -} - -// waitpid(pid, status, options) is equivalent to -// wait4(pid, status, options, nullptr). -// This is a dedicated syscall on i386, glibc maps it to wait4 on amd64. -TEST(WaitTest, WaitPid) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(42, 0), SyscallSucceeds()); - - int status; - EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), - SyscallSucceedsWithValue(child)); - - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(42, WEXITSTATUS(status)); -} - -// Test that signaling a zombie succeeds. This is a signals test that is in this -// file for some reason. -TEST(WaitTest, KillZombie) { - pid_t child; - ASSERT_THAT(child = ForkAndExit(42, 0), SyscallSucceeds()); - - // Sleep for three seconds to ensure the child has exited. - absl::SleepFor(absl::Seconds(3)); - - // The child is now a zombie. Check that killing it returns 0. - EXPECT_THAT(kill(child, SIGTERM), SyscallSucceeds()); - EXPECT_THAT(kill(child, 0), SyscallSucceeds()); - - EXPECT_THAT(Wait4(child, nullptr, 0, nullptr), - SyscallSucceedsWithValue(child)); -} - -TEST(WaitTest, Wait4Rusage) { - pid_t child; - constexpr absl::Duration kSpin = absl::Seconds(3); - ASSERT_THAT(child = ForkSpinAndExit(21, absl::ToInt64Seconds(kSpin)), - SyscallSucceeds()); - - int status; - struct rusage rusage = {}; - ASSERT_THAT(Wait4(child, &status, 0, &rusage), - SyscallSucceedsWithValue(child)); - - EXPECT_TRUE(WIFEXITED(status)); - EXPECT_EQ(21, WEXITSTATUS(status)); - - EXPECT_GE(RusageCpuTime(rusage), kSpin); -} - -TEST(WaitTest, WaitidRusage) { - pid_t child; - constexpr absl::Duration kSpin = absl::Seconds(3); - ASSERT_THAT(child = ForkSpinAndExit(27, absl::ToInt64Seconds(kSpin)), - SyscallSucceeds()); - - siginfo_t si = {}; - struct rusage rusage = {}; - - // From waitid(2): - // The raw waitid() system call takes a fifth argument, of type - // struct rusage *. If this argument is non-NULL, then it is used - // to return resource usage information about the child, in the - // same manner as wait4(2). - EXPECT_THAT( - RetryEINTR(syscall)(SYS_waitid, P_PID, child, &si, WEXITED, &rusage), - SyscallSucceeds()); - EXPECT_EQ(si.si_signo, SIGCHLD); - EXPECT_EQ(si.si_code, CLD_EXITED); - EXPECT_EQ(si.si_status, 27); - EXPECT_EQ(si.si_pid, child); - - EXPECT_GE(RusageCpuTime(rusage), kSpin); -} - -// After bf959931ddb88c4e4366e96dd22e68fa0db9527c ("wait/ptrace: assume __WALL -// if the child is traced") (Linux 4.7), tracees are always eligible for -// waiting, regardless of type. -TEST(WaitTest, TraceeWALL) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - FileDescriptor rfd(fds[0]); - FileDescriptor wfd(fds[1]); - - pid_t child = fork(); - if (child == 0) { - // Child. - rfd.reset(); - - TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, nullptr, nullptr) == 0); - - // Notify parent that we're now a tracee. - wfd.reset(); - - _exit(0); - } - ASSERT_THAT(child, SyscallSucceeds()); - - wfd.reset(); - - // Wait for child to become tracee. - char c; - EXPECT_THAT(ReadFd(rfd.get(), &c, sizeof(c)), SyscallSucceedsWithValue(0)); - - // We can wait on the fork child with WCLONE, as it is a tracee. - int status; - if (IsRunningOnGvisor()) { - ASSERT_THAT(Wait4(child, &status, __WCLONE, nullptr), - SyscallSucceedsWithValue(child)); - - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) << status; - } else { - // On older versions of Linux, we may get ECHILD. - ASSERT_THAT(Wait4(child, &status, __WCLONE, nullptr), - ::testing::AnyOf(SyscallSucceedsWithValue(child), - SyscallFailsWithErrno(ECHILD))); - } -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc deleted file mode 100644 index 9b219cfd6..000000000 --- a/test/syscalls/linux/write.cc +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <signal.h> -#include <sys/resource.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <time.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/cleanup.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. -class WriteTest : public ::testing::Test { - public: - ssize_t WriteBytes(int fd, int bytes) { - std::vector<char> buf(bytes); - std::fill(buf.begin(), buf.end(), 'a'); - return WriteFd(fd, buf.data(), buf.size()); - } -}; - -TEST_F(WriteTest, WriteNoExceedsRLimit) { - // Get the current rlimit and restore after test run. - struct rlimit initial_lim; - ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - auto cleanup = Cleanup([&initial_lim] { - EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - }); - - int fd; - struct rlimit setlim; - const int target_lim = 1024; - setlim.rlim_cur = target_lim; - setlim.rlim_max = RLIM_INFINITY; - const std::string pathname = NewTempAbsPath(); - ASSERT_THAT(fd = open(pathname.c_str(), O_WRONLY | O_CREAT, S_IRWXU), - SyscallSucceeds()); - ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds()); - - EXPECT_THAT(WriteBytes(fd, target_lim), SyscallSucceedsWithValue(target_lim)); - - std::vector<char> buf(target_lim + 1); - std::fill(buf.begin(), buf.end(), 'a'); - EXPECT_THAT(pwrite(fd, buf.data(), target_lim, 1), SyscallSucceeds()); - EXPECT_THAT(pwrite64(fd, buf.data(), target_lim, 1), SyscallSucceeds()); - - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST_F(WriteTest, WriteExceedsRLimit) { - // Get the current rlimit and restore after test run. - struct rlimit initial_lim; - ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - auto cleanup = Cleanup([&initial_lim] { - EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds()); - }); - - int fd; - sigset_t filesize_mask; - sigemptyset(&filesize_mask); - sigaddset(&filesize_mask, SIGXFSZ); - - struct rlimit setlim; - const int target_lim = 1024; - setlim.rlim_cur = target_lim; - setlim.rlim_max = RLIM_INFINITY; - - const std::string pathname = NewTempAbsPath(); - ASSERT_THAT(fd = open(pathname.c_str(), O_WRONLY | O_CREAT, S_IRWXU), - SyscallSucceeds()); - ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds()); - ASSERT_THAT(sigprocmask(SIG_BLOCK, &filesize_mask, nullptr), - SyscallSucceeds()); - std::vector<char> buf(target_lim + 2); - std::fill(buf.begin(), buf.end(), 'a'); - - EXPECT_THAT(write(fd, buf.data(), target_lim + 1), - SyscallSucceedsWithValue(target_lim)); - EXPECT_THAT(write(fd, buf.data(), 1), SyscallFailsWithErrno(EFBIG)); - siginfo_t info; - struct timespec timelimit = {0, 0}; - ASSERT_THAT(RetryEINTR(sigtimedwait)(&filesize_mask, &info, &timelimit), - SyscallSucceedsWithValue(SIGXFSZ)); - EXPECT_EQ(info.si_code, SI_USER); - EXPECT_EQ(info.si_pid, getpid()); - EXPECT_EQ(info.si_uid, getuid()); - - EXPECT_THAT(pwrite(fd, buf.data(), target_lim + 1, 1), - SyscallSucceedsWithValue(target_lim - 1)); - EXPECT_THAT(pwrite(fd, buf.data(), 1, target_lim), - SyscallFailsWithErrno(EFBIG)); - ASSERT_THAT(RetryEINTR(sigtimedwait)(&filesize_mask, &info, &timelimit), - SyscallSucceedsWithValue(SIGXFSZ)); - EXPECT_EQ(info.si_code, SI_USER); - EXPECT_EQ(info.si_pid, getpid()); - EXPECT_EQ(info.si_uid, getuid()); - - EXPECT_THAT(pwrite64(fd, buf.data(), target_lim + 1, 1), - SyscallSucceedsWithValue(target_lim - 1)); - EXPECT_THAT(pwrite64(fd, buf.data(), 1, target_lim), - SyscallFailsWithErrno(EFBIG)); - ASSERT_THAT(RetryEINTR(sigtimedwait)(&filesize_mask, &info, &timelimit), - SyscallSucceedsWithValue(SIGXFSZ)); - EXPECT_EQ(info.si_code, SI_USER); - EXPECT_EQ(info.si_pid, getpid()); - EXPECT_EQ(info.si_uid, getuid()); - - ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &filesize_mask, nullptr), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc deleted file mode 100644 index 8b00ef44c..000000000 --- a/test/syscalls/linux/xattr.cc +++ /dev/null @@ -1,609 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <errno.h> -#include <fcntl.h> -#include <limits.h> -#include <sys/types.h> -#include <sys/xattr.h> -#include <unistd.h> - -#include <string> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/container/flat_hash_set.h" -#include "test/syscalls/linux/file_base.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -class XattrTest : public FileTest {}; - -TEST_F(XattrTest, XattrNonexistentFile) { - const char* path = "/does/not/exist"; - EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(getxattr(path, nullptr, nullptr, 0), - SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(ENOENT)); -} - -TEST_F(XattrTest, XattrNullName) { - const char* path = test_file_name_.c_str(); - - EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(EFAULT)); - EXPECT_THAT(getxattr(path, nullptr, nullptr, 0), - SyscallFailsWithErrno(EFAULT)); - EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(XattrTest, XattrEmptyName) { - const char* path = test_file_name_.c_str(); - - EXPECT_THAT(setxattr(path, "", nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(ERANGE)); - EXPECT_THAT(getxattr(path, "", nullptr, 0), SyscallFailsWithErrno(ERANGE)); - EXPECT_THAT(removexattr(path, ""), SyscallFailsWithErrno(ERANGE)); -} - -TEST_F(XattrTest, XattrLargeName) { - const char* path = test_file_name_.c_str(); - std::string name = "user."; - name += std::string(XATTR_NAME_MAX - name.length(), 'a'); - - // An xattr should be whitelisted before it can be accessed--do not allow - // arbitrary xattrs to be read/written in gVisor. - if (!IsRunningOnGvisor()) { - EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), - SyscallSucceeds()); - EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), - SyscallSucceedsWithValue(0)); - } - - name += "a"; - EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(ERANGE)); - EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), - SyscallFailsWithErrno(ERANGE)); - EXPECT_THAT(removexattr(path, name.c_str()), SyscallFailsWithErrno(ERANGE)); -} - -TEST_F(XattrTest, XattrInvalidPrefix) { - const char* path = test_file_name_.c_str(); - std::string name(XATTR_NAME_MAX, 'a'); - EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(EOPNOTSUPP)); - EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), - SyscallFailsWithErrno(EOPNOTSUPP)); - EXPECT_THAT(removexattr(path, name.c_str()), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -// Do not allow save/restore cycles after making the test file read-only, as -// the restore will fail to open it with r/w permissions. -TEST_F(XattrTest, XattrReadOnly_NoRandomSave) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - char val = 'a'; - size_t size = sizeof(val); - - EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); - - DisableSave ds; - ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IRUSR)); - - EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), - SyscallFailsWithErrno(EACCES)); - EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EACCES)); - - char buf = '-'; - EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size)); - EXPECT_EQ(buf, val); - - char list[sizeof(name)]; - EXPECT_THAT(listxattr(path, list, sizeof(list)), - SyscallSucceedsWithValue(sizeof(name))); - EXPECT_STREQ(list, name); -} - -// Do not allow save/restore cycles after making the test file write-only, as -// the restore will fail to open it with r/w permissions. -TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) { - // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - - DisableSave ds; - ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR)); - - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - char val = 'a'; - size_t size = sizeof(val); - - EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); - - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(EACCES)); - - // listxattr will succeed even without read permissions. - char list[sizeof(name)]; - EXPECT_THAT(listxattr(path, list, sizeof(list)), - SyscallSucceedsWithValue(sizeof(name))); - EXPECT_STREQ(list, name); - - EXPECT_THAT(removexattr(path, name), SyscallSucceeds()); -} - -TEST_F(XattrTest, XattrTrustedWithNonadmin) { - // TODO(b/148380782): Support setxattr and getxattr with "trusted" prefix. - SKIP_IF(IsRunningOnGvisor()); - SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - - const char* path = test_file_name_.c_str(); - const char name[] = "trusted.abc"; - EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); -} - -TEST_F(XattrTest, XattrOnDirectory) { - TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(dir.path().c_str(), name, nullptr, 0, /*flags=*/0), - SyscallSucceeds()); - EXPECT_THAT(getxattr(dir.path().c_str(), name, nullptr, 0), - SyscallSucceedsWithValue(0)); - - char list[sizeof(name)]; - EXPECT_THAT(listxattr(dir.path().c_str(), list, sizeof(list)), - SyscallSucceedsWithValue(sizeof(name))); - EXPECT_STREQ(list, name); - - EXPECT_THAT(removexattr(dir.path().c_str(), name), SyscallSucceeds()); -} - -TEST_F(XattrTest, XattrOnSymlink) { - TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(dir.path(), test_file_name_)); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(link.path().c_str(), name, nullptr, 0, /*flags=*/0), - SyscallSucceeds()); - EXPECT_THAT(getxattr(link.path().c_str(), name, nullptr, 0), - SyscallSucceedsWithValue(0)); - - char list[sizeof(name)]; - EXPECT_THAT(listxattr(link.path().c_str(), list, sizeof(list)), - SyscallSucceedsWithValue(sizeof(name))); - EXPECT_STREQ(list, name); - - EXPECT_THAT(removexattr(link.path().c_str(), name), SyscallSucceeds()); -} - -TEST_F(XattrTest, XattrOnInvalidFileTypes) { - const char name[] = "user.test"; - - char char_device[] = "/dev/zero"; - EXPECT_THAT(setxattr(char_device, name, nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(getxattr(char_device, name, nullptr, 0), - SyscallFailsWithErrno(ENODATA)); - EXPECT_THAT(listxattr(char_device, nullptr, 0), SyscallSucceedsWithValue(0)); - - // Use tmpfs, where creation of named pipes is supported. - const std::string fifo = NewTempAbsPathInDir("/dev/shm"); - const char* path = fifo.c_str(); - EXPECT_THAT(mknod(path, S_IFIFO | S_IRUSR | S_IWUSR, 0), SyscallSucceeds()); - EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); - EXPECT_THAT(listxattr(path, nullptr, 0), SyscallSucceedsWithValue(0)); - EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); -} - -TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - std::vector<char> val = {'a', 'a'}; - size_t size = 1; - EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), - SyscallSucceeds()); - - std::vector<char> buf = {'-', '-'}; - std::vector<char> expected_buf = {'a', '-'}; - EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), - SyscallSucceedsWithValue(size)); - EXPECT_EQ(buf, expected_buf); -} - -TEST_F(XattrTest, SetxattrZeroSize) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - char val = 'a'; - EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds()); - - char buf = '-'; - EXPECT_THAT(getxattr(path, name, &buf, XATTR_SIZE_MAX), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(buf, '-'); -} - -TEST_F(XattrTest, SetxattrSizeTooLarge) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - - // Note that each particular fs implementation may stipulate a lower size - // limit, in which case we actually may fail (e.g. error with ENOSPC) for - // some sizes under XATTR_SIZE_MAX. - size_t size = XATTR_SIZE_MAX + 1; - std::vector<char> val(size); - EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), - SyscallFailsWithErrno(E2BIG)); - - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); -} - -TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0), - SyscallFailsWithErrno(EFAULT)); - - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); -} - -TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); - - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); -} - -TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - std::vector<char> val(XATTR_SIZE_MAX + 1); - std::fill(val.begin(), val.end(), 'a'); - size_t size = 1; - EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), - SyscallSucceeds()); - - std::vector<char> buf = {'-', '-'}; - std::vector<char> expected_buf = {'a', '-'}; - EXPECT_THAT(getxattr(path, name, buf.data(), size), - SyscallSucceedsWithValue(size)); - EXPECT_EQ(buf, expected_buf); -} - -TEST_F(XattrTest, SetxattrReplaceWithSmaller) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - std::vector<char> val = {'a', 'a'}; - EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0), - SyscallSucceeds()); - EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0), - SyscallSucceeds()); - - std::vector<char> buf = {'-', '-'}; - std::vector<char> expected_buf = {'a', '-'}; - EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(1)); - EXPECT_EQ(buf, expected_buf); -} - -TEST_F(XattrTest, SetxattrReplaceWithLarger) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - std::vector<char> val = {'a', 'a'}; - EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0), - SyscallSucceeds()); - EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0), - SyscallSucceeds()); - - std::vector<char> buf = {'-', '-'}; - EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(2)); - EXPECT_EQ(buf, val); -} - -TEST_F(XattrTest, SetxattrCreateFlag) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), - SyscallSucceeds()); - EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), - SyscallFailsWithErrno(EEXIST)); - - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); -} - -TEST_F(XattrTest, SetxattrReplaceFlag) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), - SyscallFailsWithErrno(ENODATA)); - EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); - EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), - SyscallSucceeds()); - - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); -} - -TEST_F(XattrTest, SetxattrInvalidFlags) { - const char* path = test_file_name_.c_str(); - int invalid_flags = 0xff; - EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_F(XattrTest, Getxattr) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - int val = 1234; - size_t size = sizeof(val); - EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); - - int buf = 0; - EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size)); - EXPECT_EQ(buf, val); -} - -TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - std::vector<char> val = {'a', 'a'}; - size_t size = val.size(); - EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); - - char buf = '-'; - EXPECT_THAT(getxattr(path, name, &buf, 1), SyscallFailsWithErrno(ERANGE)); - EXPECT_EQ(buf, '-'); -} - -TEST_F(XattrTest, GetxattrSizeLargerThanValue) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - char val = 'a'; - EXPECT_THAT(setxattr(path, name, &val, 1, /*flags=*/0), SyscallSucceeds()); - - std::vector<char> buf(XATTR_SIZE_MAX); - std::fill(buf.begin(), buf.end(), '-'); - std::vector<char> expected_buf = buf; - expected_buf[0] = 'a'; - EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), - SyscallSucceedsWithValue(1)); - EXPECT_EQ(buf, expected_buf); -} - -TEST_F(XattrTest, GetxattrZeroSize) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - char val = 'a'; - EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0), - SyscallSucceeds()); - - char buf = '-'; - EXPECT_THAT(getxattr(path, name, &buf, 0), - SyscallSucceedsWithValue(sizeof(val))); - EXPECT_EQ(buf, '-'); -} - -TEST_F(XattrTest, GetxattrSizeTooLarge) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - char val = 'a'; - EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0), - SyscallSucceeds()); - - std::vector<char> buf(XATTR_SIZE_MAX + 1); - std::fill(buf.begin(), buf.end(), '-'); - std::vector<char> expected_buf = buf; - expected_buf[0] = 'a'; - EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), - SyscallSucceedsWithValue(sizeof(val))); - EXPECT_EQ(buf, expected_buf); -} - -TEST_F(XattrTest, GetxattrNullValue) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - char val = 'a'; - size_t size = sizeof(val); - EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); - - EXPECT_THAT(getxattr(path, name, nullptr, size), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - char val = 'a'; - size_t size = sizeof(val); - // Set value with zero size. - EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds()); - // Get value with nonzero size. - EXPECT_THAT(getxattr(path, name, nullptr, size), SyscallSucceedsWithValue(0)); - - // Set value with nonzero size. - EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); - // Get value with zero size. - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size)); -} - -TEST_F(XattrTest, GetxattrNonexistentName) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); -} - -TEST_F(XattrTest, Listxattr) { - const char* path = test_file_name_.c_str(); - const std::string name = "user.test"; - const std::string name2 = "user.test2"; - const std::string name3 = "user.test3"; - EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), - SyscallSucceeds()); - EXPECT_THAT(setxattr(path, name2.c_str(), nullptr, 0, /*flags=*/0), - SyscallSucceeds()); - EXPECT_THAT(setxattr(path, name3.c_str(), nullptr, 0, /*flags=*/0), - SyscallSucceeds()); - - std::vector<char> list(name.size() + 1 + name2.size() + 1 + name3.size() + 1); - char* buf = list.data(); - EXPECT_THAT(listxattr(path, buf, XATTR_SIZE_MAX), - SyscallSucceedsWithValue(list.size())); - - absl::flat_hash_set<std::string> got = {}; - for (char* p = buf; p < buf + list.size(); p += strlen(p) + 1) { - got.insert(std::string{p}); - } - - absl::flat_hash_set<std::string> expected = {name, name2, name3}; - EXPECT_EQ(got, expected); -} - -TEST_F(XattrTest, ListxattrNoXattrs) { - const char* path = test_file_name_.c_str(); - - std::vector<char> list, expected; - EXPECT_THAT(listxattr(path, list.data(), sizeof(list)), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(list, expected); - - // Listxattr should succeed if there are no attributes, even if the buffer - // passed in is a nullptr. - EXPECT_THAT(listxattr(path, nullptr, sizeof(list)), - SyscallSucceedsWithValue(0)); -} - -TEST_F(XattrTest, ListxattrNullBuffer) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); - - EXPECT_THAT(listxattr(path, nullptr, sizeof(name)), - SyscallFailsWithErrno(EFAULT)); -} - -TEST_F(XattrTest, ListxattrSizeTooSmall) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); - - char list[sizeof(name) - 1]; - EXPECT_THAT(listxattr(path, list, sizeof(list)), - SyscallFailsWithErrno(ERANGE)); -} - -TEST_F(XattrTest, ListxattrZeroSize) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); - EXPECT_THAT(listxattr(path, nullptr, 0), - SyscallSucceedsWithValue(sizeof(name))); -} - -TEST_F(XattrTest, RemoveXattr) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); - EXPECT_THAT(removexattr(path, name), SyscallSucceeds()); - EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); -} - -TEST_F(XattrTest, RemoveXattrNonexistentName) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENODATA)); -} - -TEST_F(XattrTest, LXattrOnSymlink) { - const char name[] = "user.test"; - TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(dir.path(), test_file_name_)); - - EXPECT_THAT(lsetxattr(link.path().c_str(), name, nullptr, 0, 0), - SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(lgetxattr(link.path().c_str(), name, nullptr, 0), - SyscallFailsWithErrno(ENODATA)); - EXPECT_THAT(llistxattr(link.path().c_str(), nullptr, 0), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(lremovexattr(link.path().c_str(), name), - SyscallFailsWithErrno(EPERM)); -} - -TEST_F(XattrTest, LXattrOnNonsymlink) { - const char* path = test_file_name_.c_str(); - const char name[] = "user.test"; - int val = 1234; - size_t size = sizeof(val); - EXPECT_THAT(lsetxattr(path, name, &val, size, /*flags=*/0), - SyscallSucceeds()); - - int buf = 0; - EXPECT_THAT(lgetxattr(path, name, &buf, size), - SyscallSucceedsWithValue(size)); - EXPECT_EQ(buf, val); - - char list[sizeof(name)]; - EXPECT_THAT(llistxattr(path, list, sizeof(list)), - SyscallSucceedsWithValue(sizeof(name))); - EXPECT_STREQ(list, name); - - EXPECT_THAT(lremovexattr(path, name), SyscallSucceeds()); -} - -TEST_F(XattrTest, XattrWithFD) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_.c_str(), 0)); - const char name[] = "user.test"; - int val = 1234; - size_t size = sizeof(val); - EXPECT_THAT(fsetxattr(fd.get(), name, &val, size, /*flags=*/0), - SyscallSucceeds()); - - int buf = 0; - EXPECT_THAT(fgetxattr(fd.get(), name, &buf, size), - SyscallSucceedsWithValue(size)); - EXPECT_EQ(buf, val); - - char list[sizeof(name)]; - EXPECT_THAT(flistxattr(fd.get(), list, sizeof(list)), - SyscallSucceedsWithValue(sizeof(name))); - EXPECT_STREQ(list, name); - - EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds()); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/uds/BUILD b/test/uds/BUILD deleted file mode 100644 index 51e2c7ce8..000000000 --- a/test/uds/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -go_library( - name = "uds", - testonly = 1, - srcs = ["uds.go"], - deps = [ - "//pkg/log", - "//pkg/unet", - ], -) diff --git a/test/uds/uds.go b/test/uds/uds.go deleted file mode 100644 index b714c61b0..000000000 --- a/test/uds/uds.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package uds contains helpers for testing external UDS functionality. -package uds - -import ( - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "syscall" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/unet" -) - -// createEchoSocket creates a socket that echoes back anything received. -// -// Only works for stream, seqpacket sockets. -func createEchoSocket(path string, protocol int) (cleanup func(), err error) { - fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) - if err != nil { - return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err) - } - - if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { - return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err) - } - - if err := syscall.Listen(fd, 0); err != nil { - return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err) - } - - server, err := unet.NewServerSocket(fd) - if err != nil { - return nil, fmt.Errorf("error creating echo(%d) unet socket: %v", protocol, err) - } - - acceptAndEchoOne := func() error { - s, err := server.Accept() - if err != nil { - return fmt.Errorf("failed to accept: %v", err) - } - defer s.Close() - - for { - buf := make([]byte, 512) - for { - n, err := s.Read(buf) - if err == io.EOF { - return nil - } - if err != nil { - return fmt.Errorf("failed to read: %d, %v", n, err) - } - - n, err = s.Write(buf[:n]) - if err != nil { - return fmt.Errorf("failed to write: %d, %v", n, err) - } - } - } - } - - go func() { - for { - if err := acceptAndEchoOne(); err != nil { - log.Warningf("Failed to handle echo(%d) socket: %v", protocol, err) - return - } - } - }() - - cleanup = func() { - if err := server.Close(); err != nil { - log.Warningf("Failed to close echo(%d) socket: %v", protocol, err) - } - } - - return cleanup, nil -} - -// createNonListeningSocket creates a socket that is bound but not listening. -// -// Only relevant for stream, seqpacket sockets. -func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) { - fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) - if err != nil { - return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err) - } - - if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { - return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err) - } - - cleanup = func() { - if err := syscall.Close(fd); err != nil { - log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err) - } - } - - return cleanup, nil -} - -// createNullSocket creates a socket that reads anything received. -// -// Only works for dgram sockets. -func createNullSocket(path string, protocol int) (cleanup func(), err error) { - fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) - if err != nil { - return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err) - } - - if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { - return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err) - } - - s, err := unet.NewSocket(fd) - if err != nil { - return nil, fmt.Errorf("error creating null(%d) unet socket: %v", protocol, err) - } - - go func() { - buf := make([]byte, 512) - for { - n, err := s.Read(buf) - if err != nil { - log.Warningf("failed to read: %d, %v", n, err) - return - } - } - }() - - cleanup = func() { - if err := s.Close(); err != nil { - log.Warningf("Failed to close null(%d) socket: %v", protocol, err) - } - } - - return cleanup, nil -} - -type socketCreator func(path string, proto int) (cleanup func(), err error) - -// CreateSocketTree creates a local tree of unix domain sockets for use in -// testing: -// * /stream/echo -// * /stream/nonlistening -// * /seqpacket/echo -// * /seqpacket/nonlistening -// * /dgram/null -func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) { - dir, err = ioutil.TempDir(baseDir, "sockets") - if err != nil { - return "", nil, fmt.Errorf("error creating temp dir: %v", err) - } - - var protocols = []struct { - protocol int - name string - sockets map[string]socketCreator - }{ - { - protocol: syscall.SOCK_STREAM, - name: "stream", - sockets: map[string]socketCreator{ - "echo": createEchoSocket, - "nonlistening": createNonListeningSocket, - }, - }, - { - protocol: syscall.SOCK_SEQPACKET, - name: "seqpacket", - sockets: map[string]socketCreator{ - "echo": createEchoSocket, - "nonlistening": createNonListeningSocket, - }, - }, - { - protocol: syscall.SOCK_DGRAM, - name: "dgram", - sockets: map[string]socketCreator{ - "null": createNullSocket, - }, - }, - } - - var cleanups []func() - for _, proto := range protocols { - protoDir := filepath.Join(dir, proto.name) - if err := os.Mkdir(protoDir, 0755); err != nil { - return "", nil, fmt.Errorf("error creating %s dir: %v", proto.name, err) - } - - for name, fn := range proto.sockets { - path := filepath.Join(protoDir, name) - cleanup, err := fn(path, proto.protocol) - if err != nil { - return "", nil, fmt.Errorf("error creating %s %s socket: %v", proto.name, name, err) - } - - cleanups = append(cleanups, cleanup) - } - } - - cleanup = func() { - for _, c := range cleanups { - c() - } - - os.RemoveAll(dir) - } - - return dir, cleanup, nil -} diff --git a/test/util/BUILD b/test/util/BUILD deleted file mode 100644 index 2a17c33ee..000000000 --- a/test/util/BUILD +++ /dev/null @@ -1,358 +0,0 @@ -load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -cc_library( - name = "capability_util", - testonly = 1, - srcs = ["capability_util.cc"], - hdrs = ["capability_util.h"], - deps = [ - ":cleanup", - ":memory_util", - ":posix_error", - ":save_util", - ":test_util", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "eventfd_util", - testonly = 1, - hdrs = ["eventfd_util.h"], - deps = [ - ":file_descriptor", - ":posix_error", - ":save_util", - ], -) - -cc_library( - name = "file_descriptor", - testonly = 1, - hdrs = ["file_descriptor.h"], - deps = [ - ":logging", - ":posix_error", - ":save_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - gtest, - ], -) - -cc_library( - name = "proc_util", - testonly = 1, - srcs = ["proc_util.cc"], - hdrs = ["proc_util.h"], - deps = [ - ":fs_util", - ":posix_error", - ":test_util", - "@com_google_absl//absl/strings", - gtest, - ], -) - -cc_test( - name = "proc_util_test", - size = "small", - srcs = ["proc_util_test.cc"], - deps = [ - ":proc_util", - ":test_main", - ":test_util", - gtest, - ], -) - -cc_library( - name = "cleanup", - testonly = 1, - hdrs = ["cleanup.h"], -) - -cc_library( - name = "fs_util", - testonly = 1, - srcs = ["fs_util.cc"], - hdrs = ["fs_util.h"], - deps = [ - ":cleanup", - ":file_descriptor", - ":posix_error", - "@com_google_absl//absl/strings", - gtest, - ], -) - -cc_test( - name = "fs_util_test", - size = "small", - srcs = ["fs_util_test.cc"], - deps = [ - ":fs_util", - ":posix_error", - ":temp_path", - ":test_main", - ":test_util", - gtest, - ], -) - -cc_library( - name = "logging", - testonly = 1, - srcs = ["logging.cc"], - hdrs = ["logging.h"], -) - -cc_library( - name = "memory_util", - testonly = 1, - hdrs = ["memory_util.h"], - deps = [ - ":logging", - ":posix_error", - ":save_util", - ":test_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_library( - name = "mount_util", - testonly = 1, - hdrs = ["mount_util.h"], - deps = [ - ":cleanup", - ":posix_error", - ":test_util", - gtest, - ], -) - -cc_library( - name = "save_util", - testonly = 1, - srcs = [ - "save_util.cc", - "save_util_linux.cc", - "save_util_other.cc", - ], - hdrs = ["save_util.h"], - defines = select_system(), -) - -cc_library( - name = "multiprocess_util", - testonly = 1, - srcs = ["multiprocess_util.cc"], - hdrs = ["multiprocess_util.h"], - deps = [ - ":cleanup", - ":file_descriptor", - ":posix_error", - ":save_util", - ":test_util", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "platform_util", - testonly = 1, - srcs = ["platform_util.cc"], - hdrs = ["platform_util.h"], - deps = [":test_util"], -) - -cc_library( - name = "posix_error", - testonly = 1, - srcs = ["posix_error.cc"], - hdrs = ["posix_error.h"], - deps = [ - ":logging", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:variant", - gtest, - ], -) - -cc_test( - name = "posix_error_test", - size = "small", - srcs = ["posix_error_test.cc"], - deps = [ - ":posix_error", - ":test_main", - gtest, - ], -) - -cc_library( - name = "pty_util", - testonly = 1, - srcs = ["pty_util.cc"], - hdrs = ["pty_util.h"], - deps = [ - ":file_descriptor", - ":posix_error", - ], -) - -cc_library( - name = "signal_util", - testonly = 1, - srcs = ["signal_util.cc"], - hdrs = ["signal_util.h"], - deps = [ - ":cleanup", - ":posix_error", - ":test_util", - gtest, - ], -) - -cc_library( - name = "temp_path", - testonly = 1, - srcs = ["temp_path.cc"], - hdrs = ["temp_path.h"], - deps = [ - ":fs_util", - ":posix_error", - ":test_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - gtest, - ], -) - -cc_library( - name = "test_util", - testonly = 1, - srcs = [ - "test_util.cc", - "test_util_impl.cc", - "test_util_runfiles.cc", - ], - hdrs = ["test_util.h"], - defines = select_system(), - deps = [ - ":fs_util", - ":logging", - ":posix_error", - ":save_util", - "@bazel_tools//tools/cpp/runfiles", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", - gtest, - gbenchmark, - ], -) - -cc_library( - name = "thread_util", - testonly = 1, - hdrs = ["thread_util.h"], - deps = [":logging"], -) - -cc_library( - name = "time_util", - testonly = 1, - srcs = ["time_util.cc"], - hdrs = ["time_util.h"], - deps = [ - "@com_google_absl//absl/time", - ], -) - -cc_library( - name = "timer_util", - testonly = 1, - srcs = ["timer_util.cc"], - hdrs = ["timer_util.h"], - deps = [ - ":cleanup", - ":logging", - ":posix_error", - ":test_util", - "@com_google_absl//absl/time", - gtest, - ], -) - -cc_test( - name = "test_util_test", - size = "small", - srcs = ["test_util_test.cc"], - deps = [ - ":test_main", - ":test_util", - gtest, - ], -) - -cc_library( - name = "test_main", - testonly = 1, - srcs = ["test_main.cc"], - deps = [":test_util"], -) - -cc_library( - name = "epoll_util", - testonly = 1, - srcs = ["epoll_util.cc"], - hdrs = ["epoll_util.h"], - deps = [ - ":file_descriptor", - ":posix_error", - ":save_util", - gtest, - ], -) - -cc_library( - name = "rlimit_util", - testonly = 1, - srcs = ["rlimit_util.cc"], - hdrs = ["rlimit_util.h"], - deps = [ - ":cleanup", - ":logging", - ":posix_error", - ":test_util", - ], -) - -cc_library( - name = "uid_util", - testonly = 1, - srcs = ["uid_util.cc"], - hdrs = ["uid_util.h"], - deps = [ - ":posix_error", - ":save_util", - ], -) - -cc_library( - name = "temp_umask", - testonly = 1, - hdrs = ["temp_umask.h"], -) diff --git a/test/util/capability_util.cc b/test/util/capability_util.cc deleted file mode 100644 index 9fee52fbb..000000000 --- a/test/util/capability_util.cc +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/capability_util.h" - -#include <linux/capability.h> -#include <sched.h> -#include <sys/mman.h> -#include <sys/wait.h> - -#include <iostream> - -#include "absl/strings/str_cat.h" -#include "test/util/memory_util.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<bool> CanCreateUserNamespace() { - // The most reliable way to determine if userns creation is possible is by - // trying to create one; see below. - ASSIGN_OR_RETURN_ERRNO( - auto child_stack, - MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - int const child_pid = clone( - +[](void*) { return 0; }, - reinterpret_cast<void*>(child_stack.addr() + kPageSize), - CLONE_NEWUSER | SIGCHLD, /* arg = */ nullptr); - if (child_pid > 0) { - int status; - int const ret = waitpid(child_pid, &status, /* options = */ 0); - MaybeSave(); - if (ret < 0) { - return PosixError(errno, "waitpid"); - } - if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) { - return PosixError( - ESRCH, absl::StrCat("child process exited with status ", status)); - } - return true; - } else if (errno == EPERM) { - // Per clone(2), EPERM can be returned if: - // - // - "CLONE_NEWUSER was specified in flags, but either the effective user ID - // or the effective group ID of the caller does not have a mapping in the - // parent namespace (see user_namespaces(7))." - // - // - "(since Linux 3.9) CLONE_NEWUSER was specified in flags and the caller - // is in a chroot environment (i.e., the caller's root directory does - // not match the root directory of the mount namespace in which it - // resides)." - std::cerr << "clone(CLONE_NEWUSER) failed with EPERM"; - return false; - } else if (errno == EUSERS) { - // "(since Linux 3.11) CLONE_NEWUSER was specified in flags, and the call - // would cause the limit on the number of nested user namespaces to be - // exceeded. See user_namespaces(7)." - std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS"; - return false; - } else { - // Unexpected error code; indicate an actual error. - return PosixError(errno, "clone(CLONE_NEWUSER)"); - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/capability_util.h b/test/util/capability_util.h deleted file mode 100644 index bb9ea1fe5..000000000 --- a/test/util/capability_util.h +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities for testing capabilities. - -#ifndef GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_ -#define GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_ - -#include <errno.h> -#include <linux/capability.h> -#include <sys/syscall.h> -#include <unistd.h> - -#include "test/util/cleanup.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" -#include "test/util/test_util.h" - -#ifndef _LINUX_CAPABILITY_VERSION_3 -#error Expecting _LINUX_CAPABILITY_VERSION_3 support -#endif - -namespace gvisor { -namespace testing { - -// HaveCapability returns true if the process has the specified EFFECTIVE -// capability. -inline PosixErrorOr<bool> HaveCapability(int cap) { - if (!cap_valid(cap)) { - return PosixError(EINVAL, "Invalid capability"); - } - - struct __user_cap_header_struct header = {_LINUX_CAPABILITY_VERSION_3, 0}; - struct __user_cap_data_struct caps[_LINUX_CAPABILITY_U32S_3] = {}; - RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capget, &header, &caps)); - MaybeSave(); - - return (caps[CAP_TO_INDEX(cap)].effective & CAP_TO_MASK(cap)) != 0; -} - -// SetCapability sets the specified EFFECTIVE capability. -inline PosixError SetCapability(int cap, bool set) { - if (!cap_valid(cap)) { - return PosixError(EINVAL, "Invalid capability"); - } - - struct __user_cap_header_struct header = {_LINUX_CAPABILITY_VERSION_3, 0}; - struct __user_cap_data_struct caps[_LINUX_CAPABILITY_U32S_3] = {}; - RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capget, &header, &caps)); - MaybeSave(); - - if (set) { - caps[CAP_TO_INDEX(cap)].effective |= CAP_TO_MASK(cap); - } else { - caps[CAP_TO_INDEX(cap)].effective &= ~CAP_TO_MASK(cap); - } - header = {_LINUX_CAPABILITY_VERSION_3, 0}; - RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capset, &header, &caps)); - MaybeSave(); - - return NoError(); -} - -// DropPermittedCapability drops the specified PERMITTED. The EFFECTIVE -// capabilities must be a subset of PERMITTED, so those are dropped as well. -inline PosixError DropPermittedCapability(int cap) { - if (!cap_valid(cap)) { - return PosixError(EINVAL, "Invalid capability"); - } - - struct __user_cap_header_struct header = {_LINUX_CAPABILITY_VERSION_3, 0}; - struct __user_cap_data_struct caps[_LINUX_CAPABILITY_U32S_3] = {}; - RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capget, &header, &caps)); - MaybeSave(); - - caps[CAP_TO_INDEX(cap)].effective &= ~CAP_TO_MASK(cap); - caps[CAP_TO_INDEX(cap)].permitted &= ~CAP_TO_MASK(cap); - - header = {_LINUX_CAPABILITY_VERSION_3, 0}; - RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capset, &header, &caps)); - MaybeSave(); - - return NoError(); -} - -PosixErrorOr<bool> CanCreateUserNamespace(); - -} // namespace testing -} // namespace gvisor -#endif // GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_ diff --git a/test/util/cleanup.h b/test/util/cleanup.h deleted file mode 100644 index c76482ef4..000000000 --- a/test/util/cleanup.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_CLEANUP_H_ -#define GVISOR_TEST_UTIL_CLEANUP_H_ - -#include <functional> -#include <utility> - -namespace gvisor { -namespace testing { - -class Cleanup { - public: - Cleanup() : released_(true) {} - explicit Cleanup(std::function<void()>&& callback) : cb_(callback) {} - - Cleanup(Cleanup&& other) { - released_ = other.released_; - cb_ = other.Release(); - } - - Cleanup& operator=(Cleanup&& other) { - released_ = other.released_; - cb_ = other.Release(); - return *this; - } - - ~Cleanup() { - if (!released_) { - cb_(); - } - } - - std::function<void()>&& Release() { - released_ = true; - return std::move(cb_); - } - - private: - Cleanup(Cleanup const& other) = delete; - - bool released_ = false; - std::function<void(void)> cb_; -}; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_CLEANUP_H_ diff --git a/test/util/epoll_util.cc b/test/util/epoll_util.cc deleted file mode 100644 index 2e5051468..000000000 --- a/test/util/epoll_util.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/epoll_util.h" - -#include <sys/epoll.h> - -#include "gmock/gmock.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<FileDescriptor> NewEpollFD(int size) { - // "Since Linux 2.6.8, the size argument is ignored, but must be greater than - // zero." - epoll_create(2) - int fd = epoll_create(size); - MaybeSave(); - if (fd < 0) { - return PosixError(errno, "epoll_create"); - } - return FileDescriptor(fd); -} - -PosixError RegisterEpollFD(int epoll_fd, int target_fd, int events, - uint64_t data) { - struct epoll_event event; - event.events = events; - event.data.u64 = data; - int rc = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, target_fd, &event); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "epoll_ctl"); - } - return NoError(); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/epoll_util.h b/test/util/epoll_util.h deleted file mode 100644 index f233b37d5..000000000 --- a/test/util/epoll_util.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_EPOLL_UTIL_H_ -#define GVISOR_TEST_UTIL_EPOLL_UTIL_H_ - -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// Returns a new epoll file descriptor. -PosixErrorOr<FileDescriptor> NewEpollFD(int size = 1); - -// Registers `target_fd` with the epoll instance represented by `epoll_fd` for -// the epoll events `events`. Events on `target_fd` will be indicated by setting -// data.u64 to `data` in the returned epoll_event. -PosixError RegisterEpollFD(int epoll_fd, int target_fd, int events, - uint64_t data); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_EPOLL_UTIL_H_ diff --git a/test/util/eventfd_util.h b/test/util/eventfd_util.h deleted file mode 100644 index cb9ce829c..000000000 --- a/test/util/eventfd_util.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_EVENTFD_UTIL_H_ -#define GVISOR_TEST_UTIL_EVENTFD_UTIL_H_ - -#include <sys/eventfd.h> - -#include <cerrno> - -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" - -namespace gvisor { -namespace testing { - -// Returns a new eventfd with the given initial value and flags. -inline PosixErrorOr<FileDescriptor> NewEventFD(unsigned int initval = 0, - int flags = 0) { - int fd = eventfd(initval, flags); - MaybeSave(); - if (fd < 0) { - return PosixError(errno, "eventfd"); - } - return FileDescriptor(fd); -} - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_EVENTFD_UTIL_H_ diff --git a/test/util/file_descriptor.h b/test/util/file_descriptor.h deleted file mode 100644 index fc5caa55b..000000000 --- a/test/util/file_descriptor.h +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_FILE_DESCRIPTOR_H_ -#define GVISOR_TEST_UTIL_FILE_DESCRIPTOR_H_ - -#include <fcntl.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> -#include <string> - -#include "gmock/gmock.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "test/util/logging.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" - -namespace gvisor { -namespace testing { - -// FileDescriptor is an RAII type class which takes ownership of a file -// descriptor. It will close the FD when this object goes out of scope. -class FileDescriptor { - public: - // Constructs an empty FileDescriptor (one that does not own a file - // descriptor). - FileDescriptor() = default; - - // Constructs a FileDescriptor that owns fd. If fd is negative, constructs an - // empty FileDescriptor. - explicit FileDescriptor(int fd) { set_fd(fd); } - - FileDescriptor(FileDescriptor&& orig) : fd_(orig.release()) {} - - FileDescriptor& operator=(FileDescriptor&& orig) { - reset(orig.release()); - return *this; - } - - PosixErrorOr<FileDescriptor> Dup() const { - if (fd_ < 0) { - return PosixError(EINVAL, "Attempting to Dup unset fd"); - } - - int fd = dup(fd_); - if (fd < 0) { - return PosixError(errno, absl::StrCat("dup ", fd_)); - } - MaybeSave(); - return FileDescriptor(fd); - } - - FileDescriptor(FileDescriptor const& other) = delete; - FileDescriptor& operator=(FileDescriptor const& other) = delete; - - ~FileDescriptor() { reset(); } - - // If this object is non-empty, returns the owned file descriptor. (Ownership - // is retained by the FileDescriptor.) Otherwise returns -1. - int get() const { return fd_; } - - // If this object is non-empty, transfers ownership of the file descriptor to - // the caller and returns it. Otherwise returns -1. - int release() { - int const fd = fd_; - fd_ = -1; - return fd; - } - - // If this object is non-empty, closes the owned file descriptor (recording a - // test failure if the close fails). - void reset() { reset(-1); } - - // Like no-arg reset(), but the FileDescriptor takes ownership of fd after - // closing its existing file descriptor. - void reset(int fd) { - if (fd_ >= 0) { - TEST_PCHECK(close(fd_) == 0); - MaybeSave(); - } - set_fd(fd); - } - - private: - // Wrapper that coerces negative fd values other than -1 to -1 so that get() - // etc. return -1. - void set_fd(int fd) { fd_ = std::max(fd, -1); } - - int fd_ = -1; -}; - -// Wrapper around open(2) that returns a FileDescriptor. -inline PosixErrorOr<FileDescriptor> Open(std::string const& path, int flags, - mode_t mode = 0) { - int fd = open(path.c_str(), flags, mode); - if (fd < 0) { - return PosixError(errno, absl::StrFormat("open(%s, %#x, %#o)", path.c_str(), - flags, mode)); - } - MaybeSave(); - return FileDescriptor(fd); -} - -// Wrapper around openat(2) that returns a FileDescriptor. -inline PosixErrorOr<FileDescriptor> OpenAt(int dirfd, std::string const& path, - int flags, mode_t mode = 0) { - int fd = openat(dirfd, path.c_str(), flags, mode); - if (fd < 0) { - return PosixError(errno, absl::StrFormat("openat(%d, %s, %#x, %#o)", dirfd, - path, flags, mode)); - } - MaybeSave(); - return FileDescriptor(fd); -} - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_FILE_DESCRIPTOR_H_ diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc deleted file mode 100644 index 052781445..000000000 --- a/test/util/fs_util.cc +++ /dev/null @@ -1,633 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/fs_util.h" - -#include <dirent.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "gmock/gmock.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -namespace { -PosixError WriteContentsToFD(int fd, absl::string_view contents) { - int written = 0; - while (static_cast<absl::string_view::size_type>(written) < contents.size()) { - int wrote = write(fd, contents.data() + written, contents.size() - written); - if (wrote < 0) { - if (errno == EINTR) { - continue; - } - return PosixError( - errno, absl::StrCat("WriteContentsToFD fd: ", fd, " write failure.")); - } - written += wrote; - } - return NoError(); -} -} // namespace - -namespace internal { - -// Given a collection of file paths, append them all together, -// ensuring that the proper path separators are inserted between them. -std::string JoinPathImpl(std::initializer_list<absl::string_view> paths) { - std::string result; - - if (paths.size() != 0) { - // This size calculation is worst-case: it assumes one extra "/" for every - // path other than the first. - size_t total_size = paths.size() - 1; - for (const absl::string_view path : paths) total_size += path.size(); - result.resize(total_size); - - auto begin = result.begin(); - auto out = begin; - bool trailing_slash = false; - for (absl::string_view path : paths) { - if (path.empty()) continue; - if (path.front() == '/') { - if (trailing_slash) { - path.remove_prefix(1); - } - } else { - if (!trailing_slash && out != begin) *out++ = '/'; - } - const size_t this_size = path.size(); - memcpy(&*out, path.data(), this_size); - out += this_size; - trailing_slash = out[-1] == '/'; - } - result.erase(out - begin); - } - return result; -} -} // namespace internal - -// Returns a status or the current working directory. -PosixErrorOr<std::string> GetCWD() { - char buffer[PATH_MAX + 1] = {}; - if (getcwd(buffer, PATH_MAX) == nullptr) { - return PosixError(errno, "GetCWD() failed"); - } - - return std::string(buffer); -} - -PosixErrorOr<struct stat> Stat(absl::string_view path) { - struct stat stat_buf; - int res = stat(std::string(path).c_str(), &stat_buf); - if (res < 0) { - return PosixError(errno, absl::StrCat("stat ", path)); - } - return stat_buf; -} - -PosixErrorOr<struct stat> Lstat(absl::string_view path) { - struct stat stat_buf; - int res = lstat(std::string(path).c_str(), &stat_buf); - if (res < 0) { - return PosixError(errno, absl::StrCat("lstat ", path)); - } - return stat_buf; -} - -PosixErrorOr<struct stat> Fstat(int fd) { - struct stat stat_buf; - int res = fstat(fd, &stat_buf); - if (res < 0) { - return PosixError(errno, absl::StrCat("fstat ", fd)); - } - return stat_buf; -} - -PosixErrorOr<bool> Exists(absl::string_view path) { - struct stat stat_buf; - int res = stat(std::string(path).c_str(), &stat_buf); - if (res < 0) { - if (errno == ENOENT) { - return false; - } - return PosixError(errno, absl::StrCat("stat ", path)); - } - return true; -} - -PosixErrorOr<bool> IsDirectory(absl::string_view path) { - ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Lstat(path)); - if (S_ISDIR(stat_buf.st_mode)) { - return true; - } - - return false; -} - -PosixError Delete(absl::string_view path) { - int res = unlink(std::string(path).c_str()); - if (res < 0) { - return PosixError(errno, absl::StrCat("unlink ", path)); - } - - return NoError(); -} - -PosixError Truncate(absl::string_view path, int length) { - int res = truncate(std::string(path).c_str(), length); - if (res < 0) { - return PosixError(errno, - absl::StrCat("truncate ", path, " to length ", length)); - } - - return NoError(); -} - -PosixError Chmod(absl::string_view path, int mode) { - int res = chmod(std::string(path).c_str(), mode); - if (res < 0) { - return PosixError(errno, absl::StrCat("chmod ", path)); - } - - return NoError(); -} - -PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, - dev_t dev) { - int res = mknodat(dfd.get(), std::string(path).c_str(), mode, dev); - if (res < 0) { - return PosixError(errno, absl::StrCat("mknod ", path)); - } - - return NoError(); -} - -PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, - int flags) { - int res = unlinkat(dfd.get(), std::string(path).c_str(), flags); - if (res < 0) { - return PosixError(errno, absl::StrCat("unlink ", path)); - } - - return NoError(); -} - -PosixError Mkdir(absl::string_view path, int mode) { - int res = mkdir(std::string(path).c_str(), mode); - if (res < 0) { - return PosixError(errno, absl::StrCat("mkdir ", path, " mode ", mode)); - } - - return NoError(); -} - -PosixError Rmdir(absl::string_view path) { - int res = rmdir(std::string(path).c_str()); - if (res < 0) { - return PosixError(errno, absl::StrCat("rmdir ", path)); - } - - return NoError(); -} - -PosixError SetContents(absl::string_view path, absl::string_view contents) { - ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path)); - if (!exists) { - return PosixError( - ENOENT, absl::StrCat("SetContents file ", path, " doesn't exist.")); - } - - ASSIGN_OR_RETURN_ERRNO(auto fd, Open(std::string(path), O_WRONLY | O_TRUNC)); - return WriteContentsToFD(fd.get(), contents); -} - -// Create a file with the given contents (if it does not already exist with the -// given mode) and then set the contents. -PosixError CreateWithContents(absl::string_view path, - absl::string_view contents, int mode) { - ASSIGN_OR_RETURN_ERRNO( - auto fd, Open(std::string(path), O_WRONLY | O_CREAT | O_TRUNC, mode)); - return WriteContentsToFD(fd.get(), contents); -} - -PosixError GetContents(absl::string_view path, std::string* output) { - ASSIGN_OR_RETURN_ERRNO(auto fd, Open(std::string(path), O_RDONLY)); - output->clear(); - - // Keep reading until we hit an EOF or an error. - return GetContentsFD(fd.get(), output); -} - -PosixErrorOr<std::string> GetContents(absl::string_view path) { - std::string ret; - RETURN_IF_ERRNO(GetContents(path, &ret)); - return ret; -} - -PosixErrorOr<std::string> GetContentsFD(int fd) { - std::string ret; - RETURN_IF_ERRNO(GetContentsFD(fd, &ret)); - return ret; -} - -PosixError GetContentsFD(int fd, std::string* output) { - // Keep reading until we hit an EOF or an error. - while (true) { - char buf[16 * 1024] = {}; // Read in 16KB chunks. - int bytes_read = read(fd, buf, sizeof(buf)); - if (bytes_read < 0) { - if (errno == EINTR) { - continue; - } - return PosixError(errno, "GetContentsFD read failure."); - } - - if (bytes_read == 0) { - break; // EOF. - } - - output->append(buf, bytes_read); - } - return NoError(); -} - -PosixErrorOr<std::string> ReadLink(absl::string_view path) { - char buf[PATH_MAX + 1] = {}; - int ret = readlink(std::string(path).c_str(), buf, PATH_MAX); - if (ret < 0) { - return PosixError(errno, absl::StrCat("readlink ", path)); - } - - return std::string(buf, ret); -} - -PosixError WalkTree( - absl::string_view path, bool recursive, - const std::function<void(absl::string_view, const struct stat&)>& cb) { - DIR* dir = opendir(std::string(path).c_str()); - if (dir == nullptr) { - return PosixError(errno, absl::StrCat("opendir ", path)); - } - auto dir_closer = Cleanup([&dir]() { closedir(dir); }); - while (true) { - // Readdir(3): If the end of the directory stream is reached, NULL is - // returned and errno is not changed. If an error occurs, NULL is returned - // and errno is set appropriately. To distinguish end of stream and from an - // error, set errno to zero before calling readdir() and then check the - // value of errno if NULL is returned. - errno = 0; - struct dirent* dp = readdir(dir); - if (dp == nullptr) { - if (errno != 0) { - return PosixError(errno, absl::StrCat("readdir ", path)); - } - break; // We're done. - } - - if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) { - // Skip dots. - continue; - } - - auto full_path = JoinPath(path, dp->d_name); - ASSIGN_OR_RETURN_ERRNO(struct stat s, Stat(full_path)); - if (S_ISDIR(s.st_mode) && recursive) { - RETURN_IF_ERRNO(WalkTree(full_path, recursive, cb)); - } else { - cb(full_path, s); - } - } - // We're done walking so let's invoke our cleanup callback now. - dir_closer.Release()(); - - // And we have to dispatch the callback on the base directory. - ASSIGN_OR_RETURN_ERRNO(struct stat s, Stat(path)); - cb(path, s); - - return NoError(); -} - -PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath, - bool skipdots) { - std::vector<std::string> files; - - DIR* dir = opendir(std::string(abspath).c_str()); - if (dir == nullptr) { - return PosixError(errno, absl::StrCat("opendir ", abspath)); - } - auto dir_closer = Cleanup([&dir]() { closedir(dir); }); - while (true) { - // Readdir(3): If the end of the directory stream is reached, NULL is - // returned and errno is not changed. If an error occurs, NULL is returned - // and errno is set appropriately. To distinguish end of stream and from an - // error, set errno to zero before calling readdir() and then check the - // value of errno if NULL is returned. - errno = 0; - struct dirent* dp = readdir(dir); - if (dp == nullptr) { - if (errno != 0) { - return PosixError(errno, absl::StrCat("readdir ", abspath)); - } - break; // We're done. - } - - if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) { - if (skipdots) { - continue; - } - } - files.push_back(std::string(dp->d_name)); - } - - return files; -} - -PosixError RecursivelyDelete(absl::string_view path, int* undeleted_dirs, - int* undeleted_files) { - ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path)); - if (!exists) { - return PosixError(ENOENT, absl::StrCat(path, " does not exist")); - } - - ASSIGN_OR_RETURN_ERRNO(bool dir, IsDirectory(path)); - if (!dir) { - // Nothing recursive needs to happen we can just call Delete. - auto status = Delete(path); - if (!status.ok() && undeleted_files) { - (*undeleted_files)++; - } - return status; - } - - return WalkTree(path, /*recursive=*/true, - [&](absl::string_view absolute_path, const struct stat& s) { - if (S_ISDIR(s.st_mode)) { - auto rm_status = Rmdir(absolute_path); - if (!rm_status.ok() && undeleted_dirs) { - (*undeleted_dirs)++; - } - } else { - auto delete_status = Delete(absolute_path); - if (!delete_status.ok() && undeleted_files) { - (*undeleted_files)++; - } - } - }); -} - -PosixError RecursivelyCreateDir(absl::string_view path) { - if (path.empty() || path == "/") { - return PosixError(EINVAL, "Cannot create root!"); - } - - // Does it already exist, if so we're done. - ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path)); - if (exists) { - return NoError(); - } - - // Do we need to create directories under us? - auto dirname = Dirname(path); - ASSIGN_OR_RETURN_ERRNO(exists, Exists(dirname)); - if (!exists) { - RETURN_IF_ERRNO(RecursivelyCreateDir(dirname)); - } - - return Mkdir(path); -} - -// Makes a path absolute with respect to an optional base. If no base is -// provided it will use the current working directory. -PosixErrorOr<std::string> MakeAbsolute(absl::string_view filename, - absl::string_view base) { - if (filename.empty()) { - return PosixError(EINVAL, "filename cannot be empty."); - } - - if (filename[0] == '/') { - // This path is already absolute. - return std::string(filename); - } - - std::string actual_base; - if (!base.empty()) { - actual_base = std::string(base); - } else { - auto cwd_or = GetCWD(); - RETURN_IF_ERRNO(cwd_or.error()); - actual_base = cwd_or.ValueOrDie(); - } - - // Reverse iterate removing trailing slashes, effectively right trim '/'. - for (int i = actual_base.size() - 1; i >= 0 && actual_base[i] == '/'; --i) { - actual_base.erase(i, 1); - } - - if (filename == ".") { - return actual_base.empty() ? "/" : actual_base; - } - - return absl::StrCat(actual_base, "/", filename); -} - -std::string CleanPath(const absl::string_view unclean_path) { - std::string path = std::string(unclean_path); - const char* src = path.c_str(); - std::string::iterator dst = path.begin(); - - // Check for absolute path and determine initial backtrack limit. - const bool is_absolute_path = *src == '/'; - if (is_absolute_path) { - *dst++ = *src++; - while (*src == '/') ++src; - } - std::string::const_iterator backtrack_limit = dst; - - // Process all parts - while (*src) { - bool parsed = false; - - if (src[0] == '.') { - // 1dot ".<whateverisnext>", check for END or SEP. - if (src[1] == '/' || !src[1]) { - if (*++src) { - ++src; - } - parsed = true; - } else if (src[1] == '.' && (src[2] == '/' || !src[2])) { - // 2dot END or SEP (".." | "../<whateverisnext>"). - src += 2; - if (dst != backtrack_limit) { - // We can backtrack the previous part - for (--dst; dst != backtrack_limit && dst[-1] != '/'; --dst) { - // Empty. - } - } else if (!is_absolute_path) { - // Failed to backtrack and we can't skip it either. Rewind and copy. - src -= 2; - *dst++ = *src++; - *dst++ = *src++; - if (*src) { - *dst++ = *src; - } - // We can never backtrack over a copied "../" part so set new limit. - backtrack_limit = dst; - } - if (*src) { - ++src; - } - parsed = true; - } - } - - // If not parsed, copy entire part until the next SEP or EOS. - if (!parsed) { - while (*src && *src != '/') { - *dst++ = *src++; - } - if (*src) { - *dst++ = *src++; - } - } - - // Skip consecutive SEP occurrences - while (*src == '/') { - ++src; - } - } - - // Calculate and check the length of the cleaned path. - int path_length = dst - path.begin(); - if (path_length != 0) { - // Remove trailing '/' except if it is root path ("/" ==> path_length := 1) - if (path_length > 1 && path[path_length - 1] == '/') { - --path_length; - } - path.resize(path_length); - } else { - // The cleaned path is empty; assign "." as per the spec. - path.assign(1, '.'); - } - return path; -} - -PosixErrorOr<std::string> GetRelativePath(absl::string_view source, - absl::string_view dest) { - if (!absl::StartsWith(source, "/") || !absl::StartsWith(dest, "/")) { - // At least one of the inputs is not an absolute path. - return PosixError( - EINVAL, - "GetRelativePath: At least one of the inputs is not an absolute path."); - } - const std::string clean_source = CleanPath(source); - const std::string clean_dest = CleanPath(dest); - auto source_parts = absl::StrSplit(clean_source, '/', absl::SkipEmpty()); - auto dest_parts = absl::StrSplit(clean_dest, '/', absl::SkipEmpty()); - auto source_iter = source_parts.begin(); - auto dest_iter = dest_parts.begin(); - - // Advance past common prefix. - while (source_iter != source_parts.end() && dest_iter != dest_parts.end() && - *source_iter == *dest_iter) { - ++source_iter; - ++dest_iter; - } - - // Build result backtracking. - std::string result = ""; - while (source_iter != source_parts.end()) { - absl::StrAppend(&result, "../"); - ++source_iter; - } - - // Add remaining path to dest. - while (dest_iter != dest_parts.end()) { - absl::StrAppend(&result, *dest_iter, "/"); - ++dest_iter; - } - - if (result.empty()) { - return std::string("."); - } - - // Remove trailing slash. - result.erase(result.size() - 1); - return result; -} - -absl::string_view Dirname(absl::string_view path) { - return SplitPath(path).first; -} - -absl::string_view Basename(absl::string_view path) { - return SplitPath(path).second; -} - -std::pair<absl::string_view, absl::string_view> SplitPath( - absl::string_view path) { - std::string::size_type pos = path.find_last_of('/'); - - // Handle the case with no '/' in 'path'. - if (pos == absl::string_view::npos) { - return std::make_pair(path.substr(0, 0), path); - } - - // Handle the case with a single leading '/' in 'path'. - if (pos == 0) { - return std::make_pair(path.substr(0, 1), absl::ClippedSubstr(path, 1)); - } - - return std::make_pair(path.substr(0, pos), - absl::ClippedSubstr(path, pos + 1)); -} - -std::string JoinPath(absl::string_view path1, absl::string_view path2) { - if (path1.empty()) { - return std::string(path2); - } - if (path2.empty()) { - return std::string(path1); - } - - if (path1.back() == '/') { - if (path2.front() == '/') { - return absl::StrCat(path1, absl::ClippedSubstr(path2, 1)); - } - } else { - if (path2.front() != '/') { - return absl::StrCat(path1, "/", path2); - } - } - return absl::StrCat(path1, path2); -} - -PosixErrorOr<std::string> ProcessExePath(int pid) { - if (pid <= 0) { - return PosixError(EINVAL, "Invalid pid specified"); - } - - return ReadLink(absl::StrCat("/proc/", pid, "/exe")); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/fs_util.h b/test/util/fs_util.h deleted file mode 100644 index caf19b24d..000000000 --- a/test/util/fs_util.h +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_FS_UTIL_H_ -#define GVISOR_TEST_UTIL_FS_UTIL_H_ - -#include <dirent.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include "absl/strings/string_view.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0 -// because "it isn't needed", even though Linux can return it via F_GETFL. -#if defined(__x86_64__) -constexpr int kOLargeFile = 00100000; -#elif defined(__aarch64__) -constexpr int kOLargeFile = 00400000; -#else -#error "Unknown architecture" -#endif - -// Returns a status or the current working directory. -PosixErrorOr<std::string> GetCWD(); - -// Returns true/false depending on whether or not path exists, or an error if it -// can't be determined. -PosixErrorOr<bool> Exists(absl::string_view path); - -// Returns a stat structure for the given path or an error. -PosixErrorOr<struct stat> Stat(absl::string_view path); - -// Returns a stat struct for the given fd. -PosixErrorOr<struct stat> Fstat(int fd); - -// Deletes the file or directory at path or returns an error. -PosixError Delete(absl::string_view path); - -// Changes the mode of a file or returns an error. -PosixError Chmod(absl::string_view path, int mode); - -// Create a special or ordinary file. -PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, - dev_t dev); - -// Unlink the file. -PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, - int flags); - -// Truncates a file to the given length or returns an error. -PosixError Truncate(absl::string_view path, int length); - -// Returns true/false depending on whether or not the path is a directory or -// returns an error. -PosixErrorOr<bool> IsDirectory(absl::string_view path); - -// Makes a directory or returns an error. -PosixError Mkdir(absl::string_view path, int mode = 0755); - -// Removes a directory or returns an error. -PosixError Rmdir(absl::string_view path); - -// Attempts to set the contents of a file or returns an error. -PosixError SetContents(absl::string_view path, absl::string_view contents); - -// Creates a file with the given contents and mode or returns an error. -PosixError CreateWithContents(absl::string_view path, - absl::string_view contents, int mode = 0666); - -// Attempts to read the entire contents of the file into the provided string -// buffer or returns an error. -PosixError GetContents(absl::string_view path, std::string* output); - -// Attempts to read the entire contents of the file or returns an error. -PosixErrorOr<std::string> GetContents(absl::string_view path); - -// Attempts to read the entire contents of the provided fd into the provided -// string or returns an error. -PosixError GetContentsFD(int fd, std::string* output); - -// Attempts to read the entire contents of the provided fd or returns an error. -PosixErrorOr<std::string> GetContentsFD(int fd); - -// Executes the readlink(2) system call or returns an error. -PosixErrorOr<std::string> ReadLink(absl::string_view path); - -// WalkTree will walk a directory tree in a depth first search manner (if -// recursive). It will invoke a provided callback for each file and directory, -// the parent will always be invoked last making this appropriate for things -// such as deleting an entire directory tree. -// -// This method will return an error when it's unable to access the provided -// path, or when the path is not a directory. -PosixError WalkTree( - absl::string_view path, bool recursive, - const std::function<void(absl::string_view, const struct stat&)>& cb); - -// Returns the base filenames for all files under a given absolute path. If -// skipdots is true the returned vector will not contain "." or "..". This -// method does not walk the tree recursively it only returns the elements -// in that directory. -PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath, - bool skipdots); - -// Attempt to recursively delete a directory or file. Returns an error and -// the number of undeleted directories and files. If either -// undeleted_dirs or undeleted_files is nullptr then it will not be used. -PosixError RecursivelyDelete(absl::string_view path, int* undeleted_dirs, - int* undeleted_files); - -// Recursively create the directory provided or return an error. -PosixError RecursivelyCreateDir(absl::string_view path); - -// Makes a path absolute with respect to an optional base. If no base is -// provided it will use the current working directory. -PosixErrorOr<std::string> MakeAbsolute(absl::string_view filename, - absl::string_view base); - -// Generates a relative path from the source directory to the destination -// (dest) file or directory. This uses ../ when necessary for destinations -// which are not nested within the source. Both source and dest are required -// to be absolute paths, and an empty string will be returned if they are not. -PosixErrorOr<std::string> GetRelativePath(absl::string_view source, - absl::string_view dest); - -// Returns the part of the path before the final "/", EXCEPT: -// * If there is a single leading "/" in the path, the result will be the -// leading "/". -// * If there is no "/" in the path, the result is the empty prefix of the -// input string. -absl::string_view Dirname(absl::string_view path); - -// Return the parts of the path, split on the final "/". If there is no -// "/" in the path, the first part of the output is empty and the second -// is the input. If the only "/" in the path is the first character, it is -// the first part of the output. -std::pair<absl::string_view, absl::string_view> SplitPath( - absl::string_view path); - -// Returns the part of the path after the final "/". If there is no -// "/" in the path, the result is the same as the input. -// Note that this function's behavior differs from the Unix basename -// command if path ends with "/". For such paths, this function returns the -// empty string. -absl::string_view Basename(absl::string_view path); - -// Collapse duplicate "/"s, resolve ".." and "." path elements, remove -// trailing "/". -// -// NOTE: This respects relative vs. absolute paths, but does not -// invoke any system calls (getcwd(2)) in order to resolve relative -// paths wrt actual working directory. That is, this is purely a -// string manipulation, completely independent of process state. -std::string CleanPath(absl::string_view path); - -// Returns the full path to the executable of the given pid or a PosixError. -PosixErrorOr<std::string> ProcessExePath(int pid); - -namespace internal { -// Not part of the public API. -std::string JoinPathImpl(std::initializer_list<absl::string_view> paths); -} // namespace internal - -// Join multiple paths together. -// All paths will be treated as relative paths, regardless of whether or not -// they start with a leading '/'. That is, all paths will be concatenated -// together, with the appropriate path separator inserted in between. -// Arguments must be convertible to absl::string_view. -// -// Usage: -// std::string path = JoinPath("/foo", dirname, filename); -// std::string path = JoinPath(FLAGS_test_srcdir, filename); -// -// 0, 1, 2-path specializations exist to optimize common cases. -inline std::string JoinPath() { return std::string(); } -inline std::string JoinPath(absl::string_view path) { - return std::string(path.data(), path.size()); -} - -std::string JoinPath(absl::string_view path1, absl::string_view path2); -template <typename... T> -inline std::string JoinPath(absl::string_view path1, absl::string_view path2, - absl::string_view path3, const T&... args) { - return internal::JoinPathImpl({path1, path2, path3, args...}); -} -} // namespace testing -} // namespace gvisor -#endif // GVISOR_TEST_UTIL_FS_UTIL_H_ diff --git a/test/util/fs_util_test.cc b/test/util/fs_util_test.cc deleted file mode 100644 index 657b6a46e..000000000 --- a/test/util/fs_util_test.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/fs_util.h" - -#include <errno.h> - -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/posix_error.h" -#include "test/util/temp_path.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(FsUtilTest, RecursivelyCreateDirManualDelete) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string base_path = - JoinPath(root.path(), "/a/b/c/d/e/f/g/h/i/j/k/l/m"); - - ASSERT_THAT(Exists(base_path), IsPosixErrorOkAndHolds(false)); - ASSERT_NO_ERRNO(RecursivelyCreateDir(base_path)); - - // Delete everything until we hit root and then stop, we want to try this - // without using RecursivelyDelete. - std::string cur_path = base_path; - while (cur_path != root.path()) { - ASSERT_THAT(Exists(cur_path), IsPosixErrorOkAndHolds(true)); - ASSERT_NO_ERRNO(Rmdir(cur_path)); - ASSERT_THAT(Exists(cur_path), IsPosixErrorOkAndHolds(false)); - auto dir = Dirname(cur_path); - cur_path = std::string(dir); - } -} - -TEST(FsUtilTest, RecursivelyCreateAndDeleteDir) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string base_path = - JoinPath(root.path(), "/a/b/c/d/e/f/g/h/i/j/k/l/m"); - - ASSERT_THAT(Exists(base_path), IsPosixErrorOkAndHolds(false)); - ASSERT_NO_ERRNO(RecursivelyCreateDir(base_path)); - - const std::string sub_path = JoinPath(root.path(), "a"); - ASSERT_NO_ERRNO(RecursivelyDelete(sub_path, nullptr, nullptr)); - ASSERT_THAT(Exists(sub_path), IsPosixErrorOkAndHolds(false)); -} - -TEST(FsUtilTest, RecursivelyCreateAndDeletePartial) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const std::string base_path = - JoinPath(root.path(), "/a/b/c/d/e/f/g/h/i/j/k/l/m"); - - ASSERT_THAT(Exists(base_path), IsPosixErrorOkAndHolds(false)); - ASSERT_NO_ERRNO(RecursivelyCreateDir(base_path)); - - const std::string a = JoinPath(root.path(), "a"); - auto listing = ASSERT_NO_ERRNO_AND_VALUE(ListDir(a, true)); - ASSERT_THAT(listing, ::testing::Contains("b")); - ASSERT_EQ(listing.size(), 1); - - listing = ASSERT_NO_ERRNO_AND_VALUE(ListDir(a, false)); - ASSERT_THAT(listing, ::testing::Contains(".")); - ASSERT_THAT(listing, ::testing::Contains("..")); - ASSERT_THAT(listing, ::testing::Contains("b")); - ASSERT_EQ(listing.size(), 3); - - const std::string sub_path = JoinPath(root.path(), "/a/b/c/d/e/f"); - - ASSERT_NO_ERRNO( - CreateWithContents(JoinPath(Dirname(sub_path), "file"), "Hello World")); - std::string contents = ""; - ASSERT_NO_ERRNO(GetContents(JoinPath(Dirname(sub_path), "file"), &contents)); - ASSERT_EQ(contents, "Hello World"); - - ASSERT_NO_ERRNO(RecursivelyDelete(sub_path, nullptr, nullptr)); - ASSERT_THAT(Exists(sub_path), IsPosixErrorOkAndHolds(false)); - - // The parent of the subpath (directory e) should still exist. - ASSERT_THAT(Exists(Dirname(sub_path)), IsPosixErrorOkAndHolds(true)); - - // The file we created along side f should also still exist. - ASSERT_THAT(Exists(JoinPath(Dirname(sub_path), "file")), - IsPosixErrorOkAndHolds(true)); -} -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/util/logging.cc b/test/util/logging.cc deleted file mode 100644 index 5d5e76c46..000000000 --- a/test/util/logging.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/logging.h" - -#include <errno.h> -#include <stdint.h> -#include <stdlib.h> -#include <unistd.h> - -namespace gvisor { -namespace testing { - -namespace { - -// We implement this here instead of using test_util to avoid cyclic -// dependencies. -int Write(int fd, const char* buf, size_t size) { - size_t written = 0; - while (written < size) { - int res = write(fd, buf + written, size - written); - if (res < 0 && errno == EINTR) { - continue; - } else if (res <= 0) { - break; - } - - written += res; - } - return static_cast<int>(written); -} - -// Write 32-bit decimal number to fd. -int WriteNumber(int fd, uint32_t val) { - constexpr char kDigits[] = "0123456789"; - constexpr int kBase = 10; - - // 10 chars for 32-bit number in decimal, 1 char for the NUL-terminator. - constexpr int kBufferSize = 11; - char buf[kBufferSize]; - - // Convert the number to string. - char* s = buf + sizeof(buf) - 1; - size_t size = 0; - - *s = '\0'; - do { - s--; - size++; - - *s = kDigits[val % kBase]; - val /= kBase; - } while (val); - - return Write(fd, s, size); -} - -} // namespace - -void CheckFailure(const char* cond, size_t cond_size, const char* msg, - size_t msg_size, bool include_errno) { - int saved_errno = errno; - - constexpr char kCheckFailure[] = "Check failed: "; - Write(2, kCheckFailure, sizeof(kCheckFailure) - 1); - Write(2, cond, cond_size); - - if (msg != nullptr) { - Write(2, ": ", 2); - Write(2, msg, msg_size); - } - - if (include_errno) { - constexpr char kErrnoMessage[] = " (errno "; - Write(2, kErrnoMessage, sizeof(kErrnoMessage) - 1); - WriteNumber(2, saved_errno); - Write(2, ")", 1); - } - - Write(2, "\n", 1); - - abort(); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/logging.h b/test/util/logging.h deleted file mode 100644 index 589166fab..000000000 --- a/test/util/logging.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_LOGGING_H_ -#define GVISOR_TEST_UTIL_LOGGING_H_ - -#include <stddef.h> - -namespace gvisor { -namespace testing { - -void CheckFailure(const char* cond, size_t cond_size, const char* msg, - size_t msg_size, bool include_errno); - -// If cond is false, aborts the current process. -// -// This macro is async-signal-safe. -#define TEST_CHECK(cond) \ - do { \ - if (!(cond)) { \ - ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \ - 0, false); \ - } \ - } while (0) - -// If cond is false, logs msg then aborts the current process. -// -// This macro is async-signal-safe. -#define TEST_CHECK_MSG(cond, msg) \ - do { \ - if (!(cond)) { \ - ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \ - sizeof(msg) - 1, false); \ - } \ - } while (0) - -// If cond is false, logs errno, then aborts the current process. -// -// This macro is async-signal-safe. -#define TEST_PCHECK(cond) \ - do { \ - if (!(cond)) { \ - ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \ - 0, true); \ - } \ - } while (0) - -// If cond is false, logs msg and errno, then aborts the current process. -// -// This macro is async-signal-safe. -#define TEST_PCHECK_MSG(cond, msg) \ - do { \ - if (!(cond)) { \ - ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \ - sizeof(msg) - 1, true); \ - } \ - } while (0) - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_LOGGING_H_ diff --git a/test/util/memory_util.h b/test/util/memory_util.h deleted file mode 100644 index e189b73e8..000000000 --- a/test/util/memory_util.h +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_MEMORY_UTIL_H_ -#define GVISOR_TEST_UTIL_MEMORY_UTIL_H_ - -#include <errno.h> -#include <stddef.h> -#include <stdint.h> -#include <sys/mman.h> - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "test/util/logging.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// RAII type for mmap'ed memory. Only usable in tests due to use of a test-only -// macro that can't be named without invoking the presubmit's wrath. -class Mapping { - public: - // Constructs a mapping that owns nothing. - Mapping() = default; - - // Constructs a mapping that owns the mmapped memory [ptr, ptr+len). Most - // users should use Mmap or MmapAnon instead. - Mapping(void* ptr, size_t len) : ptr_(ptr), len_(len) {} - - Mapping(Mapping&& orig) : ptr_(orig.ptr_), len_(orig.len_) { orig.release(); } - - Mapping& operator=(Mapping&& orig) { - ptr_ = orig.ptr_; - len_ = orig.len_; - orig.release(); - return *this; - } - - Mapping(Mapping const&) = delete; - Mapping& operator=(Mapping const&) = delete; - - ~Mapping() { reset(); } - - void* ptr() const { return ptr_; } - size_t len() const { return len_; } - - // Returns a pointer to the end of the mapping. Useful for when the mapping - // is used as a thread stack. - void* endptr() const { return reinterpret_cast<void*>(addr() + len_); } - - // Returns the start of this mapping cast to uintptr_t for ease of pointer - // arithmetic. - uintptr_t addr() const { return reinterpret_cast<uintptr_t>(ptr_); } - - // Returns the end of this mapping cast to uintptr_t for ease of pointer - // arithmetic. - uintptr_t endaddr() const { return reinterpret_cast<uintptr_t>(endptr()); } - - // Returns this mapping as a StringPiece for ease of comparison. - // - // This function is named view in anticipation of the eventual replacement of - // StringPiece with std::string_view. - absl::string_view view() const { - return absl::string_view(static_cast<char const*>(ptr_), len_); - } - - // These are both named reset for consistency with standard smart pointers. - - void reset(void* ptr, size_t len) { - if (len_) { - TEST_PCHECK(munmap(ptr_, len_) == 0); - } - ptr_ = ptr; - len_ = len; - } - - void reset() { reset(nullptr, 0); } - - void release() { - ptr_ = nullptr; - len_ = 0; - } - - private: - void* ptr_ = nullptr; - size_t len_ = 0; -}; - -// Wrapper around mmap(2) that returns a Mapping. -inline PosixErrorOr<Mapping> Mmap(void* addr, size_t length, int prot, - int flags, int fd, off_t offset) { - void* ptr = mmap(addr, length, prot, flags, fd, offset); - if (ptr == MAP_FAILED) { - return PosixError( - errno, absl::StrFormat("mmap(%p, %d, %x, %x, %d, %d)", addr, length, - prot, flags, fd, offset)); - } - MaybeSave(); - return Mapping(ptr, length); -} - -// Convenience wrapper around Mmap for anonymous mappings. -inline PosixErrorOr<Mapping> MmapAnon(size_t length, int prot, int flags) { - return Mmap(nullptr, length, prot, flags | MAP_ANONYMOUS, -1, 0); -} - -// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of -// void* isn't directly compatible with SyscallSucceeds. -inline PosixErrorOr<void*> Mremap(void* old_address, size_t old_size, - size_t new_size, int flags, - void* new_address) { - void* rv = mremap(old_address, old_size, new_size, flags, new_address); - if (rv == MAP_FAILED) { - return PosixError(errno, "mremap failed"); - } - return rv; -} - -// Returns true if the page containing addr is mapped. -inline bool IsMapped(uintptr_t addr) { - int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)), - kPageSize, MS_ASYNC); - if (rv == 0) { - return true; - } - TEST_PCHECK_MSG(errno == ENOMEM, "msync failed with unexpected errno"); - return false; -} - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_MEMORY_UTIL_H_ diff --git a/test/util/mount_util.h b/test/util/mount_util.h deleted file mode 100644 index 09e2281eb..000000000 --- a/test/util/mount_util.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_MOUNT_UTIL_H_ -#define GVISOR_TEST_UTIL_MOUNT_UTIL_H_ - -#include <errno.h> -#include <sys/mount.h> - -#include <functional> -#include <string> - -#include "gmock/gmock.h" -#include "test/util/cleanup.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// Mount mounts the filesystem, and unmounts when the returned reference is -// destroyed. -inline PosixErrorOr<Cleanup> Mount(const std::string& source, - const std::string& target, - const std::string& fstype, - uint64_t mountflags, const std::string& data, - uint64_t umountflags) { - if (mount(source.c_str(), target.c_str(), fstype.c_str(), mountflags, - data.c_str()) == -1) { - return PosixError(errno, "mount failed"); - } - return Cleanup([target, umountflags]() { - EXPECT_THAT(umount2(target.c_str(), umountflags), SyscallSucceeds()); - }); -} - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_MOUNT_UTIL_H_ diff --git a/test/util/multiprocess_util.cc b/test/util/multiprocess_util.cc deleted file mode 100644 index 8b676751b..000000000 --- a/test/util/multiprocess_util.cc +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/multiprocess_util.h" - -#include <asm/unistd.h> -#include <errno.h> -#include <fcntl.h> -#include <signal.h> -#include <sys/prctl.h> -#include <unistd.h> - -#include "absl/strings/str_cat.h" -#include "test/util/cleanup.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -// exec_fn wraps a variant of the exec family, e.g. execve or execveat. -PosixErrorOr<Cleanup> ForkAndExecHelper(const std::function<void()>& exec_fn, - const std::function<void()>& fn, - pid_t* child, int* execve_errno) { - int pfds[2]; - int ret = pipe2(pfds, O_CLOEXEC); - if (ret < 0) { - return PosixError(errno, "pipe failed"); - } - FileDescriptor rfd(pfds[0]); - FileDescriptor wfd(pfds[1]); - - int parent_stdout = dup(STDOUT_FILENO); - if (parent_stdout < 0) { - return PosixError(errno, "dup stdout"); - } - int parent_stderr = dup(STDERR_FILENO); - if (parent_stdout < 0) { - return PosixError(errno, "dup stderr"); - } - - pid_t pid = fork(); - if (pid < 0) { - return PosixError(errno, "fork failed"); - } else if (pid == 0) { - // Child. - rfd.reset(); - if (dup2(parent_stdout, STDOUT_FILENO) < 0) { - _exit(3); - } - if (dup2(parent_stderr, STDERR_FILENO) < 0) { - _exit(4); - } - close(parent_stdout); - close(parent_stderr); - - // Clean ourself up in case the parent doesn't. - if (prctl(PR_SET_PDEATHSIG, SIGKILL)) { - _exit(3); - } - - if (fn) { - fn(); - } - - // Call variant of exec function. - exec_fn(); - - int error = errno; - if (WriteFd(pfds[1], &error, sizeof(error)) != sizeof(error)) { - // We can't do much if the write fails, but we can at least exit with a - // different code. - _exit(2); - } - _exit(1); - } - - // Parent. - if (child) { - *child = pid; - } - - auto cleanup = Cleanup([pid] { - kill(pid, SIGKILL); - RetryEINTR(waitpid)(pid, nullptr, 0); - }); - - wfd.reset(); - - int read_errno; - ret = ReadFd(rfd.get(), &read_errno, sizeof(read_errno)); - if (ret == 0) { - // Other end of the pipe closed, execve must have succeeded. - read_errno = 0; - } else if (ret < 0) { - return PosixError(errno, "read pipe failed"); - } else if (ret != sizeof(read_errno)) { - return PosixError(EPIPE, absl::StrCat("pipe read wrong size ", ret)); - } - - if (execve_errno) { - *execve_errno = read_errno; - } - - return std::move(cleanup); -} - -} // namespace - -PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, - const ExecveArray& argv, - const ExecveArray& envv, - const std::function<void()>& fn, pid_t* child, - int* execve_errno) { - char* const* argv_data = argv.get(); - char* const* envv_data = envv.get(); - const std::function<void()> exec_fn = [=] { - execve(filename.c_str(), argv_data, envv_data); - }; - return ForkAndExecHelper(exec_fn, fn, child, execve_errno); -} - -PosixErrorOr<Cleanup> ForkAndExecveat(const int32_t dirfd, - const std::string& pathname, - const ExecveArray& argv, - const ExecveArray& envv, const int flags, - const std::function<void()>& fn, - pid_t* child, int* execve_errno) { - char* const* argv_data = argv.get(); - char* const* envv_data = envv.get(); - const std::function<void()> exec_fn = [=] { - syscall(__NR_execveat, dirfd, pathname.c_str(), argv_data, envv_data, - flags); - }; - return ForkAndExecHelper(exec_fn, fn, child, execve_errno); -} - -PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn) { - pid_t pid = fork(); - if (pid == 0) { - fn(); - _exit(0); - } - MaybeSave(); - if (pid < 0) { - return PosixError(errno, "fork failed"); - } - - int status; - if (waitpid(pid, &status, 0) < 0) { - return PosixError(errno, "waitpid failed"); - } - - return status; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h deleted file mode 100644 index 2f3bf4a6f..000000000 --- a/test/util/multiprocess_util.h +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_MULTIPROCESS_UTIL_H_ -#define GVISOR_TEST_UTIL_MULTIPROCESS_UTIL_H_ - -#include <unistd.h> - -#include <algorithm> -#include <string> -#include <utility> -#include <vector> - -#include "absl/strings/string_view.h" -#include "test/util/cleanup.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// Immutable holder for a dynamically-sized array of pointers to mutable char, -// terminated by a null pointer, as required for the argv and envp arguments to -// execve(2). -class ExecveArray { - public: - // Constructs an empty ExecveArray. - ExecveArray() = default; - - // Constructs an ExecveArray by copying strings from the given range. T must - // be a range over ranges of char. - template <typename T> - explicit ExecveArray(T const& strs) : ExecveArray(strs.begin(), strs.end()) {} - - // Constructs an ExecveArray by copying strings from [first, last). InputIt - // must be an input iterator over a range over char. - template <typename InputIt> - ExecveArray(InputIt first, InputIt last) { - std::vector<size_t> offsets; - auto output_it = std::back_inserter(str_); - for (InputIt it = first; it != last; ++it) { - offsets.push_back(str_.size()); - auto const& s = *it; - std::copy(s.begin(), s.end(), output_it); - str_.push_back('\0'); - } - ptrs_.reserve(offsets.size() + 1); - for (auto offset : offsets) { - ptrs_.push_back(str_.data() + offset); - } - ptrs_.push_back(nullptr); - } - - // Constructs an ExecveArray by copying strings from list. This overload must - // exist independently of the single-argument template constructor because - // std::initializer_list does not participate in template argument deduction - // (i.e. cannot be type-inferred in an invocation of the templated - // constructor). - /* implicit */ ExecveArray(std::initializer_list<absl::string_view> list) - : ExecveArray(list.begin(), list.end()) {} - - // Disable move construction and assignment since ptrs_ points into str_. - ExecveArray(ExecveArray&&) = delete; - ExecveArray& operator=(ExecveArray&&) = delete; - - char* const* get() const { return ptrs_.data(); } - size_t get_size() { return str_.size(); } - - private: - std::vector<char> str_; - std::vector<char*> ptrs_; -}; - -// Simplified version of SubProcess. Returns OK and a cleanup function to kill -// the child if it made it to execve. -// -// fn is run between fork and exec. If it needs to fail, it should exit the -// process. -// -// The child pid is returned via child, if provided. -// execve's error code is returned via execve_errno, if provided. -PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, - const ExecveArray& argv, - const ExecveArray& envv, - const std::function<void()>& fn, pid_t* child, - int* execve_errno); - -inline PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, - const ExecveArray& argv, - const ExecveArray& envv, pid_t* child, - int* execve_errno) { - return ForkAndExec( - filename, argv, envv, [] {}, child, execve_errno); -} - -// Equivalent to ForkAndExec, except using dirfd and flags with execveat. -PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, - const std::string& pathname, - const ExecveArray& argv, - const ExecveArray& envv, int flags, - const std::function<void()>& fn, - pid_t* child, int* execve_errno); - -inline PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, - const std::string& pathname, - const ExecveArray& argv, - const ExecveArray& envv, int flags, - pid_t* child, int* execve_errno) { - return ForkAndExecveat( - dirfd, pathname, argv, envv, flags, [] {}, child, execve_errno); -} - -// Calls fn in a forked subprocess and returns the exit status of the -// subprocess. -// -// fn must be async-signal-safe. -PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_MULTIPROCESS_UTIL_H_ diff --git a/test/util/platform_util.cc b/test/util/platform_util.cc deleted file mode 100644 index c9200d381..000000000 --- a/test/util/platform_util.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/platform_util.h" - -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -PlatformSupport PlatformSupport32Bit() { - if (GvisorPlatform() == Platform::kPtrace || - GvisorPlatform() == Platform::kKVM) { - return PlatformSupport::NotSupported; - } else { - return PlatformSupport::Allowed; - } -} - -PlatformSupport PlatformSupportAlignmentCheck() { - return PlatformSupport::Allowed; -} - -PlatformSupport PlatformSupportMultiProcess() { - return PlatformSupport::Allowed; -} - -PlatformSupport PlatformSupportInt3() { - if (GvisorPlatform() == Platform::kKVM) { - return PlatformSupport::NotSupported; - } else { - return PlatformSupport::Allowed; - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/platform_util.h b/test/util/platform_util.h deleted file mode 100644 index 28cc92371..000000000 --- a/test/util/platform_util.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_PLATFORM_UTIL_H_ -#define GVISOR_TEST_UTIL_PLATFORM_UTIL_H_ - -namespace gvisor { -namespace testing { - -// PlatformSupport is a generic enumeration of classes of support. -// -// It is up to the individual functions and callers to agree on the precise -// definition for each case. The document here generally refers to 32-bit -// as an example. Many cases will use only NotSupported and Allowed. -enum class PlatformSupport { - // The feature is not supported on the current platform. - // - // In the case of 32-bit, this means that calls will generally be interpreted - // as 64-bit calls, and there is no support for 32-bit binaries, long calls, - // etc. This usually means that the underlying implementation just pretends - // that 32-bit doesn't exist. - NotSupported, - - // Calls will be ignored by the kernel with a fixed error. - Ignored, - - // Calls will result in a SIGSEGV or similar fault. - Segfault, - - // The feature is supported as expected. - // - // In the case of 32-bit, this means that the system call or far call will be - // handled properly. - Allowed, -}; - -PlatformSupport PlatformSupport32Bit(); -PlatformSupport PlatformSupportAlignmentCheck(); -PlatformSupport PlatformSupportMultiProcess(); -PlatformSupport PlatformSupportInt3(); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_PLATFORM_UTL_H_ diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc deleted file mode 100644 index cebf7e0ac..000000000 --- a/test/util/posix_error.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/posix_error.h" - -#include <cassert> -#include <cerrno> -#include <cstring> -#include <string> - -#include "absl/strings/str_cat.h" - -namespace gvisor { -namespace testing { - -std::string PosixError::ToString() const { - if (ok()) { - return "No Error"; - } - - std::string ret; - - char strerrno_buf[1024] = {}; - - auto res = strerror_r(errno_, strerrno_buf, sizeof(strerrno_buf)); - -// The GNU version of strerror_r always returns a non-null char* pointing to a -// buffer containing the stringified errno; the XSI version returns a positive -// errno which indicates the result of writing the stringified errno into the -// supplied buffer. The gymnastics below are needed to support both. -#ifndef _GNU_SOURCE - if (res != 0) { - ret = absl::StrCat("PosixError(errno=", errno_, " strerror_r FAILED(", ret, - "))"); - } else { - ret = absl::StrCat("PosixError(errno=", errno_, " ", strerrno_buf, ")"); - } -#else - ret = absl::StrCat("PosixError(errno=", errno_, " ", res, ")"); -#endif - - if (!msg_.empty()) { - ret.append(" "); - ret.append(msg_); - } - - return ret; -} - -::std::ostream& operator<<(::std::ostream& os, const PosixError& e) { - os << e.ToString(); - return os; -} - -void PosixErrorIsMatcherCommonImpl::DescribeTo(std::ostream* os) const { - *os << "has an errno value that "; - code_matcher_.DescribeTo(os); - *os << ", and has an error message that "; - message_matcher_.DescribeTo(os); -} - -void PosixErrorIsMatcherCommonImpl::DescribeNegationTo(std::ostream* os) const { - *os << "has an errno value that "; - code_matcher_.DescribeNegationTo(os); - *os << ", or has an error message that "; - message_matcher_.DescribeNegationTo(os); -} - -bool PosixErrorIsMatcherCommonImpl::MatchAndExplain( - const PosixError& error, - ::testing::MatchResultListener* result_listener) const { - ::testing::StringMatchResultListener inner_listener; - - inner_listener.Clear(); - if (!code_matcher_.MatchAndExplain(error.errno_value(), &inner_listener)) { - return false; - } - - if (!message_matcher_.Matches(error.error_message())) { - return false; - } - - return true; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/posix_error.h b/test/util/posix_error.h deleted file mode 100644 index ad666bce0..000000000 --- a/test/util/posix_error.h +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_POSIX_ERROR_H_ -#define GVISOR_TEST_UTIL_POSIX_ERROR_H_ - -#include <string> - -#include "gmock/gmock.h" -#include "absl/base/attributes.h" -#include "absl/strings/string_view.h" -#include "absl/types/variant.h" -#include "test/util/logging.h" - -namespace gvisor { -namespace testing { - -class PosixErrorIsMatcherCommonImpl; - -template <typename T> -class PosixErrorOr; - -class ABSL_MUST_USE_RESULT PosixError { - public: - PosixError() {} - explicit PosixError(int errno_value) : errno_(errno_value) {} - PosixError(int errno_value, std::string msg) - : errno_(errno_value), msg_(std::move(msg)) {} - - PosixError(PosixError&& other) = default; - PosixError& operator=(PosixError&& other) = default; - PosixError(const PosixError&) = default; - PosixError& operator=(const PosixError&) = default; - - bool ok() const { return errno_ == 0; } - - // Returns a reference to *this to make matchers compatible with - // PosixErrorOr. - const PosixError& error() const { return *this; } - - std::string error_message() const { return msg_; } - - // ToString produces a full string representation of this posix error - // including the printable representation of the errno and the error message. - std::string ToString() const; - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const {} - - private: - int errno_value() const { return errno_; } - int errno_ = 0; - std::string msg_; - - friend class PosixErrorIsMatcherCommonImpl; - - template <typename T> - friend class PosixErrorOr; -}; - -template <typename T> -class ABSL_MUST_USE_RESULT PosixErrorOr { - public: - // A PosixErrorOr will check fail if it is constructed with NoError(). - PosixErrorOr(const PosixError& error); - PosixErrorOr(const T& value); - PosixErrorOr(T&& value); - - PosixErrorOr(PosixErrorOr&& other) = default; - PosixErrorOr& operator=(PosixErrorOr&& other) = default; - PosixErrorOr(const PosixErrorOr&) = default; - PosixErrorOr& operator=(const PosixErrorOr&) = default; - - // Conversion copy/move constructor, T must be convertible from U. - template <typename U> - friend class PosixErrorOr; - - template <typename U> - PosixErrorOr(PosixErrorOr<U> other); - - template <typename U> - PosixErrorOr& operator=(PosixErrorOr<U> other); - - // Return a reference to the error or NoError(). - PosixError error() const; - - // Returns this->error().error_message(); - std::string error_message() const; - - // Returns true if this PosixErrorOr contains some T. - bool ok() const; - - // Returns a reference to our current value, or CHECK-fails if !this->ok(). - const T& ValueOrDie() const&; - T& ValueOrDie() &; - const T&& ValueOrDie() const&&; - T&& ValueOrDie() &&; - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const {} - - private: - int errno_value() const; - absl::variant<T, PosixError> value_; - - friend class PosixErrorIsMatcherCommonImpl; -}; - -template <typename T> -PosixErrorOr<T>::PosixErrorOr(const PosixError& error) : value_(error) { - TEST_CHECK_MSG( - !error.ok(), - "Constructing PosixErrorOr with NoError, eg. errno 0 is not allowed."); -} - -template <typename T> -PosixErrorOr<T>::PosixErrorOr(const T& value) : value_(value) {} - -template <typename T> -PosixErrorOr<T>::PosixErrorOr(T&& value) : value_(std::move(value)) {} - -// Conversion copy/move constructor, T must be convertible from U. -template <typename T> -template <typename U> -inline PosixErrorOr<T>::PosixErrorOr(PosixErrorOr<U> other) { - if (absl::holds_alternative<U>(other.value_)) { - // T is convertible from U. - value_ = absl::get<U>(std::move(other.value_)); - } else if (absl::holds_alternative<PosixError>(other.value_)) { - value_ = absl::get<PosixError>(std::move(other.value_)); - } else { - TEST_CHECK_MSG(false, "PosixErrorOr does not contain PosixError or value"); - } -} - -template <typename T> -template <typename U> -inline PosixErrorOr<T>& PosixErrorOr<T>::operator=(PosixErrorOr<U> other) { - if (absl::holds_alternative<U>(other.value_)) { - // T is convertible from U. - value_ = absl::get<U>(std::move(other.value_)); - } else if (absl::holds_alternative<PosixError>(other.value_)) { - value_ = absl::get<PosixError>(std::move(other.value_)); - } else { - TEST_CHECK_MSG(false, "PosixErrorOr does not contain PosixError or value"); - } - return *this; -} - -template <typename T> -PosixError PosixErrorOr<T>::error() const { - if (!absl::holds_alternative<PosixError>(value_)) { - return PosixError(); - } - return absl::get<PosixError>(value_); -} - -template <typename T> -int PosixErrorOr<T>::errno_value() const { - return error().errno_value(); -} - -template <typename T> -std::string PosixErrorOr<T>::error_message() const { - return error().error_message(); -} - -template <typename T> -bool PosixErrorOr<T>::ok() const { - return absl::holds_alternative<T>(value_); -} - -template <typename T> -const T& PosixErrorOr<T>::ValueOrDie() const& { - TEST_CHECK(absl::holds_alternative<T>(value_)); - return absl::get<T>(value_); -} - -template <typename T> -T& PosixErrorOr<T>::ValueOrDie() & { - TEST_CHECK(absl::holds_alternative<T>(value_)); - return absl::get<T>(value_); -} - -template <typename T> -const T&& PosixErrorOr<T>::ValueOrDie() const&& { - TEST_CHECK(absl::holds_alternative<T>(value_)); - return std::move(absl::get<T>(value_)); -} - -template <typename T> -T&& PosixErrorOr<T>::ValueOrDie() && { - TEST_CHECK(absl::holds_alternative<T>(value_)); - return std::move(absl::get<T>(value_)); -} - -extern ::std::ostream& operator<<(::std::ostream& os, const PosixError& e); - -template <typename T> -::std::ostream& operator<<(::std::ostream& os, const PosixErrorOr<T>& e) { - os << e.error(); - return os; -} - -// NoError is a PosixError that represents a successful state, i.e. No Error. -inline PosixError NoError() { return PosixError(); } - -// Monomorphic implementation of matcher IsPosixErrorOk() for a given type T. -// T can be PosixError, PosixErrorOr<>, or a reference to either of them. -template <typename T> -class MonoPosixErrorIsOkMatcherImpl : public ::testing::MatcherInterface<T> { - public: - void DescribeTo(std::ostream* os) const override { *os << "is OK"; } - void DescribeNegationTo(std::ostream* os) const override { - *os << "is not OK"; - } - bool MatchAndExplain(T actual_value, - ::testing::MatchResultListener*) const override { - return actual_value.ok(); - } -}; - -// Implements IsPosixErrorOkMatcher() as a polymorphic matcher. -class IsPosixErrorOkMatcher { - public: - template <typename T> - operator ::testing::Matcher<T>() const { // NOLINT - return MakeMatcher(new MonoPosixErrorIsOkMatcherImpl<T>()); - } -}; - -// Monomorphic implementation of a matcher for a PosixErrorOr. -template <typename PosixErrorOrType> -class IsPosixErrorOkAndHoldsMatcherImpl - : public ::testing::MatcherInterface<PosixErrorOrType> { - public: - using ValueType = typename std::remove_reference<decltype( - std::declval<PosixErrorOrType>().ValueOrDie())>::type; - - template <typename InnerMatcher> - explicit IsPosixErrorOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher) - : inner_matcher_(::testing::SafeMatcherCast<const ValueType&>( - std::forward<InnerMatcher>(inner_matcher))) {} - - void DescribeTo(std::ostream* os) const override { - *os << "is OK and has a value that "; - inner_matcher_.DescribeTo(os); - } - - void DescribeNegationTo(std::ostream* os) const override { - *os << "isn't OK or has a value that "; - inner_matcher_.DescribeNegationTo(os); - } - - bool MatchAndExplain( - PosixErrorOrType actual_value, - ::testing::MatchResultListener* listener) const override { - // We can't extract the value if it doesn't contain one. - if (!actual_value.ok()) { - return false; - } - - ::testing::StringMatchResultListener inner_listener; - const bool matches = inner_matcher_.MatchAndExplain( - actual_value.ValueOrDie(), &inner_listener); - const std::string inner_explanation = inner_listener.str(); - *listener << "has a value " - << ::testing::PrintToString(actual_value.ValueOrDie()); - - if (!inner_explanation.empty()) { - *listener << " " << inner_explanation; - } - return matches; - } - - private: - const ::testing::Matcher<const ValueType&> inner_matcher_; -}; - -// Implements IsOkAndHolds() as a polymorphic matcher. -template <typename InnerMatcher> -class IsPosixErrorOkAndHoldsMatcher { - public: - explicit IsPosixErrorOkAndHoldsMatcher(InnerMatcher inner_matcher) - : inner_matcher_(std::move(inner_matcher)) {} - - // Converts this polymorphic matcher to a monomorphic one of the given type. - // PosixErrorOrType can be either PosixErrorOr<T> or a reference to - // PosixErrorOr<T>. - template <typename PosixErrorOrType> - operator ::testing::Matcher<PosixErrorOrType>() const { // NOLINT - return ::testing::MakeMatcher( - new IsPosixErrorOkAndHoldsMatcherImpl<PosixErrorOrType>( - inner_matcher_)); - } - - private: - const InnerMatcher inner_matcher_; -}; - -// PosixErrorIs() is a polymorphic matcher. This class is the common -// implementation of it shared by all types T where PosixErrorIs() can be -// used as a Matcher<T>. -class PosixErrorIsMatcherCommonImpl { - public: - PosixErrorIsMatcherCommonImpl( - ::testing::Matcher<int> code_matcher, - ::testing::Matcher<const std::string&> message_matcher) - : code_matcher_(std::move(code_matcher)), - message_matcher_(std::move(message_matcher)) {} - - void DescribeTo(std::ostream* os) const; - - void DescribeNegationTo(std::ostream* os) const; - - bool MatchAndExplain(const PosixError& error, - ::testing::MatchResultListener* result_listener) const; - - template <typename T> - bool MatchAndExplain(const PosixErrorOr<T>& error_or, - ::testing::MatchResultListener* result_listener) const { - if (error_or.ok()) { - *result_listener << "has a value " - << ::testing::PrintToString(error_or.ValueOrDie()); - return false; - } - - return MatchAndExplain(error_or.error(), result_listener); - } - - private: - const ::testing::Matcher<int> code_matcher_; - const ::testing::Matcher<const std::string&> message_matcher_; -}; - -// Monomorphic implementation of matcher PosixErrorIs() for a given type -// T. T can be PosixError, PosixErrorOr<>, or a reference to either of them. -template <typename T> -class MonoPosixErrorIsMatcherImpl : public ::testing::MatcherInterface<T> { - public: - explicit MonoPosixErrorIsMatcherImpl( - PosixErrorIsMatcherCommonImpl common_impl) - : common_impl_(std::move(common_impl)) {} - - void DescribeTo(std::ostream* os) const override { - common_impl_.DescribeTo(os); - } - - void DescribeNegationTo(std::ostream* os) const override { - common_impl_.DescribeNegationTo(os); - } - - bool MatchAndExplain( - T actual_value, - ::testing::MatchResultListener* result_listener) const override { - return common_impl_.MatchAndExplain(actual_value, result_listener); - } - - private: - PosixErrorIsMatcherCommonImpl common_impl_; -}; - -inline ::testing::Matcher<int> ToErrorCodeMatcher( - const ::testing::Matcher<int>& m) { - return m; -} - -// Implements PosixErrorIs() as a polymorphic matcher. -class PosixErrorIsMatcher { - public: - template <typename ErrorCodeMatcher> - PosixErrorIsMatcher(ErrorCodeMatcher&& code_matcher, - ::testing::Matcher<const std::string&> message_matcher) - : common_impl_( - ToErrorCodeMatcher(std::forward<ErrorCodeMatcher>(code_matcher)), - std::move(message_matcher)) {} - - // Converts this polymorphic matcher to a monomorphic matcher of the - // given type. T can be StatusOr<>, Status, or a reference to - // either of them. - template <typename T> - operator ::testing::Matcher<T>() const { // NOLINT - return MakeMatcher(new MonoPosixErrorIsMatcherImpl<T>(common_impl_)); - } - - private: - const PosixErrorIsMatcherCommonImpl common_impl_; -}; - -// Returns a gMock matcher that matches a PosixError or PosixErrorOr<> whose -// whose error code matches code_matcher, and whose error message matches -// message_matcher. -template <typename ErrorCodeMatcher> -PosixErrorIsMatcher PosixErrorIs( - ErrorCodeMatcher&& code_matcher, - ::testing::Matcher<const std::string&> message_matcher) { - return PosixErrorIsMatcher(std::forward<ErrorCodeMatcher>(code_matcher), - std::move(message_matcher)); -} - -// Returns a gMock matcher that matches a PosixErrorOr<> which is ok() and -// value matches the inner matcher. -template <typename InnerMatcher> -IsPosixErrorOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type> -IsPosixErrorOkAndHolds(InnerMatcher&& inner_matcher) { - return IsPosixErrorOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>( - std::forward<InnerMatcher>(inner_matcher)); -} - -// Internal helper for concatenating macro values. -#define POSIX_ERROR_IMPL_CONCAT_INNER_(x, y) x##y -#define POSIX_ERROR_IMPL_CONCAT_(x, y) POSIX_ERROR_IMPL_CONCAT_INNER_(x, y) - -#define POSIX_ERROR_IMPL_ASSIGN_OR_RETURN_(posixerroror, lhs, rexpr) \ - auto posixerroror = (rexpr); \ - if (!posixerroror.ok()) { \ - return (posixerroror.error()); \ - } \ - lhs = std::move(posixerroror).ValueOrDie() - -#define EXPECT_NO_ERRNO(expression) \ - EXPECT_THAT(expression, IsPosixErrorOkMatcher()) -#define ASSERT_NO_ERRNO(expression) \ - ASSERT_THAT(expression, IsPosixErrorOkMatcher()) - -#define ASSIGN_OR_RETURN_ERRNO(lhs, rexpr) \ - POSIX_ERROR_IMPL_ASSIGN_OR_RETURN_( \ - POSIX_ERROR_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr) - -#define RETURN_IF_ERRNO(s) \ - do { \ - if (!s.ok()) { \ - return s; \ - } \ - } while (false); - -#define ASSERT_NO_ERRNO_AND_VALUE(expr) \ - ({ \ - auto _expr_result = (expr); \ - ASSERT_NO_ERRNO(_expr_result); \ - std::move(_expr_result).ValueOrDie(); \ - }) - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_POSIX_ERROR_H_ diff --git a/test/util/posix_error_test.cc b/test/util/posix_error_test.cc deleted file mode 100644 index bf9465abb..000000000 --- a/test/util/posix_error_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/posix_error.h" - -#include <errno.h> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace gvisor { -namespace testing { - -namespace { - -TEST(PosixErrorTest, PosixError) { - auto err = PosixError(EAGAIN); - EXPECT_THAT(err, PosixErrorIs(EAGAIN, "")); -} - -TEST(PosixErrorTest, PosixErrorOrPosixError) { - auto err = PosixErrorOr<std::nullptr_t>(PosixError(EAGAIN)); - EXPECT_THAT(err, PosixErrorIs(EAGAIN, "")); -} - -TEST(PosixErrorTest, PosixErrorOrNullptr) { - auto err = PosixErrorOr<std::nullptr_t>(nullptr); - EXPECT_TRUE(err.ok()); - EXPECT_NO_ERRNO(err); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/util/proc_util.cc b/test/util/proc_util.cc deleted file mode 100644 index 34d636ba9..000000000 --- a/test/util/proc_util.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/proc_util.h" - -#include <algorithm> -#include <iostream> -#include <vector> - -#include "absl/strings/ascii.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "test/util/fs_util.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// Parses a single line from /proc/<xxx>/maps. -PosixErrorOr<ProcMapsEntry> ParseProcMapsLine(absl::string_view line) { - ProcMapsEntry map_entry = {}; - - // Limit splitting to 6 parts so that if there is a file path and it contains - // spaces, the file path is not split. - std::vector<std::string> parts = - absl::StrSplit(line, absl::MaxSplits(' ', 5), absl::SkipEmpty()); - - // parts.size() should be 6 if there is a file name specified, and 5 - // otherwise. - if (parts.size() < 5) { - return PosixError(EINVAL, absl::StrCat("Invalid line: ", line)); - } - - // Address range in the form X-X where X are hex values without leading 0x. - std::vector<std::string> addresses = absl::StrSplit(parts[0], '-'); - if (addresses.size() != 2) { - return PosixError(EINVAL, - absl::StrCat("Invalid address range: ", parts[0])); - } - ASSIGN_OR_RETURN_ERRNO(map_entry.start, AtoiBase(addresses[0], 16)); - ASSIGN_OR_RETURN_ERRNO(map_entry.end, AtoiBase(addresses[1], 16)); - - // Permissions are four bytes of the form rwxp or - if permission not set. - if (parts[1].size() != 4) { - return PosixError(EINVAL, - absl::StrCat("Invalid permission field: ", parts[1])); - } - - map_entry.readable = parts[1][0] == 'r'; - map_entry.writable = parts[1][1] == 'w'; - map_entry.executable = parts[1][2] == 'x'; - map_entry.priv = parts[1][3] == 'p'; - - ASSIGN_OR_RETURN_ERRNO(map_entry.offset, AtoiBase(parts[2], 16)); - - std::vector<std::string> device = absl::StrSplit(parts[3], ':'); - if (device.size() != 2) { - return PosixError(EINVAL, absl::StrCat("Invalid device: ", parts[3])); - } - ASSIGN_OR_RETURN_ERRNO(map_entry.major, AtoiBase(device[0], 16)); - ASSIGN_OR_RETURN_ERRNO(map_entry.minor, AtoiBase(device[1], 16)); - - ASSIGN_OR_RETURN_ERRNO(map_entry.inode, Atoi<int64_t>(parts[4])); - if (parts.size() == 6) { - // A filename is present. However, absl::StrSplit retained the whitespace - // between the inode number and the filename. - map_entry.filename = - std::string(absl::StripLeadingAsciiWhitespace(parts[5])); - } - - return map_entry; -} - -PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps( - absl::string_view contents) { - std::vector<ProcMapsEntry> entries; - auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty()); - for (const auto& l : lines) { - std::cout << "line: " << l << std::endl; - ASSIGN_OR_RETURN_ERRNO(auto entry, ParseProcMapsLine(l)); - entries.push_back(entry); - } - return entries; -} - -PosixErrorOr<bool> IsVsyscallEnabled() { - ASSIGN_OR_RETURN_ERRNO(auto contents, GetContents("/proc/self/maps")); - ASSIGN_OR_RETURN_ERRNO(auto maps, ParseProcMaps(contents)); - return std::any_of(maps.begin(), maps.end(), [](const ProcMapsEntry& e) { - return e.filename == "[vsyscall]"; - }); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/proc_util.h b/test/util/proc_util.h deleted file mode 100644 index af209a51e..000000000 --- a/test/util/proc_util.h +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_PROC_UTIL_H_ -#define GVISOR_TEST_UTIL_PROC_UTIL_H_ - -#include <ostream> -#include <string> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// ProcMapsEntry contains the data from a single line in /proc/<xxx>/maps. -struct ProcMapsEntry { - uint64_t start; - uint64_t end; - bool readable; - bool writable; - bool executable; - bool priv; - uint64_t offset; - int major; - int minor; - int64_t inode; - std::string filename; -}; - -// Parses a ProcMaps line or returns an error. -PosixErrorOr<ProcMapsEntry> ParseProcMapsLine(absl::string_view line); -PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps( - absl::string_view contents); - -// Returns true if vsyscall (emmulation or not) is enabled. -PosixErrorOr<bool> IsVsyscallEnabled(); - -// Printer for ProcMapsEntry. -inline std::ostream& operator<<(std::ostream& os, const ProcMapsEntry& entry) { - std::string str = - absl::StrCat(absl::Hex(entry.start, absl::PadSpec::kZeroPad8), "-", - absl::Hex(entry.end, absl::PadSpec::kZeroPad8), " "); - - absl::StrAppend(&str, entry.readable ? "r" : "-"); - absl::StrAppend(&str, entry.writable ? "w" : "-"); - absl::StrAppend(&str, entry.executable ? "x" : "-"); - absl::StrAppend(&str, entry.priv ? "p" : "s"); - - absl::StrAppend(&str, " ", absl::Hex(entry.offset, absl::PadSpec::kZeroPad8), - " ", absl::Hex(entry.major, absl::PadSpec::kZeroPad2), ":", - absl::Hex(entry.minor, absl::PadSpec::kZeroPad2), " ", - entry.inode); - if (absl::string_view(entry.filename) != "") { - // Pad to column 74 - int pad = 73 - str.length(); - if (pad > 0) { - absl::StrAppend(&str, std::string(pad, ' ')); - } - absl::StrAppend(&str, entry.filename); - } - os << str; - return os; -} - -// Printer for std::vector<ProcMapsEntry>. -inline std::ostream& operator<<(std::ostream& os, - const std::vector<ProcMapsEntry>& vec) { - for (unsigned int i = 0; i < vec.size(); i++) { - os << vec[i]; - if (i != vec.size() - 1) { - os << "\n"; - } - } - return os; -} - -// GMock printer for std::vector<ProcMapsEntry>. -inline void PrintTo(const std::vector<ProcMapsEntry>& vec, std::ostream* os) { - *os << vec; -} - -// Checks that /proc/pid/maps contains all of the passed mappings. -// -// The major, minor, and inode fields are ignored. -MATCHER_P(ContainsMappings, mappings, - "contains mappings:\n" + ::testing::PrintToString(mappings)) { - auto contents_or = GetContents(absl::StrCat("/proc/", arg, "/maps")); - if (!contents_or.ok()) { - *result_listener << "Unable to read mappings: " - << contents_or.error().ToString(); - return false; - } - - auto maps_or = ParseProcMaps(contents_or.ValueOrDie()); - if (!maps_or.ok()) { - *result_listener << "Unable to parse mappings: " - << maps_or.error().ToString(); - return false; - } - - auto maps = std::move(maps_or).ValueOrDie(); - - // Does maps contain all elements in mappings? The comparator ignores - // the major, minor, and inode fields. - bool all_present = true; - std::for_each(mappings.begin(), mappings.end(), [&](const ProcMapsEntry& e1) { - auto it = - std::find_if(maps.begin(), maps.end(), [&e1](const ProcMapsEntry& e2) { - return e1.start == e2.start && e1.end == e2.end && - e1.readable == e2.readable && e1.writable == e2.writable && - e1.executable == e2.executable && e1.priv == e2.priv && - e1.offset == e2.offset && e1.filename == e2.filename; - }); - if (it == maps.end()) { - // It wasn't found. - if (all_present) { - // We will output the message once and then a line for each mapping - // that wasn't found. - all_present = false; - *result_listener << "Got mappings:\n" - << maps << "\nThat were missing:\n"; - } - *result_listener << e1 << "\n"; - } - }); - - return all_present; -} - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_PROC_UTIL_H_ diff --git a/test/util/proc_util_test.cc b/test/util/proc_util_test.cc deleted file mode 100644 index 71dd2355e..000000000 --- a/test/util/proc_util_test.cc +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/proc_util.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "test/util/test_util.h" - -using ::testing::IsEmpty; - -namespace gvisor { -namespace testing { - -namespace { - -TEST(ParseProcMapsLineTest, WithoutFilename) { - auto entry = ASSERT_NO_ERRNO_AND_VALUE( - ParseProcMapsLine("2ab4f00b7000-2ab4f00b9000 r-xp 00000000 00:00 0 ")); - EXPECT_EQ(entry.start, 0x2ab4f00b7000); - EXPECT_EQ(entry.end, 0x2ab4f00b9000); - EXPECT_TRUE(entry.readable); - EXPECT_FALSE(entry.writable); - EXPECT_TRUE(entry.executable); - EXPECT_TRUE(entry.priv); - EXPECT_EQ(entry.offset, 0); - EXPECT_EQ(entry.major, 0); - EXPECT_EQ(entry.minor, 0); - EXPECT_EQ(entry.inode, 0); - EXPECT_THAT(entry.filename, IsEmpty()); -} - -TEST(ParseProcMapsLineTest, WithFilename) { - auto entry = ASSERT_NO_ERRNO_AND_VALUE( - ParseProcMapsLine("00407000-00408000 rw-p 00006000 00:0e 10 " - " /bin/cat")); - EXPECT_EQ(entry.start, 0x407000); - EXPECT_EQ(entry.end, 0x408000); - EXPECT_TRUE(entry.readable); - EXPECT_TRUE(entry.writable); - EXPECT_FALSE(entry.executable); - EXPECT_TRUE(entry.priv); - EXPECT_EQ(entry.offset, 0x6000); - EXPECT_EQ(entry.major, 0); - EXPECT_EQ(entry.minor, 0x0e); - EXPECT_EQ(entry.inode, 10); - EXPECT_EQ(entry.filename, "/bin/cat"); -} - -TEST(ParseProcMapsLineTest, WithFilenameContainingSpaces) { - auto entry = ASSERT_NO_ERRNO_AND_VALUE( - ParseProcMapsLine("7f26b3b12000-7f26b3b13000 rw-s 00000000 00:05 1432484 " - " /dev/zero (deleted)")); - EXPECT_EQ(entry.start, 0x7f26b3b12000); - EXPECT_EQ(entry.end, 0x7f26b3b13000); - EXPECT_TRUE(entry.readable); - EXPECT_TRUE(entry.writable); - EXPECT_FALSE(entry.executable); - EXPECT_FALSE(entry.priv); - EXPECT_EQ(entry.offset, 0); - EXPECT_EQ(entry.major, 0); - EXPECT_EQ(entry.minor, 0x05); - EXPECT_EQ(entry.inode, 1432484); - EXPECT_EQ(entry.filename, "/dev/zero (deleted)"); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc deleted file mode 100644 index c01f916aa..000000000 --- a/test/util/pty_util.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/pty_util.h" - -#include <sys/ioctl.h> -#include <termios.h> - -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { - PosixErrorOr<int> n = SlaveID(master); - if (!n.ok()) { - return PosixErrorOr<FileDescriptor>(n.error()); - } - return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK); -} - -PosixErrorOr<int> SlaveID(const FileDescriptor& master) { - // Get pty index. - int n; - int ret = ioctl(master.get(), TIOCGPTN, &n); - if (ret < 0) { - return PosixError(errno, "ioctl(TIOCGPTN) failed"); - } - - // Unlock pts. - int unlock = 0; - ret = ioctl(master.get(), TIOCSPTLCK, &unlock); - if (ret < 0) { - return PosixError(errno, "ioctl(TIOSPTLCK) failed"); - } - - return n; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/pty_util.h b/test/util/pty_util.h deleted file mode 100644 index 0722da379..000000000 --- a/test/util/pty_util.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_PTY_UTIL_H_ -#define GVISOR_TEST_UTIL_PTY_UTIL_H_ - -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// Opens the slave end of the passed master as R/W and nonblocking. -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); - -// Get the number of the slave end of the master. -PosixErrorOr<int> SlaveID(const FileDescriptor& master); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_PTY_UTIL_H_ diff --git a/test/util/rlimit_util.cc b/test/util/rlimit_util.cc deleted file mode 100644 index d7bfc1606..000000000 --- a/test/util/rlimit_util.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/rlimit_util.h" - -#include <sys/resource.h> - -#include <cerrno> - -#include "test/util/cleanup.h" -#include "test/util/logging.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<Cleanup> ScopedSetSoftRlimit(int resource, rlim_t newval) { - struct rlimit old_rlim; - if (getrlimit(resource, &old_rlim) != 0) { - return PosixError(errno, "getrlimit failed"); - } - struct rlimit new_rlim = old_rlim; - new_rlim.rlim_cur = newval; - if (setrlimit(resource, &new_rlim) != 0) { - return PosixError(errno, "setrlimit failed"); - } - return Cleanup([resource, old_rlim] { - TEST_PCHECK(setrlimit(resource, &old_rlim) == 0); - }); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/rlimit_util.h b/test/util/rlimit_util.h deleted file mode 100644 index 873252a32..000000000 --- a/test/util/rlimit_util.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_RLIMIT_UTIL_H_ -#define GVISOR_TEST_UTIL_RLIMIT_UTIL_H_ - -#include <sys/resource.h> -#include <sys/time.h> - -#include "test/util/cleanup.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<Cleanup> ScopedSetSoftRlimit(int resource, rlim_t newval); - -} // namespace testing -} // namespace gvisor -#endif // GVISOR_TEST_UTIL_RLIMIT_UTIL_H_ diff --git a/test/util/save_util.cc b/test/util/save_util.cc deleted file mode 100644 index 384d626f0..000000000 --- a/test/util/save_util.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/save_util.h" - -#include <stddef.h> -#include <stdlib.h> -#include <unistd.h> - -#include <atomic> -#include <cerrno> - -#define GVISOR_COOPERATIVE_SAVE_TEST "GVISOR_COOPERATIVE_SAVE_TEST" - -namespace gvisor { -namespace testing { -namespace { - -enum class CooperativeSaveMode { - kUnknown = 0, // cooperative_save_mode is statically-initialized to 0 - kAvailable, - kNotAvailable, -}; - -std::atomic<CooperativeSaveMode> cooperative_save_mode; - -bool CooperativeSaveEnabled() { - auto mode = cooperative_save_mode.load(); - if (mode == CooperativeSaveMode::kUnknown) { - mode = (getenv(GVISOR_COOPERATIVE_SAVE_TEST) != nullptr) - ? CooperativeSaveMode::kAvailable - : CooperativeSaveMode::kNotAvailable; - cooperative_save_mode.store(mode); - } - return mode == CooperativeSaveMode::kAvailable; -} - -std::atomic<int> save_disable; - -} // namespace - -DisableSave::DisableSave() { save_disable++; } - -DisableSave::~DisableSave() { reset(); } - -void DisableSave::reset() { - if (!reset_) { - reset_ = true; - save_disable--; - } -} - -namespace internal { -bool ShouldSave() { - return CooperativeSaveEnabled() && (save_disable.load() == 0); -} -} // namespace internal - -} // namespace testing -} // namespace gvisor diff --git a/test/util/save_util.h b/test/util/save_util.h deleted file mode 100644 index bddad6120..000000000 --- a/test/util/save_util.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_SAVE_UTIL_H_ -#define GVISOR_TEST_UTIL_SAVE_UTIL_H_ - -namespace gvisor { -namespace testing { -// Disable save prevents saving while the given function executes. -// -// This lasts the duration of the object, unless reset is called. -class DisableSave { - public: - DisableSave(); - ~DisableSave(); - DisableSave(DisableSave const&) = delete; - DisableSave(DisableSave&&) = delete; - DisableSave& operator=(DisableSave const&) = delete; - DisableSave& operator=(DisableSave&&) = delete; - - // reset allows saves to continue, and is called implicitly by the destructor. - // It may be called multiple times safely, but is not thread-safe. - void reset(); - - private: - bool reset_ = false; -}; - -// May perform a co-operative save cycle. -// -// errno is guaranteed to be preserved. -void MaybeSave(); - -namespace internal { -bool ShouldSave(); -} // namespace internal - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_SAVE_UTIL_H_ diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc deleted file mode 100644 index d0aea8e6a..000000000 --- a/test/util/save_util_linux.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef __linux__ - -#include <errno.h> -#include <sys/syscall.h> -#include <unistd.h> - -#include "test/util/save_util.h" - -#if defined(__x86_64__) || defined(__i386__) -#define SYS_TRIGGER_SAVE SYS_create_module -#elif defined(__aarch64__) -#define SYS_TRIGGER_SAVE SYS_finit_module -#else -#error "Unknown architecture" -#endif - -namespace gvisor { -namespace testing { - -void MaybeSave() { - if (internal::ShouldSave()) { - int orig_errno = errno; - // We use it to trigger saving the sentry state - // when this syscall is called. - // Notice: this needs to be a valid syscall - // that is not used in any of the syscall tests. - syscall(SYS_TRIGGER_SAVE, nullptr, 0); - errno = orig_errno; - } -} - -} // namespace testing -} // namespace gvisor - -#endif diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc deleted file mode 100644 index 931af2c29..000000000 --- a/test/util/save_util_other.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __linux__ - -namespace gvisor { -namespace testing { - -void MaybeSave() { - // Saving is never available in a non-linux environment. -} - -} // namespace testing -} // namespace gvisor - -#endif diff --git a/test/util/signal_util.cc b/test/util/signal_util.cc deleted file mode 100644 index 5ee95ee80..000000000 --- a/test/util/signal_util.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/signal_util.h" - -#include <signal.h> - -#include <ostream> - -#include "gtest/gtest.h" -#include "test/util/cleanup.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace { - -struct Range { - int start; - int end; -}; - -// Format a Range as "start-end" or "start" for single value Ranges. -static ::std::ostream& operator<<(::std::ostream& os, const Range& range) { - if (range.end > range.start) { - return os << range.start << '-' << range.end; - } - - return os << range.start; -} - -} // namespace - -// Format a sigset_t as a comma separated list of numeric ranges. -// Empty sigset: [] -// Full sigset: [1-31,34-64] -::std::ostream& operator<<(::std::ostream& os, const sigset_t& sigset) { - const char* delim = ""; - Range range = {0, 0}; - - os << '['; - - for (int sig = 1; sig <= gvisor::testing::kMaxSignal; ++sig) { - if (sigismember(&sigset, sig)) { - if (range.start) { - range.end = sig; - } else { - range.start = sig; - range.end = sig; - } - } else if (range.start) { - os << delim << range; - delim = ","; - range.start = 0; - range.end = 0; - } - } - - if (range.start) { - os << delim << range; - } - - return os << ']'; -} - -namespace gvisor { -namespace testing { - -PosixErrorOr<Cleanup> ScopedSigaction(int sig, struct sigaction const& sa) { - struct sigaction old_sa; - int rc = sigaction(sig, &sa, &old_sa); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "sigaction failed"); - } - return Cleanup([sig, old_sa] { - EXPECT_THAT(sigaction(sig, &old_sa, nullptr), SyscallSucceeds()); - }); -} - -PosixErrorOr<Cleanup> ScopedSignalMask(int how, sigset_t const& set) { - sigset_t old; - int rc = sigprocmask(how, &set, &old); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "sigprocmask failed"); - } - return Cleanup([old] { - EXPECT_THAT(sigprocmask(SIG_SETMASK, &old, nullptr), SyscallSucceeds()); - }); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/signal_util.h b/test/util/signal_util.h deleted file mode 100644 index e7b66aa51..000000000 --- a/test/util/signal_util.h +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_SIGNAL_UTIL_H_ -#define GVISOR_TEST_UTIL_SIGNAL_UTIL_H_ - -#include <signal.h> -#include <sys/syscall.h> -#include <unistd.h> - -#include <ostream> - -#include "gmock/gmock.h" -#include "test/util/cleanup.h" -#include "test/util/posix_error.h" - -// Format a sigset_t as a comma separated list of numeric ranges. -::std::ostream& operator<<(::std::ostream& os, const sigset_t& sigset); - -namespace gvisor { -namespace testing { - -// The maximum signal number. -static constexpr int kMaxSignal = 64; - -// Wrapper for the tgkill(2) syscall, which glibc does not provide. -inline int tgkill(pid_t tgid, pid_t tid, int sig) { - return syscall(__NR_tgkill, tgid, tid, sig); -} - -// Installs the passed sigaction and returns a cleanup function to restore the -// previous handler when it goes out of scope. -PosixErrorOr<Cleanup> ScopedSigaction(int sig, struct sigaction const& sa); - -// Updates the signal mask as per sigprocmask(2) and returns a cleanup function -// to restore the previous signal mask when it goes out of scope. -PosixErrorOr<Cleanup> ScopedSignalMask(int how, sigset_t const& set); - -// ScopedSignalMask variant that creates a mask of the single signal 'sig'. -inline PosixErrorOr<Cleanup> ScopedSignalMask(int how, int sig) { - sigset_t set; - sigemptyset(&set); - sigaddset(&set, sig); - return ScopedSignalMask(how, set); -} - -// Asserts equality of two sigset_t values. -MATCHER_P(EqualsSigset, value, "equals " + ::testing::PrintToString(value)) { - for (int sig = 1; sig <= kMaxSignal; ++sig) { - if (sigismember(&arg, sig) != sigismember(&value, sig)) { - return false; - } - } - return true; -} - -#ifdef __x86_64__ -// Fault can be used to generate a synchronous SIGSEGV. -// -// This fault can be fixed up in a handler via fixup, below. -inline void Fault() { - // Zero and dereference %ax. - asm("movabs $0, %%rax\r\n" - "mov 0(%%rax), %%rax\r\n" - : - : - : "ax"); -} - -// FixupFault fixes up a fault generated by fault, above. -inline void FixupFault(ucontext_t* ctx) { - // Skip the bad instruction above. - // - // The encoding is 0x48 0xab 0x00. - ctx->uc_mcontext.gregs[REG_RIP] += 3; -} -#elif __aarch64__ -inline void Fault() { - // Zero and dereference x0. - asm("mov xzr, x0\r\n" - "str xzr, [x0]\r\n" - : - : - : "x0"); -} - -inline void FixupFault(ucontext_t* ctx) { - // Skip the bad instruction above. - ctx->uc_mcontext.pc += 4; -} -#endif - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_SIGNAL_UTIL_H_ diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc deleted file mode 100644 index 9c10b6674..000000000 --- a/test/util/temp_path.cc +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/temp_path.h" - -#include <unistd.h> - -#include <atomic> -#include <cstdlib> -#include <iostream> - -#include "gtest/gtest.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -namespace { - -std::atomic<uint64_t> global_temp_file_number = ATOMIC_VAR_INIT(1); - -// Return a new temp filename, intended to be unique system-wide. -// -// The global file number helps maintain file naming consistency across -// different runs of a test. -// -// The timestamp is necessary because the test infrastructure invokes each -// test case in a separate process (resetting global_temp_file_number) and -// potentially in parallel, which allows for races between selecting and using a -// name. -std::string NextTempBasename() { - return absl::StrCat("gvisor_test_temp_", global_temp_file_number++, "_", - absl::ToUnixNanos(absl::Now())); -} - -void TryDeleteRecursively(std::string const& path) { - if (!path.empty()) { - int undeleted_dirs = 0; - int undeleted_files = 0; - auto status = RecursivelyDelete(path, &undeleted_dirs, &undeleted_files); - if (undeleted_dirs || undeleted_files || !status.ok()) { - std::cerr << path << ": failed to delete " << undeleted_dirs - << " directories and " << undeleted_files - << " files: " << status; - } - } -} - -} // namespace - -constexpr mode_t TempPath::kDefaultFileMode; -constexpr mode_t TempPath::kDefaultDirMode; - -std::string NewTempAbsPathInDir(absl::string_view const dir) { - return JoinPath(dir, NextTempBasename()); -} - -std::string NewTempAbsPath() { - return NewTempAbsPathInDir(GetAbsoluteTestTmpdir()); -} - -std::string NewTempRelPath() { return NextTempBasename(); } - -std::string GetAbsoluteTestTmpdir() { - // Note that TEST_TMPDIR is guaranteed to be set. - char* env_tmpdir = getenv("TEST_TMPDIR"); - std::string tmp_dir = - env_tmpdir != nullptr ? std::string(env_tmpdir) : "/tmp"; - - return MakeAbsolute(tmp_dir, "").ValueOrDie(); -} - -PosixErrorOr<TempPath> TempPath::CreateFileWith(absl::string_view const parent, - absl::string_view const content, - mode_t const mode) { - return CreateIn(parent, [=](absl::string_view path) -> PosixError { - // CreateWithContents will call open(O_WRONLY) with the given mode. If the - // mode is not user-writable, save/restore cannot preserve the fd. Hence - // the little permission dance that's done here. - auto res = CreateWithContents(path, content, mode | 0200); - RETURN_IF_ERRNO(res); - - return Chmod(path, mode); - }); -} - -PosixErrorOr<TempPath> TempPath::CreateDirWith(absl::string_view const parent, - mode_t const mode) { - return CreateIn(parent, - [=](absl::string_view path) { return Mkdir(path, mode); }); -} - -PosixErrorOr<TempPath> TempPath::CreateSymlinkTo(absl::string_view const parent, - std::string const& dest) { - return CreateIn(parent, [=](absl::string_view path) { - int ret = symlink(dest.c_str(), std::string(path).c_str()); - if (ret != 0) { - return PosixError(errno, "symlink failed"); - } - return NoError(); - }); -} - -PosixErrorOr<TempPath> TempPath::CreateFileIn(absl::string_view const parent) { - return TempPath::CreateFileWith(parent, absl::string_view(), - kDefaultFileMode); -} - -PosixErrorOr<TempPath> TempPath::CreateDirIn(absl::string_view const parent) { - return TempPath::CreateDirWith(parent, kDefaultDirMode); -} - -PosixErrorOr<TempPath> TempPath::CreateFileMode(mode_t mode) { - return TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), absl::string_view(), - mode); -} - -PosixErrorOr<TempPath> TempPath::CreateFile() { - return TempPath::CreateFileIn(GetAbsoluteTestTmpdir()); -} - -PosixErrorOr<TempPath> TempPath::CreateDir() { - return TempPath::CreateDirIn(GetAbsoluteTestTmpdir()); -} - -TempPath::~TempPath() { TryDeleteRecursively(path_); } - -TempPath::TempPath(TempPath&& orig) { reset(orig.release()); } - -TempPath& TempPath::operator=(TempPath&& orig) { - reset(orig.release()); - return *this; -} - -std::string TempPath::reset(std::string newpath) { - std::string path = path_; - TryDeleteRecursively(path_); - path_ = std::move(newpath); - return path; -} - -std::string TempPath::release() { - std::string path = path_; - path_ = std::string(); - return path; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/temp_path.h b/test/util/temp_path.h deleted file mode 100644 index 9e5ac11f4..000000000 --- a/test/util/temp_path.h +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_TEMP_PATH_H_ -#define GVISOR_TEST_UTIL_TEMP_PATH_H_ - -#include <sys/stat.h> - -#include <string> -#include <utility> - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// Returns an absolute path for a file in `dir` that does not yet exist. -// Distinct calls to NewTempAbsPathInDir from the same process, even from -// multiple threads, are guaranteed to return different paths. Distinct calls to -// NewTempAbsPathInDir from different processes are not synchronized. -std::string NewTempAbsPathInDir(absl::string_view const dir); - -// Like NewTempAbsPathInDir, but the returned path is in the test's temporary -// directory, as provided by the testing framework. -std::string NewTempAbsPath(); - -// Like NewTempAbsPathInDir, but the returned path is relative (to the current -// working directory). -std::string NewTempRelPath(); - -// Returns the absolute path for the test temp dir. -std::string GetAbsoluteTestTmpdir(); - -// Represents a temporary file or directory. -class TempPath { - public: - // Default creation mode for files. - static constexpr mode_t kDefaultFileMode = 0644; - - // Default creation mode for directories. - static constexpr mode_t kDefaultDirMode = 0755; - - // Creates a temporary file in directory `parent` with mode `mode` and - // contents `content`. - static PosixErrorOr<TempPath> CreateFileWith(absl::string_view parent, - absl::string_view content, - mode_t mode); - - // Creates an empty temporary subdirectory in directory `parent` with mode - // `mode`. - static PosixErrorOr<TempPath> CreateDirWith(absl::string_view parent, - mode_t mode); - - // Creates a temporary symlink in directory `parent` to destination `dest`. - static PosixErrorOr<TempPath> CreateSymlinkTo(absl::string_view parent, - std::string const& dest); - - // Creates an empty temporary file in directory `parent` with mode - // kDefaultFileMode. - static PosixErrorOr<TempPath> CreateFileIn(absl::string_view parent); - - // Creates an empty temporary subdirectory in directory `parent` with mode - // kDefaultDirMode. - static PosixErrorOr<TempPath> CreateDirIn(absl::string_view parent); - - // Creates an empty temporary file in the test's temporary directory with mode - // `mode`. - static PosixErrorOr<TempPath> CreateFileMode(mode_t mode); - - // Creates an empty temporary file in the test's temporary directory with - // mode kDefaultFileMode. - static PosixErrorOr<TempPath> CreateFile(); - - // Creates an empty temporary subdirectory in the test's temporary directory - // with mode kDefaultDirMode. - static PosixErrorOr<TempPath> CreateDir(); - - // Constructs a TempPath that represents nothing. - TempPath() = default; - - // Constructs a TempPath that represents the given path, which will be deleted - // when the TempPath is destroyed. - explicit TempPath(std::string path) : path_(std::move(path)) {} - - // Attempts to delete the represented temporary file or directory (in the - // latter case, also attempts to delete its contents). - ~TempPath(); - - // Attempts to delete the represented temporary file or directory, then - // transfers ownership of the path represented by orig to this TempPath. - TempPath(TempPath&& orig); - TempPath& operator=(TempPath&& orig); - - // Changes the path this TempPath represents. If the TempPath already - // represented a path, deletes and returns that path. Otherwise returns the - // empty string. - std::string reset(std::string newpath); - std::string reset() { return reset(""); } - - // Forgets and returns the path this TempPath represents. The path is not - // deleted. - std::string release(); - - // Returns the path this TempPath represents. - std::string path() const { return path_; } - - private: - template <typename F> - static PosixErrorOr<TempPath> CreateIn(absl::string_view const parent, - F const& f) { - std::string path = NewTempAbsPathInDir(parent); - RETURN_IF_ERRNO(f(path)); - return TempPath(std::move(path)); - } - - std::string path_; -}; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_TEMP_PATH_H_ diff --git a/test/util/temp_umask.h b/test/util/temp_umask.h deleted file mode 100644 index e7de84a54..000000000 --- a/test/util/temp_umask.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_ -#define GVISOR_TEST_UTIL_TEMP_UMASK_H_ - -#include <sys/stat.h> -#include <sys/types.h> - -namespace gvisor { -namespace testing { - -class TempUmask { - public: - // Sets the process umask to `mask`. - explicit TempUmask(mode_t mask) : old_mask_(umask(mask)) {} - - // Sets the process umask to its previous value. - ~TempUmask() { umask(old_mask_); } - - private: - mode_t old_mask_; -}; - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_ diff --git a/test/util/test_main.cc b/test/util/test_main.cc deleted file mode 100644 index 1f389e58f..000000000 --- a/test/util/test_main.cc +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/test_util.h" - -int main(int argc, char** argv) { - gvisor::testing::TestInit(&argc, &argv); - return gvisor::testing::RunAllTests(); -} diff --git a/test/util/test_util.cc b/test/util/test_util.cc deleted file mode 100644 index 95e1e0c96..000000000 --- a/test/util/test_util.cc +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/test_util.h" - -#include <limits.h> -#include <stdlib.h> -#include <string.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <sys/uio.h> -#include <sys/utsname.h> -#include <unistd.h> - -#include <ctime> -#include <iostream> -#include <vector> - -#include "absl/base/attributes.h" -#include "absl/flags/flag.h" // IWYU pragma: keep -#include "absl/flags/parse.h" // IWYU pragma: keep -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "absl/time/time.h" -#include "test/util/fs_util.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -#define TEST_ON_GVISOR "TEST_ON_GVISOR" -#define GVISOR_NETWORK "GVISOR_NETWORK" - -bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; } - -const std::string GvisorPlatform() { - // Set by runner.go. - char* env = getenv(TEST_ON_GVISOR); - if (!env) { - return Platform::kNative; - } - return std::string(env); -} - -bool IsRunningWithHostinet() { - char* env = getenv(GVISOR_NETWORK); - return env && strcmp(env, "host") == 0; -} - -// Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations -// %ebx contains the address of the global offset table. %rbx is occasionally -// used to address stack variables in presence of dynamic allocas. -#if defined(__x86_64__) -#define GETCPUID(a, b, c, d, a_inp, c_inp) \ - asm("mov %%rbx, %%rdi\n" \ - "cpuid\n" \ - "xchg %%rdi, %%rbx\n" \ - : "=a"(a), "=D"(b), "=c"(c), "=d"(d) \ - : "a"(a_inp), "2"(c_inp)) - -CPUVendor GetCPUVendor() { - uint32_t eax, ebx, ecx, edx; - std::string vendor_str; - // Get vendor string (issue CPUID with eax = 0) - GETCPUID(eax, ebx, ecx, edx, 0, 0); - vendor_str.append(reinterpret_cast<char*>(&ebx), 4); - vendor_str.append(reinterpret_cast<char*>(&edx), 4); - vendor_str.append(reinterpret_cast<char*>(&ecx), 4); - if (vendor_str == "GenuineIntel") { - return CPUVendor::kIntel; - } else if (vendor_str == "AuthenticAMD") { - return CPUVendor::kAMD; - } - return CPUVendor::kUnknownVendor; -} -#endif // defined(__x86_64__) - -bool operator==(const KernelVersion& first, const KernelVersion& second) { - return first.major == second.major && first.minor == second.minor && - first.micro == second.micro; -} - -PosixErrorOr<KernelVersion> ParseKernelVersion(absl::string_view vers_str) { - KernelVersion version = {}; - std::vector<std::string> values = - absl::StrSplit(vers_str, absl::ByAnyChar(".-")); - if (values.size() == 2) { - ASSIGN_OR_RETURN_ERRNO(version.major, Atoi<int>(values[0])); - ASSIGN_OR_RETURN_ERRNO(version.minor, Atoi<int>(values[1])); - return version; - } else if (values.size() >= 3) { - ASSIGN_OR_RETURN_ERRNO(version.major, Atoi<int>(values[0])); - ASSIGN_OR_RETURN_ERRNO(version.minor, Atoi<int>(values[1])); - ASSIGN_OR_RETURN_ERRNO(version.micro, Atoi<int>(values[2])); - return version; - } - return PosixError(EINVAL, absl::StrCat("Unknown kernel release: ", vers_str)); -} - -PosixErrorOr<KernelVersion> GetKernelVersion() { - utsname buf; - RETURN_ERROR_IF_SYSCALL_FAIL(uname(&buf)); - return ParseKernelVersion(buf.release); -} - -std::string CPUSetToString(const cpu_set_t& set, size_t cpus) { - std::string str = "cpuset["; - for (unsigned int n = 0; n < cpus; n++) { - if (CPU_ISSET(n, &set)) { - if (n != 0) { - absl::StrAppend(&str, " "); - } - absl::StrAppend(&str, n); - } - } - absl::StrAppend(&str, "]"); - return str; -} - -// An overloaded operator<< makes it easy to dump the value of an OpenFd. -std::ostream& operator<<(std::ostream& out, OpenFd const& ofd) { - out << ofd.fd << " -> " << ofd.link; - return out; -} - -// An overloaded operator<< makes it easy to dump a vector of OpenFDs. -std::ostream& operator<<(std::ostream& out, std::vector<OpenFd> const& v) { - for (const auto& ofd : v) { - out << ofd << std::endl; - } - return out; -} - -PosixErrorOr<std::vector<OpenFd>> GetOpenFDs() { - // Get the results from /proc/self/fd. - ASSIGN_OR_RETURN_ERRNO(auto dir_list, - ListDir("/proc/self/fd", /*skipdots=*/true)); - - std::vector<OpenFd> ret_fds; - for (const auto& str_fd : dir_list) { - OpenFd open_fd = {}; - ASSIGN_OR_RETURN_ERRNO(open_fd.fd, Atoi<int>(str_fd)); - std::string path = absl::StrCat("/proc/self/fd/", open_fd.fd); - - // Resolve the link. - char buf[PATH_MAX] = {}; - int ret = readlink(path.c_str(), buf, sizeof(buf)); - if (ret < 0) { - if (errno == ENOENT) { - // The FD may have been closed, let's be resilient. - continue; - } - - return PosixError( - errno, absl::StrCat("readlink of ", path, " returned errno ", errno)); - } - open_fd.link = std::string(buf, ret); - ret_fds.emplace_back(std::move(open_fd)); - } - return ret_fds; -} - -PosixErrorOr<uint64_t> Links(const std::string& path) { - struct stat st; - if (stat(path.c_str(), &st)) { - return PosixError(errno, absl::StrCat("Failed to stat ", path)); - } - return static_cast<uint64_t>(st.st_nlink); -} - -void RandomizeBuffer(void* buffer, size_t len) { - struct timespec ts = {}; - clock_gettime(CLOCK_MONOTONIC, &ts); - uint32_t seed = static_cast<uint32_t>(ts.tv_nsec); - char* const buf = static_cast<char*>(buffer); - for (size_t i = 0; i < len; i++) { - buf[i] = rand_r(&seed) % 255; - } -} - -std::vector<std::vector<struct iovec>> GenerateIovecs(uint64_t total_size, - void* buf, - size_t buflen) { - std::vector<std::vector<struct iovec>> result; - for (uint64_t offset = 0; offset < total_size;) { - auto& iovec_array = *result.emplace(result.end()); - - for (; offset < total_size && iovec_array.size() < IOV_MAX; - offset += buflen) { - struct iovec iov = {}; - iov.iov_base = buf; - iov.iov_len = std::min<uint64_t>(total_size - offset, buflen); - iovec_array.push_back(iov); - } - } - - return result; -} - -uint64_t Megabytes(uint64_t n) { - // Overflow check, upper 20 bits in n shouldn't be set. - TEST_CHECK(!(0xfffff00000000000 & n)); - return n << 20; -} - -bool Equivalent(uint64_t current, uint64_t target, double tolerance) { - auto abs_diff = target > current ? target - current : current - target; - return abs_diff <= static_cast<uint64_t>(tolerance * target); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/test_util.h b/test/util/test_util.h deleted file mode 100644 index c5cb9d6d6..000000000 --- a/test/util/test_util.h +++ /dev/null @@ -1,779 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities for syscall testing. -// -// Initialization -// ============== -// -// Prior to calling RUN_ALL_TESTS, all tests must use TestInit(&argc, &argv). -// See the TestInit function for exact side-effects and semantics. -// -// Configuration -// ============= -// -// IsRunningOnGvisor returns true if the test is known to be running on gVisor. -// GvisorPlatform can be used to get more detail: -// -// if (GvisorPlatform() == Platform::kPtrace) { -// ... -// } -// -// SetupGvisorDeathTest ensures that signal handling does not interfere with -/// tests that rely on fatal signals. -// -// Matchers -// ======== -// -// ElementOf(xs) matches if the matched value is equal to an element of the -// container xs. Example: -// -// // PASS -// EXPECT_THAT(1, ElementOf({0, 1, 2})); -// -// // FAIL -// // Value of: 3 -// // Expected: one of {0, 1, 2} -// // Actual: 3 -// EXPECT_THAT(3, ElementOf({0, 1, 2})); -// -// SyscallSucceeds() matches if the syscall is successful. A successful syscall -// is defined by either a return value not equal to -1, or a return value of -1 -// with an errno of 0 (which is a possible successful return for e.g. -// PTRACE_PEEK). Example: -// -// // PASS -// EXPECT_THAT(open("/dev/null", O_RDONLY), SyscallSucceeds()); -// -// // FAIL -// // Value of: open("/", O_RDWR) -// // Expected: not -1 (success) -// // Actual: -1 (of type int), with errno 21 (Is a directory) -// EXPECT_THAT(open("/", O_RDWR), SyscallSucceeds()); -// -// SyscallSucceedsWithValue(m) matches if the syscall is successful, and the -// value also matches m. Example: -// -// // PASS -// EXPECT_THAT(read(4, buf, 8192), SyscallSucceedsWithValue(8192)); -// -// // FAIL -// // Value of: read(-1, buf, 8192) -// // Expected: is equal to 8192 -// // Actual: -1 (of type long), with errno 9 (Bad file number) -// EXPECT_THAT(read(-1, buf, 8192), SyscallSucceedsWithValue(8192)); -// -// // FAIL -// // Value of: read(4, buf, 1) -// // Expected: is > 4096 -// // Actual: 1 (of type long) -// EXPECT_THAT(read(4, buf, 1), SyscallSucceedsWithValue(Gt(4096))); -// -// SyscallFails() matches if the syscall is unsuccessful. An unsuccessful -// syscall is defined by a return value of -1 with a non-zero errno. Example: -// -// // PASS -// EXPECT_THAT(open("/", O_RDWR), SyscallFails()); -// -// // FAIL -// // Value of: open("/dev/null", O_RDONLY) -// // Expected: -1 (failure) -// // Actual: 0 (of type int) -// EXPECT_THAT(open("/dev/null", O_RDONLY), SyscallFails()); -// -// SyscallFailsWithErrno(m) matches if the syscall is unsuccessful, and errno -// matches m. Example: -// -// // PASS -// EXPECT_THAT(open("/", O_RDWR), SyscallFailsWithErrno(EISDIR)); -// -// // PASS -// EXPECT_THAT(open("/etc/passwd", O_RDWR | O_DIRECTORY), -// SyscallFailsWithErrno(AnyOf(EACCES, ENOTDIR))); -// -// // FAIL -// // Value of: open("/dev/null", O_RDONLY) -// // Expected: -1 (failure) with errno 21 (Is a directory) -// // Actual: 0 (of type int) -// EXPECT_THAT(open("/dev/null", O_RDONLY), SyscallFailsWithErrno(EISDIR)); -// -// // FAIL -// // Value of: open("/", O_RDWR) -// // Expected: -1 (failure) with errno 22 (Invalid argument) -// // Actual: -1 (of type int), failure, but with errno 21 (Is a directory) -// EXPECT_THAT(open("/", O_RDWR), SyscallFailsWithErrno(EINVAL)); -// -// Because the syscall matchers encode save/restore functionality, their meaning -// should not be inverted via Not. That is, AnyOf(SyscallSucceedsWithValue(1), -// SyscallSucceedsWithValue(2)) is permitted, but not -// Not(SyscallFailsWithErrno(EPERM)). -// -// Syscalls -// ======== -// -// RetryEINTR wraps a function that returns -1 and sets errno on failure -// to be automatically retried when EINTR occurs. Example: -// -// auto rv = RetryEINTR(waitpid)(pid, &status, 0); -// -// ReadFd/WriteFd/PreadFd/PwriteFd are interface-compatible wrappers around the -// read/write/pread/pwrite syscalls to handle both EINTR and partial -// reads/writes. Example: -// -// EXPECT_THAT(ReadFd(fd, &buf, size), SyscallSucceedsWithValue(size)); -// -// General Utilities -// ================= -// -// ApplyVec(f, xs) returns a vector containing the result of applying function -// `f` to each value in `xs`. -// -// AllBitwiseCombinations takes a variadic number of ranges containing integers -// and returns a vector containing every integer that can be formed by ORing -// together exactly one integer from each list. List<T> is an alias for -// std::initializer_list<T> that makes AllBitwiseCombinations more ergonomic to -// use with list literals (initializer lists do not otherwise participate in -// template argument deduction). Example: -// -// EXPECT_THAT( -// AllBitwiseCombinations<int>( -// List<int>{SOCK_DGRAM, SOCK_STREAM}, -// List<int>{0, SOCK_NONBLOCK}), -// Contains({SOCK_DGRAM, SOCK_STREAM, SOCK_DGRAM | SOCK_NONBLOCK, -// SOCK_STREAM | SOCK_NONBLOCK})); -// -// VecCat takes a variadic number of containers and returns a vector containing -// the concatenated contents. -// -// VecAppend takes an initial container and a variadic number of containers and -// appends each to the initial container. -// -// RandomizeBuffer will use MTRandom to fill the given buffer with random bytes. -// -// GenerateIovecs will return the smallest number of iovec arrays for writing a -// given total number of bytes to a file, each iovec array size up to IOV_MAX, -// each iovec in each array pointing to the same buffer. - -#ifndef GVISOR_TEST_UTIL_TEST_UTIL_H_ -#define GVISOR_TEST_UTIL_TEST_UTIL_H_ - -#include <stddef.h> -#include <stdlib.h> -#include <sys/uio.h> -#include <unistd.h> - -#include <algorithm> -#include <cerrno> -#include <initializer_list> -#include <iterator> -#include <string> -#include <thread> // NOLINT: using std::thread::hardware_concurrency(). -#include <utility> -#include <vector> - -#include "gmock/gmock.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/time/time.h" -#include "test/util/fs_util.h" -#include "test/util/logging.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" - -namespace gvisor { -namespace testing { - -// TestInit must be called prior to RUN_ALL_TESTS. -// -// This parses all arguments and adjusts argc and argv appropriately. -// -// TestInit may create background threads. -void TestInit(int* argc, char*** argv); - -// SKIP_IF may be used to skip a test case. -// -// These cases are still emitted, but a SKIPPED line will appear. -#define SKIP_IF(expr) \ - do { \ - if (expr) GTEST_SKIP() << #expr; \ - } while (0) - -// Platform contains platform names. -namespace Platform { -constexpr char kNative[] = "native"; -constexpr char kPtrace[] = "ptrace"; -constexpr char kKVM[] = "kvm"; -} // namespace Platform - -bool IsRunningOnGvisor(); -const std::string GvisorPlatform(); -bool IsRunningWithHostinet(); - -#ifdef __linux__ -void SetupGvisorDeathTest(); -#endif - -struct KernelVersion { - int major; - int minor; - int micro; -}; - -bool operator==(const KernelVersion& first, const KernelVersion& second); - -PosixErrorOr<KernelVersion> ParseKernelVersion(absl::string_view vers_string); -PosixErrorOr<KernelVersion> GetKernelVersion(); - -static const size_t kPageSize = sysconf(_SC_PAGESIZE); - -enum class CPUVendor { kIntel, kAMD, kUnknownVendor }; - -CPUVendor GetCPUVendor(); - -inline int NumCPUs() { return std::thread::hardware_concurrency(); } - -// Converts cpu_set_t to a std::string for easy examination. -std::string CPUSetToString(const cpu_set_t& set, size_t cpus = CPU_SETSIZE); - -struct OpenFd { - // fd is the open file descriptor number. - int fd = -1; - - // link is the resolution of the symbolic link. - std::string link; -}; - -// Make it easier to log OpenFds to error streams. -std::ostream& operator<<(std::ostream& out, std::vector<OpenFd> const& v); -std::ostream& operator<<(std::ostream& out, OpenFd const& ofd); - -// Gets a detailed list of open fds for this process. -PosixErrorOr<std::vector<OpenFd>> GetOpenFDs(); - -// Returns the number of hard links to a path. -PosixErrorOr<uint64_t> Links(const std::string& path); - -namespace internal { - -template <typename Container> -class ElementOfMatcher { - public: - explicit ElementOfMatcher(Container container) - : container_(::std::move(container)) {} - - template <typename T> - bool MatchAndExplain(T const& rv, - ::testing::MatchResultListener* const listener) const { - using std::count; - return count(container_.begin(), container_.end(), rv) != 0; - } - - void DescribeTo(::std::ostream* const os) const { - *os << "one of {"; - char const* sep = ""; - for (auto const& elem : container_) { - *os << sep << elem; - sep = ", "; - } - *os << "}"; - } - - void DescribeNegationTo(::std::ostream* const os) const { - *os << "none of {"; - char const* sep = ""; - for (auto const& elem : container_) { - *os << sep << elem; - sep = ", "; - } - *os << "}"; - } - - private: - Container const container_; -}; - -template <typename E> -class SyscallSuccessMatcher { - public: - explicit SyscallSuccessMatcher(E expected) - : expected_(::std::move(expected)) {} - - template <typename T> - operator ::testing::Matcher<T>() const { - // E is one of three things: - // - T, or a type losslessly and implicitly convertible to T. - // - A monomorphic Matcher<T>. - // - A polymorphic matcher. - // SafeMatcherCast handles any of the above correctly. - // - // Similarly, gMock will invoke this conversion operator to obtain a - // monomorphic matcher (this is how polymorphic matchers are implemented). - return ::testing::MakeMatcher( - new Impl<T>(::testing::SafeMatcherCast<T>(expected_))); - } - - private: - template <typename T> - class Impl : public ::testing::MatcherInterface<T> { - public: - explicit Impl(::testing::Matcher<T> matcher) - : matcher_(::std::move(matcher)) {} - - bool MatchAndExplain( - T const& rv, - ::testing::MatchResultListener* const listener) const override { - if (rv == static_cast<decltype(rv)>(-1) && errno != 0) { - *listener << "with errno " << PosixError(errno); - return false; - } - bool match = matcher_.MatchAndExplain(rv, listener); - if (match) { - MaybeSave(); - } - return match; - } - - void DescribeTo(::std::ostream* const os) const override { - matcher_.DescribeTo(os); - } - - void DescribeNegationTo(::std::ostream* const os) const override { - matcher_.DescribeNegationTo(os); - } - - private: - ::testing::Matcher<T> matcher_; - }; - - private: - E expected_; -}; - -// A polymorphic matcher equivalent to ::testing::internal::AnyMatcher, except -// not in namespace ::testing::internal, and describing SyscallSucceeds()'s -// match constraints (which are enforced by SyscallSuccessMatcher::Impl). -class AnySuccessValueMatcher { - public: - template <typename T> - operator ::testing::Matcher<T>() const { - return ::testing::MakeMatcher(new Impl<T>()); - } - - private: - template <typename T> - class Impl : public ::testing::MatcherInterface<T> { - public: - bool MatchAndExplain( - T const& rv, - ::testing::MatchResultListener* const listener) const override { - return true; - } - - void DescribeTo(::std::ostream* const os) const override { - *os << "not -1 (success)"; - } - - void DescribeNegationTo(::std::ostream* const os) const override { - *os << "-1 (failure)"; - } - }; -}; - -class SyscallFailureMatcher { - public: - explicit SyscallFailureMatcher(::testing::Matcher<int> errno_matcher) - : errno_matcher_(std::move(errno_matcher)) {} - - template <typename T> - bool MatchAndExplain(T const& rv, - ::testing::MatchResultListener* const listener) const { - if (rv != static_cast<decltype(rv)>(-1)) { - return false; - } - int actual_errno = errno; - *listener << "with errno " << PosixError(actual_errno); - bool match = errno_matcher_.MatchAndExplain(actual_errno, listener); - if (match) { - MaybeSave(); - } - return match; - } - - void DescribeTo(::std::ostream* const os) const { - *os << "-1 (failure), with errno "; - errno_matcher_.DescribeTo(os); - } - - void DescribeNegationTo(::std::ostream* const os) const { - *os << "not -1 (success), with errno "; - errno_matcher_.DescribeNegationTo(os); - } - - private: - ::testing::Matcher<int> errno_matcher_; -}; - -class SpecificErrnoMatcher : public ::testing::MatcherInterface<int> { - public: - explicit SpecificErrnoMatcher(int const expected) : expected_(expected) {} - - bool MatchAndExplain( - int const actual_errno, - ::testing::MatchResultListener* const listener) const override { - return actual_errno == expected_; - } - - void DescribeTo(::std::ostream* const os) const override { - *os << PosixError(expected_); - } - - void DescribeNegationTo(::std::ostream* const os) const override { - *os << "not " << PosixError(expected_); - } - - private: - int const expected_; -}; - -inline ::testing::Matcher<int> SpecificErrno(int const expected) { - return ::testing::MakeMatcher(new SpecificErrnoMatcher(expected)); -} - -} // namespace internal - -template <typename Container> -inline ::testing::PolymorphicMatcher<internal::ElementOfMatcher<Container>> -ElementOf(Container container) { - return ::testing::MakePolymorphicMatcher( - internal::ElementOfMatcher<Container>(::std::move(container))); -} - -template <typename T> -inline ::testing::PolymorphicMatcher< - internal::ElementOfMatcher<::std::vector<T>>> -ElementOf(::std::initializer_list<T> elems) { - return ::testing::MakePolymorphicMatcher( - internal::ElementOfMatcher<::std::vector<T>>(::std::vector<T>(elems))); -} - -template <typename E> -inline internal::SyscallSuccessMatcher<E> SyscallSucceedsWithValue(E expected) { - return internal::SyscallSuccessMatcher<E>(::std::move(expected)); -} - -inline internal::SyscallSuccessMatcher<internal::AnySuccessValueMatcher> -SyscallSucceeds() { - return SyscallSucceedsWithValue( - ::gvisor::testing::internal::AnySuccessValueMatcher()); -} - -inline ::testing::PolymorphicMatcher<internal::SyscallFailureMatcher> -SyscallFailsWithErrno(::testing::Matcher<int> expected) { - return ::testing::MakePolymorphicMatcher( - internal::SyscallFailureMatcher(::std::move(expected))); -} - -// Overload taking an int so that SyscallFailsWithErrno(<specific errno>) uses -// internal::SpecificErrno (which stringifies the errno) rather than -// ::testing::Eq (which doesn't). -inline ::testing::PolymorphicMatcher<internal::SyscallFailureMatcher> -SyscallFailsWithErrno(int const expected) { - return SyscallFailsWithErrno(internal::SpecificErrno(expected)); -} - -inline ::testing::PolymorphicMatcher<internal::SyscallFailureMatcher> -SyscallFails() { - return SyscallFailsWithErrno(::testing::Gt(0)); -} - -// As of GCC 7.2, -Wall => -Wc++17-compat => -Wnoexcept-type generates an -// irrelevant, non-actionable warning about ABI compatibility when -// RetryEINTRImpl is constructed with a noexcept function, such as glibc's -// syscall(). See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80985. -#if defined(__GNUC__) && !defined(__clang__) && \ - (__GNUC__ > 7 || (__GNUC__ == 7 && __GNUC_MINOR__ >= 2)) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wnoexcept-type" -#endif - -namespace internal { - -template <typename F> -struct RetryEINTRImpl { - F const f; - - explicit constexpr RetryEINTRImpl(F f) : f(std::move(f)) {} - - template <typename... Args> - auto operator()(Args&&... args) const - -> decltype(f(std::forward<Args>(args)...)) { - while (true) { - errno = 0; - auto const ret = f(std::forward<Args>(args)...); - if (ret != -1 || errno != EINTR) { - return ret; - } - } - } -}; - -} // namespace internal - -template <typename F> -constexpr internal::RetryEINTRImpl<F> RetryEINTR(F&& f) { - return internal::RetryEINTRImpl<F>(std::forward<F>(f)); -} - -#if defined(__GNUC__) && !defined(__clang__) && \ - (__GNUC__ > 7 || (__GNUC__ == 7 && __GNUC_MINOR__ >= 2)) -#pragma GCC diagnostic pop -#endif - -namespace internal { - -template <typename F> -ssize_t ApplyFileIoSyscall(F const& f, size_t const count) { - size_t completed = 0; - // `do ... while` because some callers actually want to make a syscall with a - // count of 0. - do { - auto const cur = RetryEINTR(f)(completed); - if (cur < 0) { - return cur; - } else if (cur == 0) { - break; - } - completed += cur; - } while (completed < count); - return completed; -} - -} // namespace internal - -inline ssize_t ReadFd(int fd, void* buf, size_t count) { - return internal::ApplyFileIoSyscall( - [&](size_t completed) { - return read(fd, static_cast<char*>(buf) + completed, count - completed); - }, - count); -} - -inline ssize_t WriteFd(int fd, void const* buf, size_t count) { - return internal::ApplyFileIoSyscall( - [&](size_t completed) { - return write(fd, static_cast<char const*>(buf) + completed, - count - completed); - }, - count); -} - -inline ssize_t PreadFd(int fd, void* buf, size_t count, off_t offset) { - return internal::ApplyFileIoSyscall( - [&](size_t completed) { - return pread(fd, static_cast<char*>(buf) + completed, count - completed, - offset + completed); - }, - count); -} - -inline ssize_t PwriteFd(int fd, void const* buf, size_t count, off_t offset) { - return internal::ApplyFileIoSyscall( - [&](size_t completed) { - return pwrite(fd, static_cast<char const*>(buf) + completed, - count - completed, offset + completed); - }, - count); -} - -template <typename T> -using List = std::initializer_list<T>; - -namespace internal { - -template <typename T> -void AppendAllBitwiseCombinations(std::vector<T>* combinations, T current) { - combinations->push_back(current); -} - -template <typename T, typename Arg, typename... Args> -void AppendAllBitwiseCombinations(std::vector<T>* combinations, T current, - Arg&& next, Args&&... rest) { - for (auto const option : next) { - AppendAllBitwiseCombinations(combinations, current | option, rest...); - } -} - -inline size_t CombinedSize(size_t accum) { return accum; } - -template <typename T, typename... Args> -size_t CombinedSize(size_t accum, T const& x, Args&&... xs) { - return CombinedSize(accum + x.size(), std::forward<Args>(xs)...); -} - -// Base case: no more containers, so do nothing. -template <typename T> -void DoMoveExtendContainer(T* c) {} - -// Append each container next to c. -template <typename T, typename U, typename... Args> -void DoMoveExtendContainer(T* c, U&& next, Args&&... rest) { - std::move(std::begin(next), std::end(next), std::back_inserter(*c)); - DoMoveExtendContainer(c, std::forward<Args>(rest)...); -} - -} // namespace internal - -template <typename T = int> -std::vector<T> AllBitwiseCombinations() { - return std::vector<T>(); -} - -template <typename T = int, typename... Args> -std::vector<T> AllBitwiseCombinations(Args&&... args) { - std::vector<T> combinations; - internal::AppendAllBitwiseCombinations(&combinations, 0, args...); - return combinations; -} - -template <typename T, typename U, typename F> -std::vector<T> ApplyVec(F const& f, std::vector<U> const& us) { - std::vector<T> vec; - vec.reserve(us.size()); - for (auto const& u : us) { - vec.push_back(f(u)); - } - return vec; -} - -template <typename T, typename U> -std::vector<T> ApplyVecToVec(std::vector<std::function<T(U)>> const& fs, - std::vector<U> const& us) { - std::vector<T> vec; - vec.reserve(us.size() * fs.size()); - for (auto const& f : fs) { - for (auto const& u : us) { - vec.push_back(f(u)); - } - } - return vec; -} - -// Moves all elements from the containers `args` to the end of `c`. -template <typename T, typename... Args> -void VecAppend(T* c, Args&&... args) { - c->reserve(internal::CombinedSize(c->size(), args...)); - internal::DoMoveExtendContainer(c, std::forward<Args>(args)...); -} - -// Returns a vector containing the concatenated contents of the containers -// `args`. -template <typename T, typename... Args> -std::vector<T> VecCat(Args&&... args) { - std::vector<T> combined; - VecAppend(&combined, std::forward<Args>(args)...); - return combined; -} - -#define RETURN_ERROR_IF_SYSCALL_FAIL(syscall) \ - do { \ - if ((syscall) < 0 && errno != 0) { \ - return PosixError(errno, #syscall); \ - } \ - } while (false) - -// Fill the given buffer with random bytes. -void RandomizeBuffer(void* buffer, size_t len); - -template <typename T> -inline PosixErrorOr<T> Atoi(absl::string_view str) { - T ret; - if (!absl::SimpleAtoi<T>(str, &ret)) { - return PosixError(EINVAL, "String not a number."); - } - return ret; -} - -inline PosixErrorOr<uint64_t> AtoiBase(absl::string_view str, int base) { - if (base > 255 || base < 2) { - return PosixError(EINVAL, "Invalid Base"); - } - - uint64_t ret = 0; - if (!absl::numbers_internal::safe_strtou64_base(str, &ret, base)) { - return PosixError(EINVAL, "String not a number."); - } - - return ret; -} - -inline PosixErrorOr<double> Atod(absl::string_view str) { - double ret; - if (!absl::SimpleAtod(str, &ret)) { - return PosixError(EINVAL, "String not a double type."); - } - return ret; -} - -inline PosixErrorOr<float> Atof(absl::string_view str) { - float ret; - if (!absl::SimpleAtof(str, &ret)) { - return PosixError(EINVAL, "String not a float type."); - } - return ret; -} - -// Return the smallest number of iovec arrays that can be used to write -// "total_bytes" number of bytes, each iovec writing one "buf". -std::vector<std::vector<struct iovec>> GenerateIovecs(uint64_t total_size, - void* buf, size_t buflen); - -// Returns bytes in 'n' megabytes. Used for readability. -uint64_t Megabytes(uint64_t n); - -// Predicate for checking that a value is within some tolerance of another -// value. Returns true iff current is in the range [target * (1 - tolerance), -// target * (1 + tolerance)]. -bool Equivalent(uint64_t current, uint64_t target, double tolerance); - -// Matcher wrapping the Equivalent predicate. -MATCHER_P2(EquivalentWithin, target, tolerance, - std::string(negation ? "Isn't" : "Is") + - ::absl::StrFormat(" within %.2f%% of the target of %zd bytes", - tolerance * 100, target)) { - if (target == 0) { - *result_listener << ::absl::StreamFormat("difference of infinity%%"); - } else { - int64_t delta = static_cast<int64_t>(arg) - static_cast<int64_t>(target); - double delta_percent = - static_cast<double>(delta) / static_cast<double>(target) * 100; - *result_listener << ::absl::StreamFormat("difference of %.2f%%", - delta_percent); - } - return Equivalent(arg, target, tolerance); -} - -// Returns the absolute path to the a data dependency. 'path' is the runfile -// location relative to workspace root. -#ifdef __linux__ -std::string RunfilePath(std::string path); -#endif - -void TestInit(int* argc, char*** argv); -int RunAllTests(void); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_TEST_UTIL_H_ diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc deleted file mode 100644 index 7e1ad9e66..000000000 --- a/test/util/test_util_impl.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <signal.h> - -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "benchmark/benchmark.h" -#include "test/util/logging.h" - -extern bool FLAGS_benchmark_list_tests; -extern std::string FLAGS_benchmark_filter; - -namespace gvisor { -namespace testing { - -void SetupGvisorDeathTest() {} - -void TestInit(int* argc, char*** argv) { - ::testing::InitGoogleTest(argc, *argv); - benchmark::Initialize(argc, *argv); - ::absl::ParseCommandLine(*argc, *argv); - - // Always mask SIGPIPE as it's common and tests aren't expected to handle it. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); -} - -int RunAllTests() { - if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") { - benchmark::RunSpecifiedBenchmarks(); - return 0; - } else { - return RUN_ALL_TESTS(); - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc deleted file mode 100644 index 694d21692..000000000 --- a/test/util/test_util_runfiles.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __fuchsia__ - -#include <iostream> -#include <string> - -#include "test/util/fs_util.h" -#include "test/util/test_util.h" -#include "tools/cpp/runfiles/runfiles.h" - -namespace gvisor { -namespace testing { - -std::string RunfilePath(std::string path) { - static const bazel::tools::cpp::runfiles::Runfiles* const runfiles = [] { - std::string error; - auto* runfiles = - bazel::tools::cpp::runfiles::Runfiles::CreateForTest(&error); - if (runfiles == nullptr) { - std::cerr << "Unable to find runfiles: " << error << std::endl; - } - return runfiles; - }(); - - if (!runfiles) { - // Can't find runfiles? This probably won't work, but __main__/path is our - // best guess. - return JoinPath("__main__", path); - } - - return runfiles->Rlocation(JoinPath("__main__", path)); -} - -} // namespace testing -} // namespace gvisor - -#endif // __fuchsia__ diff --git a/test/util/test_util_test.cc b/test/util/test_util_test.cc deleted file mode 100644 index f42100374..000000000 --- a/test/util/test_util_test.cc +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/test_util.h" - -#include <errno.h> - -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using ::testing::AnyOf; -using ::testing::Gt; -using ::testing::IsEmpty; -using ::testing::Lt; -using ::testing::Not; -using ::testing::TypedEq; -using ::testing::UnorderedElementsAre; -using ::testing::UnorderedElementsAreArray; - -namespace gvisor { -namespace testing { - -namespace { - -TEST(KernelVersionParsing, ValidateParsing) { - KernelVersion v = ASSERT_NO_ERRNO_AND_VALUE( - ParseKernelVersion("4.18.10-1foo2-amd64 baz blah")); - ASSERT_TRUE(v == KernelVersion({4, 18, 10})); - - v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.18.10-1foo2-amd64")); - ASSERT_TRUE(v == KernelVersion({4, 18, 10})); - - v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.18.10-14-amd64")); - ASSERT_TRUE(v == KernelVersion({4, 18, 10})); - - v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.18.10-amd64")); - ASSERT_TRUE(v == KernelVersion({4, 18, 10})); - - v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.18.10")); - ASSERT_TRUE(v == KernelVersion({4, 18, 10})); - - v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.0.10")); - ASSERT_TRUE(v == KernelVersion({4, 0, 10})); - - v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.0")); - ASSERT_TRUE(v == KernelVersion({4, 0, 0})); - - ASSERT_THAT(ParseKernelVersion("4.a"), PosixErrorIs(EINVAL, ::testing::_)); - ASSERT_THAT(ParseKernelVersion("3"), PosixErrorIs(EINVAL, ::testing::_)); - ASSERT_THAT(ParseKernelVersion(""), PosixErrorIs(EINVAL, ::testing::_)); - ASSERT_THAT(ParseKernelVersion("version 3.3.10"), - PosixErrorIs(EINVAL, ::testing::_)); -} - -TEST(MatchersTest, SyscallSucceeds) { - EXPECT_THAT(0, SyscallSucceeds()); - EXPECT_THAT(0L, SyscallSucceeds()); - - errno = 0; - EXPECT_THAT(-1, SyscallSucceeds()); - EXPECT_THAT(-1L, SyscallSucceeds()); - - errno = ENOMEM; - EXPECT_THAT(-1, Not(SyscallSucceeds())); - EXPECT_THAT(-1L, Not(SyscallSucceeds())); -} - -TEST(MatchersTest, SyscallSucceedsWithValue) { - EXPECT_THAT(0, SyscallSucceedsWithValue(0)); - EXPECT_THAT(1, SyscallSucceedsWithValue(Lt(3))); - EXPECT_THAT(-1, Not(SyscallSucceedsWithValue(Lt(3)))); - EXPECT_THAT(4, Not(SyscallSucceedsWithValue(Lt(3)))); - - // Non-int -1 - EXPECT_THAT(-1L, Not(SyscallSucceedsWithValue(0))); - - // Non-int, truncates to -1 if converted to int, with expected value - EXPECT_THAT(0xffffffffL, SyscallSucceedsWithValue(0xffffffffL)); - - // Non-int, truncates to -1 if converted to int, with monomorphic matcher - EXPECT_THAT(0xffffffffL, - SyscallSucceedsWithValue(TypedEq<long>(0xffffffffL))); - - // Non-int, truncates to -1 if converted to int, with polymorphic matcher - EXPECT_THAT(0xffffffffL, SyscallSucceedsWithValue(Gt(1))); -} - -TEST(MatchersTest, SyscallFails) { - EXPECT_THAT(0, Not(SyscallFails())); - EXPECT_THAT(0L, Not(SyscallFails())); - - errno = 0; - EXPECT_THAT(-1, Not(SyscallFails())); - EXPECT_THAT(-1L, Not(SyscallFails())); - - errno = ENOMEM; - EXPECT_THAT(-1, SyscallFails()); - EXPECT_THAT(-1L, SyscallFails()); -} - -TEST(MatchersTest, SyscallFailsWithErrno) { - EXPECT_THAT(0, Not(SyscallFailsWithErrno(EINVAL))); - EXPECT_THAT(0L, Not(SyscallFailsWithErrno(EINVAL))); - - errno = ENOMEM; - EXPECT_THAT(-1, Not(SyscallFailsWithErrno(EINVAL))); - EXPECT_THAT(-1L, Not(SyscallFailsWithErrno(EINVAL))); - - errno = EINVAL; - EXPECT_THAT(-1, SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(-1L, SyscallFailsWithErrno(EINVAL)); - - EXPECT_THAT(-1, SyscallFailsWithErrno(AnyOf(EINVAL, ENOMEM))); - EXPECT_THAT(-1L, SyscallFailsWithErrno(AnyOf(EINVAL, ENOMEM))); - - std::vector<int> expected_errnos({EINVAL, ENOMEM}); - errno = ENOMEM; - EXPECT_THAT(-1, SyscallFailsWithErrno(ElementOf(expected_errnos))); - EXPECT_THAT(-1L, SyscallFailsWithErrno(ElementOf(expected_errnos))); -} - -TEST(AllBitwiseCombinationsTest, NoArguments) { - EXPECT_THAT(AllBitwiseCombinations(), IsEmpty()); -} - -TEST(AllBitwiseCombinationsTest, EmptyList) { - EXPECT_THAT(AllBitwiseCombinations(List<int>{}), IsEmpty()); -} - -TEST(AllBitwiseCombinationsTest, SingleElementList) { - EXPECT_THAT(AllBitwiseCombinations(List<int>{5}), UnorderedElementsAre(5)); -} - -TEST(AllBitwiseCombinationsTest, SingleList) { - EXPECT_THAT(AllBitwiseCombinations(List<int>{0, 1, 2, 4}), - UnorderedElementsAre(0, 1, 2, 4)); -} - -TEST(AllBitwiseCombinationsTest, MultipleLists) { - EXPECT_THAT( - AllBitwiseCombinations(List<int>{0, 1, 2, 3}, List<int>{0, 4, 8, 12}), - UnorderedElementsAreArray( - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})); -} - -TEST(RandomizeBuffer, Works) { - const std::vector<char> original(4096); - std::vector<char> buffer = original; - RandomizeBuffer(buffer.data(), buffer.size()); - EXPECT_NE(buffer, original); -} - -// Enable comparison of vectors of iovec arrays for the following test. -MATCHER_P(IovecsListEq, expected, "") { - if (arg.size() != expected.size()) { - *result_listener << "sizes are different (actual: " << arg.size() - << ", expected: " << expected.size() << ")"; - return false; - } - - for (uint64_t i = 0; i < expected.size(); ++i) { - const std::vector<struct iovec>& actual_iovecs = arg[i]; - const std::vector<struct iovec>& expected_iovecs = expected[i]; - if (actual_iovecs.size() != expected_iovecs.size()) { - *result_listener << "iovec array size at position " << i - << " is different (actual: " << actual_iovecs.size() - << ", expected: " << expected_iovecs.size() << ")"; - return false; - } - - for (uint64_t j = 0; j < expected_iovecs.size(); ++j) { - const struct iovec& actual_iov = actual_iovecs[j]; - const struct iovec& expected_iov = expected_iovecs[j]; - if (actual_iov.iov_base != expected_iov.iov_base) { - *result_listener << "iovecs in array " << i << " at position " << j - << " are different (expected iov_base: " - << expected_iov.iov_base - << ", got: " << actual_iov.iov_base << ")"; - return false; - } - if (actual_iov.iov_len != expected_iov.iov_len) { - *result_listener << "iovecs in array " << i << " at position " << j - << " are different (expected iov_len: " - << expected_iov.iov_len - << ", got: " << actual_iov.iov_len << ")"; - return false; - } - } - } - - return true; -} - -// Verify empty iovec list generation. -TEST(GenerateIovecs, EmptyList) { - std::vector<char> buffer = {'a', 'b', 'c'}; - - EXPECT_THAT(GenerateIovecs(0, buffer.data(), buffer.size()), - IovecsListEq(std::vector<std::vector<struct iovec>>())); -} - -// Verify generating a single array of only one, partial, iovec. -TEST(GenerateIovecs, OneArray) { - std::vector<char> buffer = {'a', 'b', 'c'}; - - std::vector<std::vector<struct iovec>> expected; - struct iovec iov = {}; - iov.iov_base = buffer.data(); - iov.iov_len = 2; - expected.push_back(std::vector<struct iovec>({iov})); - EXPECT_THAT(GenerateIovecs(2, buffer.data(), buffer.size()), - IovecsListEq(expected)); -} - -// Verify that it wraps around after IOV_MAX iovecs. -TEST(GenerateIovecs, WrapsAtIovMax) { - std::vector<char> buffer = {'a', 'b', 'c'}; - - std::vector<std::vector<struct iovec>> expected; - struct iovec iov = {}; - iov.iov_base = buffer.data(); - iov.iov_len = buffer.size(); - expected.emplace_back(); - for (int i = 0; i < IOV_MAX; ++i) { - expected[0].push_back(iov); - } - iov.iov_len = 1; - expected.push_back(std::vector<struct iovec>({iov})); - - EXPECT_THAT( - GenerateIovecs(IOV_MAX * buffer.size() + 1, buffer.data(), buffer.size()), - IovecsListEq(expected)); -} - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/util/thread_util.h b/test/util/thread_util.h deleted file mode 100644 index 923c4fe10..000000000 --- a/test/util/thread_util.h +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_THREAD_UTIL_H_ -#define GVISOR_TEST_UTIL_THREAD_UTIL_H_ - -#include <pthread.h> -#ifdef __linux__ -#include <sys/syscall.h> -#endif -#include <unistd.h> - -#include <functional> -#include <utility> - -#include "test/util/logging.h" - -namespace gvisor { -namespace testing { - -// ScopedThread is a minimal wrapper around pthreads. -// -// This is used in lieu of more complex mechanisms because it provides very -// predictable behavior (no messing with timers, etc.) The thread will -// automatically joined when it is destructed (goes out of scope), but can be -// joined manually as well. -class ScopedThread { - public: - // Constructs a thread that executes f exactly once. - explicit ScopedThread(std::function<void*()> f) : f_(std::move(f)) { - CreateThread(); - } - - explicit ScopedThread(const std::function<void()>& f) { - f_ = [=] { - f(); - return nullptr; - }; - CreateThread(); - } - - ScopedThread(const ScopedThread& other) = delete; - ScopedThread& operator=(const ScopedThread& other) = delete; - - // Joins the thread. - ~ScopedThread() { Join(); } - - // Waits until this thread has finished executing. Join is idempotent and may - // be called multiple times, however Join itself is not thread-safe. - void* Join() { - if (!joined_) { - TEST_PCHECK(pthread_join(pt_, &retval_) == 0); - joined_ = true; - } - return retval_; - } - - private: - void CreateThread() { - TEST_PCHECK_MSG(pthread_create( - &pt_, /* attr = */ nullptr, - +[](void* arg) -> void* { - return static_cast<ScopedThread*>(arg)->f_(); - }, - this) == 0, - "thread creation failed"); - } - - std::function<void*()> f_; - pthread_t pt_; - bool joined_ = false; - void* retval_ = nullptr; -}; - -#ifdef __linux__ -inline pid_t gettid() { return syscall(SYS_gettid); } -#endif - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_THREAD_UTIL_H_ diff --git a/test/util/time_util.cc b/test/util/time_util.cc deleted file mode 100644 index 1ddfbfc9c..000000000 --- a/test/util/time_util.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/time_util.h" - -#include <sys/syscall.h> -#include <unistd.h> - -#include "absl/time/time.h" - -namespace gvisor { -namespace testing { - -void SleepSafe(absl::Duration duration) { - if (duration == absl::ZeroDuration()) { - return; - } - - struct timespec ts = absl::ToTimespec(duration); - int ret; - while (1) { - ret = syscall(__NR_nanosleep, &ts, &ts); - if (ret == 0 || (ret <= 0 && errno != EINTR)) { - break; - } - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/time_util.h b/test/util/time_util.h deleted file mode 100644 index f3ddc9fde..000000000 --- a/test/util/time_util.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_TIME_UTIL_H_ -#define GVISOR_TEST_UTIL_TIME_UTIL_H_ - -#include "absl/time/time.h" - -namespace gvisor { -namespace testing { - -// Sleep for at least the specified duration. Avoids glibc. -void SleepSafe(absl::Duration duration); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_TIME_UTIL_H_ diff --git a/test/util/timer_util.cc b/test/util/timer_util.cc deleted file mode 100644 index 43a26b0d3..000000000 --- a/test/util/timer_util.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/timer_util.h" - -namespace gvisor { -namespace testing { - -absl::Time Now(clockid_t id) { - struct timespec now; - TEST_PCHECK(clock_gettime(id, &now) == 0); - return absl::TimeFromTimespec(now); -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/timer_util.h b/test/util/timer_util.h deleted file mode 100644 index 31aea4fc6..000000000 --- a/test/util/timer_util.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_UTIL_TIMER_UTIL_H_ -#define GVISOR_TEST_UTIL_TIMER_UTIL_H_ - -#include <errno.h> -#include <sys/time.h> - -#include <functional> - -#include "gmock/gmock.h" -#include "absl/time/time.h" -#include "test/util/cleanup.h" -#include "test/util/logging.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" - -namespace gvisor { -namespace testing { - -// MonotonicTimer is a simple timer that uses a monotonic clock. -class MonotonicTimer { - public: - MonotonicTimer() {} - absl::Duration Duration() { - struct timespec ts; - TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) == 0); - return absl::TimeFromTimespec(ts) - start_; - } - - void Start() { - struct timespec ts; - TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) == 0); - start_ = absl::TimeFromTimespec(ts); - } - - protected: - absl::Time start_; -}; - -// Sets the given itimer and returns a cleanup function that restores the -// previous itimer when it goes out of scope. -inline PosixErrorOr<Cleanup> ScopedItimer(int which, - struct itimerval const& new_value) { - struct itimerval old_value; - int rc = setitimer(which, &new_value, &old_value); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "setitimer failed"); - } - return Cleanup(std::function<void(void)>([which, old_value] { - EXPECT_THAT(setitimer(which, &old_value, nullptr), SyscallSucceeds()); - })); -} - -// Returns the current time. -absl::Time Now(clockid_t id); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_UTIL_TIMER_UTIL_H_ diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc deleted file mode 100644 index b131b4b99..000000000 --- a/test/util/uid_util.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/util/posix_error.h" -#include "test/util/save_util.h" - -namespace gvisor { -namespace testing { - -PosixErrorOr<bool> IsRoot() { - uid_t ruid, euid, suid; - int rc = getresuid(&ruid, &euid, &suid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresuid"); - } - if (ruid != 0 || euid != 0 || suid != 0) { - return false; - } - gid_t rgid, egid, sgid; - rc = getresgid(&rgid, &egid, &sgid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresgid"); - } - if (rgid != 0 || egid != 0 || sgid != 0) { - return false; - } - return true; -} - -} // namespace testing -} // namespace gvisor diff --git a/test/util/uid_util.h b/test/util/uid_util.h deleted file mode 100644 index 2cd387fb0..000000000 --- a/test/util/uid_util.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_ -#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_ - -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// Returns true if the caller's real/effective/saved user/group IDs are all 0. -PosixErrorOr<bool> IsRoot(); - -} // namespace testing -} // namespace gvisor - -#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_ diff --git a/tools/BUILD b/tools/BUILD deleted file mode 100644 index e73a9c885..000000000 --- a/tools/BUILD +++ /dev/null @@ -1,3 +0,0 @@ -package(licenses = ["notice"]) - -exports_files(["nogo.js"]) diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD deleted file mode 100644 index 00a467473..000000000 --- a/tools/bazeldefs/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -package(licenses = ["notice"]) - -# In bazel, no special support is required for loopback networking. This is -# just a dummy data target that does not change the test environment. -genrule( - name = "loopback", - outs = ["loopback.txt"], - cmd = "touch $@", - visibility = ["//:sandbox"], -) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl deleted file mode 100644 index 905b16d41..000000000 --- a/tools/bazeldefs/defs.bzl +++ /dev/null @@ -1,93 +0,0 @@ -"""Bazel implementations of standard rules.""" - -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") -load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library") -load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library") -load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") -load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") -load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") -load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") -load("@pydeps//:requirements.bzl", _py_requirement = "requirement") - -container_image = _container_image -cc_binary = _cc_binary -cc_library = _cc_library -cc_flags_supplier = _cc_flags_supplier -cc_proto_library = _cc_proto_library -cc_test = _cc_test -cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" -go_image = _go_image -go_embed_data = _go_embed_data -gtest = "@com_google_googletest//:gtest" -gbenchmark = "@com_google_benchmark//:benchmark" -loopback = "//tools/bazeldefs:loopback" -proto_library = native.proto_library -pkg_deb = _pkg_deb -pkg_tar = _pkg_tar -py_library = native.py_library -py_binary = native.py_binary -py_test = native.py_test - -def go_binary(name, static = False, pure = False, **kwargs): - if static: - kwargs["static"] = "on" - if pure: - kwargs["pure"] = "on" - _go_binary( - name = name, - **kwargs - ) - -def go_library(name, **kwargs): - _go_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name(), - **kwargs - ) - -def go_tool_library(name, **kwargs): - _go_tool_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name(), - **kwargs - ) - -def go_proto_library(name, proto, **kwargs): - deps = kwargs.pop("deps", []) - _go_proto_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name, - proto = proto, - deps = [dep.replace("_proto", "_go_proto") for dep in deps], - **kwargs - ) - -def go_test(name, **kwargs): - library = kwargs.pop("library", None) - if library: - kwargs["embed"] = [library] - _go_test( - name = name, - **kwargs - ) - -def py_requirement(name, direct = True): - return _py_requirement(name) - -def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): - values = { - "@bazel_tools//src/conditions:linux_x86_64": amd64, - "@bazel_tools//src/conditions:linux_aarch64": arm64, - } - if default: - values["//conditions:default"] = default - return select(values, **kwargs) - -def select_system(linux = ["__linux__"], **kwargs): - return linux # Only Linux supported. - -def default_installer(): - return None - -def default_net_util(): - return [] # Nothing needed. diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl deleted file mode 100644 index 92b0b5fc0..000000000 --- a/tools/bazeldefs/platforms.bzl +++ /dev/null @@ -1,17 +0,0 @@ -"""List of platforms.""" - -# Platform to associated tags. -platforms = { - "ptrace": [ - # TODO(b/120560048): Make the tests run without this tag. - "no-sandbox", - ], - "kvm": [ - "manual", - "local", - # TODO(b/120560048): Make the tests run without this tag. - "no-sandbox", - ], -} - -default_platform = "ptrace" diff --git a/tools/bazeldefs/tags.bzl b/tools/bazeldefs/tags.bzl deleted file mode 100644 index 558fb53ae..000000000 --- a/tools/bazeldefs/tags.bzl +++ /dev/null @@ -1,40 +0,0 @@ -"""List of special Go suffixes.""" - -go_suffixes = [ - "_386", - "_386_unsafe", - "_aarch64", - "_aarch64_unsafe", - "_amd64", - "_amd64_unsafe", - "_arm", - "_arm64", - "_arm64_unsafe", - "_arm_unsafe", - "_impl", - "_impl_unsafe", - "_linux", - "_linux_unsafe", - "_mips", - "_mips64", - "_mips64_unsafe", - "_mips64le", - "_mips64le_unsafe", - "_mips_unsafe", - "_mipsle", - "_mipsle_unsafe", - "_opts", - "_opts_unsafe", - "_ppc64", - "_ppc64_unsafe", - "_ppc64le", - "_ppc64le_unsafe", - "_riscv64", - "_riscv64_unsafe", - "_s390x", - "_s390x_unsafe", - "_sparc64", - "_sparc64_unsafe", - "_wasm", - "_wasm_unsafe", -] diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD deleted file mode 100644 index 4f1a31a6d..000000000 --- a/tools/checkunsafe/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("//tools:defs.bzl", "go_tool_library") - -package(licenses = ["notice"]) - -go_tool_library( - name = "checkunsafe", - srcs = ["check_unsafe.go"], - visibility = ["//:sandbox"], - deps = [ - "@org_golang_x_tools//go/analysis:go_tool_library", - ], -) diff --git a/tools/checkunsafe/check_unsafe.go b/tools/checkunsafe/check_unsafe.go deleted file mode 100644 index 4ccd7cc5a..000000000 --- a/tools/checkunsafe/check_unsafe.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package checkunsafe allows unsafe imports only in files named appropriately. -package checkunsafe - -import ( - "fmt" - "path" - "strconv" - "strings" - - "golang.org/x/tools/go/analysis" -) - -// Analyzer defines the entrypoint. -var Analyzer = &analysis.Analyzer{ - Name: "checkunsafe", - Doc: "allows unsafe use only in specified files", - Run: run, -} - -func run(pass *analysis.Pass) (interface{}, error) { - for _, f := range pass.Files { - for _, imp := range f.Imports { - // Is this an unsafe import? - pkg, err := strconv.Unquote(imp.Path.Value) - if err != nil || pkg != "unsafe" { - continue - } - - // Extract the filename. - filename := pass.Fset.File(imp.Pos()).Name() - - // Allow files named _unsafe.go or _test.go to opt out. - if strings.HasSuffix(filename, "_unsafe.go") || strings.HasSuffix(filename, "_test.go") { - continue - } - - // Throw the error. - pass.Reportf(imp.Pos(), fmt.Sprintf("package unsafe imported by %s; must end with _unsafe.go", path.Base(filename))) - } - } - return nil, nil -} diff --git a/tools/defs.bzl b/tools/defs.bzl deleted file mode 100644 index 15a310403..000000000 --- a/tools/defs.bzl +++ /dev/null @@ -1,222 +0,0 @@ -"""Wrappers for common build rules. - -These wrappers apply common BUILD configurations (e.g., proto_library -automagically creating cc_ and go_ proto targets) and act as a single point of -change for Google-internal and bazel-compatible rules. -""" - -load("//tools/go_stateify:defs.bzl", "go_stateify") -load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") -load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") -load("//tools/bazeldefs:tags.bzl", "go_suffixes") - -# Delegate directly. -cc_binary = _cc_binary -cc_library = _cc_library -cc_test = _cc_test -cc_toolchain = _cc_toolchain -cc_flags_supplier = _cc_flags_supplier -container_image = _container_image -go_embed_data = _go_embed_data -go_image = _go_image -go_test = _go_test -go_tool_library = _go_tool_library -gtest = _gtest -gbenchmark = _gbenchmark -pkg_deb = _pkg_deb -pkg_tar = _pkg_tar -py_library = _py_library -py_binary = _py_binary -py_test = _py_test -py_requirement = _py_requirement -select_arch = _select_arch -select_system = _select_system -loopback = _loopback -default_installer = _default_installer -default_net_util = _default_net_util -platforms = _platforms -default_platform = _default_platform - -def go_binary(name, **kwargs): - """Wraps the standard go_binary. - - Args: - name: the rule name. - **kwargs: standard go_binary arguments. - """ - _go_binary( - name = name, - **kwargs - ) - -def calculate_sets(srcs): - """Calculates special Go sets for templates. - - Args: - srcs: the full set of Go sources. - - Returns: - A dictionary of the form: - - "": [src1.go, src2.go] - "suffix": [src3suffix.go, src4suffix.go] - - Note that suffix will typically start with '_'. - """ - result = dict() - for file in srcs: - if not file.endswith(".go"): - continue - target = "" - for suffix in go_suffixes: - if file.endswith(suffix + ".go"): - target = suffix - if not target in result: - result[target] = [file] - else: - result[target].append(file) - return result - -def go_imports(name, src, out): - """Simplify a single Go source file by eliminating unused imports.""" - native.genrule( - name = name, - srcs = [src], - outs = [out], - tools = ["@org_golang_x_tools//cmd/goimports:goimports"], - cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), - ) - -def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs): - """Wraps the standard go_library and does stateification and marshalling. - - The recommended way is to use this rule with mostly identical configuration as the native - go_library rule. - - These definitions provide additional flags (stateify, marshal) that can be used - with the generators to automatically supplement the library code. - - load("//tools:defs.bzl", "go_library") - - go_library( - name = "foo", - srcs = ["foo.go"], - ) - - Args: - name: the rule name. - srcs: the library sources. - deps: the library dependencies. - imports: imports required for stateify. - stateify: whether statify is enabled (default: true). - marshal: whether marshal is enabled (default: false). - marshal_debug: whether the gomarshal tools emits debugging output (default: false). - **kwargs: standard go_library arguments. - """ - all_srcs = srcs - all_deps = deps - dirname, _, _ = native.package_name().rpartition("/") - full_pkg = dirname + "/" + name - if stateify: - # Only do stateification for non-state packages without manual autogen. - # First, we need to segregate the input files via the special suffixes, - # and calculate the final output set. - state_sets = calculate_sets(srcs) - for (suffix, src_subset) in state_sets.items(): - go_stateify( - name = name + suffix + "_state_autogen_with_imports", - srcs = src_subset, - imports = imports, - package = full_pkg, - out = name + suffix + "_state_autogen_with_imports.go", - ) - go_imports( - name = name + suffix + "_state_autogen", - src = name + suffix + "_state_autogen_with_imports.go", - out = name + suffix + "_state_autogen.go", - ) - all_srcs = all_srcs + [ - name + suffix + "_state_autogen.go" - for suffix in state_sets.keys() - ] - if "//pkg/state" not in all_deps: - all_deps = all_deps + ["//pkg/state"] - - if marshal: - # See above. - marshal_sets = calculate_sets(srcs) - for (suffix, src_subset) in marshal_sets.items(): - go_marshal( - name = name + suffix + "_abi_autogen", - srcs = src_subset, - debug = select({ - "//tools/go_marshal:marshal_config_verbose": True, - "//conditions:default": marshal_debug, - }), - imports = imports, - package = name, - ) - extra_deps = [ - dep - for dep in marshal_deps - if not dep in all_deps - ] - all_deps = all_deps + extra_deps - all_srcs = all_srcs + [ - name + suffix + "_abi_autogen_unsafe.go" - for suffix in marshal_sets.keys() - ] - - _go_library( - name = name, - srcs = all_srcs, - deps = all_deps, - **kwargs - ) - - if marshal: - # Ignore importpath for go_test. - kwargs.pop("importpath", None) - - # See above. - marshal_sets = calculate_sets(srcs) - for (suffix, _) in marshal_sets.items(): - _go_test( - name = name + suffix + "_abi_autogen_test", - srcs = [name + suffix + "_abi_autogen_test.go"], - library = ":" + name, - deps = marshal_test_deps, - **kwargs - ) - -def proto_library(name, srcs, **kwargs): - """Wraps the standard proto_library. - - Given a proto_library named "foo", this produces three different targets: - - foo_proto: proto_library rule. - - foo_go_proto: go_proto_library rule. - - foo_cc_proto: cc_proto_library rule. - - Args: - srcs: the proto sources. - **kwargs: standard proto_library arguments. - """ - deps = kwargs.pop("deps", []) - _proto_library( - name = name + "_proto", - srcs = srcs, - deps = deps, - **kwargs - ) - _go_proto_library( - name = name + "_go_proto", - proto = ":" + name + "_proto", - deps = deps, - **kwargs - ) - _cc_proto_library( - name = name + "_cc_proto", - deps = [":" + name + "_proto"], - **kwargs - ) diff --git a/tools/go_branch.sh b/tools/go_branch.sh deleted file mode 100755 index f97a74aaf..000000000 --- a/tools/go_branch.sh +++ /dev/null @@ -1,94 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -eo pipefail - -# Discovery the package name from the go.mod file. -declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2) -declare -r origpwd=$(pwd) -declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE") - -# Check that gopath has been built. -declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}" -if ! [ -d "${gopath_dir}" ]; then - echo "No gopath directory found; build the :gopath target." >&2 - exit 1 -fi - -# Create a temporary working directory, and ensure that this directory and all -# subdirectories are cleaned up upon exit. -declare -r tmp_dir=$(mktemp -d) -finish() { - cd # Leave tmp_dir. - rm -rf "${tmp_dir}" -} -trap finish EXIT - -# Record the current working commit. -declare -r head=$(git describe --always) - -# We expect to have an existing go branch that we will use as the basis for -# this commit. That branch may be empty, but it must exist. -declare -r go_branch=$(git show-ref --hash origin/go) - -# Clone the current repository to the temporary directory, and check out the -# current go_branch directory. We move to the new repository for convenience. -declare -r repo_orig="$(pwd)" -declare -r repo_new="${tmp_dir}/repository" -git clone . "${repo_new}" -cd "${repo_new}" - -# Setup the repository and checkout the branch. -git config user.email "gvisor-bot@google.com" -git config user.name "gVisor bot" -git fetch origin "${go_branch}" -git checkout -b go "${go_branch}" - -# Start working on a merge commit that combines the previous history with the -# current history. Note that we don't actually want any changes yet. -# -# N.B. The git behavior changed at some point and the relevant flag was added -# to allow for override, so try the only behavior first then pass the flag. -git merge --no-commit --strategy ours ${head} || \ - git merge --allow-unrelated-histories --no-commit --strategy ours ${head} - -# Sync the entire gopath_dir. -rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" . - -# Add additional files. -for file in "${othersrc[@]}"; do - cp "${origpwd}"/"${file}" . -done - -# Construct a new README.md. -cat > README.md <<EOF -# gVisor - -This branch is a synthetic branch, containing only Go sources, that is -compatible with standard Go tools. See the master branch for authoritative -sources and tests. -EOF - -# There are a few solitary files that can get left behind due to the way bazel -# constructs the gopath target. Note that we don't find all Go files here -# because they may correspond to unused templates, etc. -cp "${repo_orig}"/runsc/*.go runsc/ - -# Update the current working set and commit. -git add . && git commit -m "Merge ${head} (automated)" - -# Push the branch back to the original repository. -git remote add orig "${repo_orig}" && git push -f orig go:go diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD deleted file mode 100644 index 32a949c93..000000000 --- a/tools/go_generics/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "go_generics", - srcs = [ - "generics.go", - "imports.go", - "remove.go", - ], - visibility = ["//:sandbox"], - deps = ["//tools/go_generics/globals"], -) - -genrule( - name = "go_generics_tests", - srcs = glob(["generics_tests/**"]) + [":go_generics"], - outs = ["go_generics_tests.tgz"], - cmd = "tar -czvhf $@ $(SRCS)", -) - -genrule( - name = "go_generics_test_bundle", - srcs = [ - ":go_generics_tests.tgz", - ":go_generics_unittest.sh", - ], - outs = ["go_generics_test.sh"], - cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@", - executable = True, -) - -sh_test( - name = "go_generics_test", - size = "small", - srcs = ["go_generics_test.sh"], -) diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl deleted file mode 100644 index c5be52ecd..000000000 --- a/tools/go_generics/defs.bzl +++ /dev/null @@ -1,140 +0,0 @@ -def _go_template_impl(ctx): - input = ctx.files.srcs - output = ctx.outputs.out - - args = ["-o=%s" % output.path] + [f.path for f in input] - - ctx.actions.run( - inputs = input, - outputs = [output], - mnemonic = "GoGenericsTemplate", - progress_message = "Building Go template %s" % ctx.label, - arguments = args, - executable = ctx.executable._tool, - ) - - return struct( - types = ctx.attr.types, - opt_types = ctx.attr.opt_types, - consts = ctx.attr.consts, - opt_consts = ctx.attr.opt_consts, - deps = ctx.attr.deps, - file = output, - ) - -""" -Generates a Go template from a set of Go files. - -A Go template is similar to a go library, except that it has certain types that -can be replaced before usage. For example, one could define a templatized List -struct, whose elements are of type T, then instantiate that template for -T=segment, where "segment" is the concrete type. - -Args: - name: the name of the template. - srcs: the list of source files that comprise the template. - types: the list of generic types in the template that are required to be specified. - opt_types: the list of generic types in the template that can but aren't required to be specified. - consts: the list of constants in the template that are required to be specified. - opt_consts: the list of constants in the template that can but aren't required to be specified. - deps: the list of dependencies. -""" -go_template = rule( - implementation = _go_template_impl, - attrs = { - "srcs": attr.label_list(mandatory = True, allow_files = True), - "deps": attr.label_list(allow_files = True), - "types": attr.string_list(), - "opt_types": attr.string_list(), - "consts": attr.string_list(), - "opt_consts": attr.string_list(), - "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics/go_merge")), - }, - outputs = { - "out": "%{name}_template.go", - }, -) - -def _go_template_instance_impl(ctx): - template = ctx.attr.template - output = ctx.outputs.out - - # Check that all required types are defined. - for t in template.types: - if t not in ctx.attr.types: - fail("Missing value for type %s in %s" % (t, ctx.attr.template.label)) - - # Check that all defined types are expected by the template. - for t in ctx.attr.types: - if (t not in template.types) and (t not in template.opt_types): - fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label)) - - # Check that all required consts are defined. - for t in template.consts: - if t not in ctx.attr.consts: - fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label)) - - # Check that all defined consts are expected by the template. - for t in ctx.attr.consts: - if (t not in template.consts) and (t not in template.opt_consts): - fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label)) - - # Build the argument list. - args = ["-i=%s" % template.file.path, "-o=%s" % output.path] - args += ["-p=%s" % ctx.attr.package] - - if len(ctx.attr.prefix) > 0: - args += ["-prefix=%s" % ctx.attr.prefix] - - if len(ctx.attr.suffix) > 0: - args += ["-suffix=%s" % ctx.attr.suffix] - - args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()] - args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()] - args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()] - - if ctx.attr.anon: - args += ["-anon"] - - ctx.actions.run( - inputs = [template.file], - outputs = [output], - mnemonic = "GoGenericsInstance", - progress_message = "Building Go template instance %s" % ctx.label, - arguments = args, - executable = ctx.executable._tool, - ) - - # TODO: How can we get the dependencies out? - return struct( - files = depset([output]), - ) - -""" -Instantiates a Go template by replacing all generic types with concrete ones. - -Args: - name: the name of the template instance. - template: the label of the template to be instatiated. - prefix: a prefix to be added to globals in the template. - suffix: a suffix to be added to global in the template. - types: the map from generic type names to concrete ones. - consts: the map from constant names to their values. - imports: the map from imports used in types/consts to their import paths. - package: the name of the package the instantiated template will be compiled into. -""" -go_template_instance = rule( - implementation = _go_template_instance_impl, - attrs = { - "template": attr.label(mandatory = True, providers = ["types"]), - "prefix": attr.string(), - "suffix": attr.string(), - "types": attr.string_dict(), - "consts": attr.string_dict(), - "imports": attr.string_dict(), - "anon": attr.bool(mandatory = False, default = False), - "package": attr.string(mandatory = True), - "out": attr.output(mandatory = True), - "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")), - }, -) diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go deleted file mode 100644 index e9cc2c753..000000000 --- a/tools/go_generics/generics.go +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// go_generics reads a Go source file and writes a new version of that file with -// a few transformations applied to each. Namely: -// -// 1. Global types can be explicitly renamed with the -t option. For example, -// if -t=A=B is passed in, all references to A will be replaced with -// references to B; a function declaration like: -// -// func f(arg *A) -// -// would be renamed to: -// -// func f(arg *B) -// -// 2. Global type definitions and their method sets will be removed when they're -// being renamed with -t. For example, if -t=A=B is passed in, the following -// definition and methods that existed in the input file wouldn't exist at -// all in the output file: -// -// type A struct{} -// -// func (*A) f() {} -// -// 3. All global types, variables, constants and functions (not methods) are -// prefixed and suffixed based on the option -prefix and -suffix arguments. -// For example, if -suffix=A is passed in, the following globals: -// -// func f() -// type t struct{} -// -// would be renamed to: -// -// func fA() -// type tA struct{} -// -// Some special tags are also modified. For example: -// -// "state:.(t)" -// -// would become: -// -// "state:.(tA)" -// -// 4. The package is renamed to the value via the -p argument. -// 5. Value of constants can be modified with -c argument. -// -// Note that not just the top-level declarations are renamed, all references to -// them are also properly renamed as well, taking into account visibility rules -// and shadowing. For example, if -suffix=A is passed in, the following: -// -// var b = 100 -// -// func f() { -// g(b) -// b := 0 -// g(b) -// } -// -// Would be replaced with: -// -// var bA = 100 -// -// func f() { -// g(bA) -// b := 0 -// g(b) -// } -// -// Note that the second call to g() kept "b" as an argument because it refers to -// the local variable "b". -// -// Note that go_generics can handle anonymous fields with renamed types if -// -anon is passed in, however it does not perform strict checking on parameter -// types that share the same name as the global type and therefore will rename -// them as well. -// -// You can see an example in the tools/go_generics/generics_tests/interface test. -package main - -import ( - "bytes" - "flag" - "fmt" - "go/ast" - "go/format" - "go/parser" - "go/token" - "io/ioutil" - "os" - "regexp" - "strings" - - "gvisor.dev/gvisor/tools/go_generics/globals" -) - -var ( - input = flag.String("i", "", "input `file`") - output = flag.String("o", "", "output `file`") - suffix = flag.String("suffix", "", "`suffix` to add to each global symbol") - prefix = flag.String("prefix", "", "`prefix` to add to each global symbol") - packageName = flag.String("p", "main", "output package `name`") - printAST = flag.Bool("ast", false, "prints the AST") - processAnon = flag.Bool("anon", false, "process anonymous fields") - types = make(mapValue) - consts = make(mapValue) - imports = make(mapValue) -) - -// mapValue implements flag.Value. We use a mapValue flag instead of a regular -// string flag when we want to allow more than one instance of the flag. For -// example, we allow several "-t A=B" arguments, and will rename them all. -type mapValue map[string]string - -func (m mapValue) String() string { - var b bytes.Buffer - first := true - for k, v := range m { - if !first { - b.WriteRune(',') - } else { - first = false - } - b.WriteString(k) - b.WriteRune('=') - b.WriteString(v) - } - return b.String() -} - -func (m mapValue) Set(s string) error { - sep := strings.Index(s, "=") - if sep == -1 { - return fmt.Errorf("missing '=' from '%s'", s) - } - - m[s[:sep]] = s[sep+1:] - - return nil -} - -// stateTagRegexp matches against the 'typed' state tags. -var stateTagRegexp = regexp.MustCompile(`^(.*[^a-z0-9_])state:"\.\(([^\)]*)\)"(.*)$`) - -var identifierRegexp = regexp.MustCompile(`^(.*[^a-zA-Z_])([a-zA-Z_][a-zA-Z0-9_]*)(.*)$`) - -func main() { - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) - flag.PrintDefaults() - } - - flag.Var(types, "t", "rename type A to B when `A=B` is passed in. Multiple such mappings are allowed.") - flag.Var(consts, "c", "reassign constant A to value B when `A=B` is passed in. Multiple such mappings are allowed.") - flag.Var(imports, "import", "specifies the import libraries to use when types are not local. `name=path` specifies that 'name', used in types as name.type, refers to the package living in 'path'.") - flag.Parse() - - if *input == "" || *output == "" { - flag.Usage() - os.Exit(1) - } - - // Parse the input file. - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, *input, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors) - if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - os.Exit(1) - } - - // Print the AST if requested. - if *printAST { - ast.Print(fset, f) - } - - cmap := ast.NewCommentMap(fset, f, f.Comments) - - // Update imports based on what's used in types and consts. - maps := []mapValue{types, consts} - importDecl, err := updateImports(maps, imports) - if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - os.Exit(1) - } - types = maps[0] - consts = maps[1] - - // Reassign all specified constants. - for _, decl := range f.Decls { - d, ok := decl.(*ast.GenDecl) - if !ok || d.Tok != token.CONST { - continue - } - - for _, gs := range d.Specs { - s := gs.(*ast.ValueSpec) - for i, id := range s.Names { - if n, ok := consts[id.Name]; ok { - s.Values[i] = &ast.BasicLit{Value: n} - } - } - } - } - - // Go through all globals and their uses in the AST and rename the types - // with explicitly provided names, and rename all types, variables, - // consts and functions with the provided prefix and suffix. - globals.Visit(fset, f, func(ident *ast.Ident, kind globals.SymKind) { - if n, ok := types[ident.Name]; ok && kind == globals.KindType { - ident.Name = n - } else { - switch kind { - case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction: - ident.Name = *prefix + ident.Name + *suffix - case globals.KindTag: - // Modify the state tag appropriately. - if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil { - if t := identifierRegexp.FindStringSubmatch(m[2]); t != nil { - typeName := *prefix + t[2] + *suffix - if n, ok := types[t[2]]; ok { - typeName = n - } - ident.Name = m[1] + `state:".(` + t[1] + typeName + t[3] + `)"` + m[3] - } - } - } - } - }, *processAnon) - - // Remove the definition of all types that are being remapped. - set := make(typeSet) - for _, v := range types { - set[v] = struct{}{} - } - removeTypes(set, f) - - // Add the new imports, if any, to the top. - if importDecl != nil { - newDecls := make([]ast.Decl, 0, len(f.Decls)+1) - newDecls = append(newDecls, importDecl) - newDecls = append(newDecls, f.Decls...) - f.Decls = newDecls - } - - // Update comments to remove the ones potentially associated with the - // type T that we removed. - f.Comments = cmap.Filter(f).Comments() - - // If there are file (package) comments, delete them. - if f.Doc != nil { - for i, cg := range f.Comments { - if cg == f.Doc { - f.Comments = append(f.Comments[:i], f.Comments[i+1:]...) - break - } - } - } - - // Write the output file. - f.Name.Name = *packageName - - var buf bytes.Buffer - if err := format.Node(&buf, fset, f); err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - os.Exit(1) - } - - if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - os.Exit(1) - } -} diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/generics_tests/all_stmts/input.go deleted file mode 100644 index 4791d1ff1..000000000 --- a/tools/go_generics/generics_tests/all_stmts/input.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "sync" -) - -type T int - -func h(T) { -} - -type s struct { - a, b int - c []int -} - -func g(T) *s { - return &s{} -} - -func f() (T, []int) { - // Branch. - goto T - goto R - - // Labeled. -T: - _ = T(0) - - // Empty. -R: - ; - - // Assignment with definition. - a, b, c := T(1), T(2), T(3) - _, _, _ = a, b, c - - // Assignment without definition. - g(T(0)).a, g(T(1)).b, c = int(T(1)), int(T(2)), T(3) - _, _, _ = a, b, c - - // Block. - { - var T T - T = 0 - _ = T - } - - // Declarations. - type Type T - const Const T = 10 - var g1 func(T, int, ...T) (int, T) - var v T - var w = T(0) - { - var T struct { - f []T - } - _ = T - } - - // Defer. - defer g1(T(0), 1) - - // Expression. - h(v + w + T(1)) - - // For statements. - for i := T(0); i < T(10); i++ { - var T func(int) T - v := T(0) - _ = v - } - - for { - var T func(int) T - v := T(0) - _ = v - } - - // Go. - go g1(T(0), 1) - - // If statements. - if a != T(1) { - var T func(int) T - v := T(0) - _ = v - } - - if a := T(0); a != T(1) { - var T func(int) T - v := T(0) - _ = v - } - - if a := T(0); a != T(1) { - var T func(int) T - v := T(0) - _ = v - } else if b := T(0); b != T(1) { - var T func(int) T - v := T(0) - _ = v - } else if T := T(0); T != 1 { - T++ - } else { - T-- - } - - if a := T(0); a != T(1) { - var T func(int) T - v := T(0) - _ = v - } else { - var T func(int) T - v := T(0) - _ = v - } - - // Inc/Dec statements. - (*(*T)(nil))++ - (*(*T)(nil))-- - - // Range statements. - for g(T(0)).a, g(T(1)).b = range g(T(10)).c { - var d T - _ = d - } - - for T, b := range g(T(10)).c { - _ = T - _ = b - } - - // Select statement. - { - var fch func(T) chan int - - select { - case <-fch(T(30)): - var T T - T = 0 - _ = T - default: - var T T - T = 0 - _ = T - case T := <-fch(T(30)): - T = 0 - _ = T - case g(T(0)).a = <-fch(T(30)): - var T T - T = 0 - _ = T - case fch(T(30)) <- int(T(0)): - var T T - T = 0 - _ = T - } - } - - // Send statements. - { - var ch chan T - var fch func(T) chan int - - ch <- T(0) - fch(T(1)) <- g(T(10)).a - } - - // Switch statements. - { - var a T - var b int - switch { - case a == T(0): - var T T - T = 0 - _ = T - case a < T(0), b < g(T(10)).a: - var T T - T = 0 - _ = T - default: - var T T - T = 0 - _ = T - } - } - - switch T(g(T(10)).a) { - case T(0): - var T T - T = 0 - _ = T - case T(1), T(g(T(10)).a): - var T T - T = 0 - _ = T - default: - var T T - T = 0 - _ = T - } - - switch b := g(T(10)); T(b.a) + T(10) { - case T(0): - var T T - T = 0 - _ = T - case T(1), T(g(T(10)).a): - var T T - T = 0 - _ = T - default: - var T T - T = 0 - _ = T - } - - // Type switch statements. - { - var interfaceFunc func(T) interface{} - - switch interfaceFunc(T(0)).(type) { - case *T, T, int: - var T T - T = 0 - _ = T - case sync.Mutex, **T: - var T T - T = 0 - _ = T - default: - var T T - T = 0 - _ = T - } - - switch x := interfaceFunc(T(0)).(type) { - case *T, T, int: - var T T - T = 0 - _ = T - _ = x - case sync.Mutex, **T: - var T T - T = 0 - _ = T - default: - var T T - T = 0 - _ = T - } - - switch t := T(0); x := interfaceFunc(T(0) + t).(type) { - case *T, T, int: - var T T - T = 0 - _ = T - _ = x - case sync.Mutex, **T: - var T T - T = 0 - _ = T - default: - var T T - T = 0 - _ = T - } - } - - // Return statement. - return T(10), g(T(11)).c -} diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt deleted file mode 100644 index c9d0e09bf..000000000 --- a/tools/go_generics/generics_tests/all_stmts/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/generics_tests/all_stmts/output/output.go deleted file mode 100644 index a53d84535..000000000 --- a/tools/go_generics/generics_tests/all_stmts/output/output.go +++ /dev/null @@ -1,288 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "sync" -) - -func h(Q) { -} - -type s struct { - a, b int - c []int -} - -func g(Q) *s { - return &s{} -} - -func f() (Q, []int) { - // Branch. - goto T - goto R - - // Labeled. -T: - _ = Q(0) - - // Empty. -R: - ; - - // Assignment with definition. - a, b, c := Q(1), Q(2), Q(3) - _, _, _ = a, b, c - - // Assignment without definition. - g(Q(0)).a, g(Q(1)).b, c = int(Q(1)), int(Q(2)), Q(3) - _, _, _ = a, b, c - - // Block. - { - var T Q - T = 0 - _ = T - } - - // Declarations. - type Type Q - const Const Q = 10 - var g1 func(Q, int, ...Q) (int, Q) - var v Q - var w = Q(0) - { - var T struct { - f []Q - } - _ = T - } - - // Defer. - defer g1(Q(0), 1) - - // Expression. - h(v + w + Q(1)) - - // For statements. - for i := Q(0); i < Q(10); i++ { - var T func(int) Q - v := T(0) - _ = v - } - - for { - var T func(int) Q - v := T(0) - _ = v - } - - // Go. - go g1(Q(0), 1) - - // If statements. - if a != Q(1) { - var T func(int) Q - v := T(0) - _ = v - } - - if a := Q(0); a != Q(1) { - var T func(int) Q - v := T(0) - _ = v - } - - if a := Q(0); a != Q(1) { - var T func(int) Q - v := T(0) - _ = v - } else if b := Q(0); b != Q(1) { - var T func(int) Q - v := T(0) - _ = v - } else if T := Q(0); T != 1 { - T++ - } else { - T-- - } - - if a := Q(0); a != Q(1) { - var T func(int) Q - v := T(0) - _ = v - } else { - var T func(int) Q - v := T(0) - _ = v - } - - // Inc/Dec statements. - (*(*Q)(nil))++ - (*(*Q)(nil))-- - - // Range statements. - for g(Q(0)).a, g(Q(1)).b = range g(Q(10)).c { - var d Q - _ = d - } - - for T, b := range g(Q(10)).c { - _ = T - _ = b - } - - // Select statement. - { - var fch func(Q) chan int - - select { - case <-fch(Q(30)): - var T Q - T = 0 - _ = T - default: - var T Q - T = 0 - _ = T - case T := <-fch(Q(30)): - T = 0 - _ = T - case g(Q(0)).a = <-fch(Q(30)): - var T Q - T = 0 - _ = T - case fch(Q(30)) <- int(Q(0)): - var T Q - T = 0 - _ = T - } - } - - // Send statements. - { - var ch chan Q - var fch func(Q) chan int - - ch <- Q(0) - fch(Q(1)) <- g(Q(10)).a - } - - // Switch statements. - { - var a Q - var b int - switch { - case a == Q(0): - var T Q - T = 0 - _ = T - case a < Q(0), b < g(Q(10)).a: - var T Q - T = 0 - _ = T - default: - var T Q - T = 0 - _ = T - } - } - - switch Q(g(Q(10)).a) { - case Q(0): - var T Q - T = 0 - _ = T - case Q(1), Q(g(Q(10)).a): - var T Q - T = 0 - _ = T - default: - var T Q - T = 0 - _ = T - } - - switch b := g(Q(10)); Q(b.a) + Q(10) { - case Q(0): - var T Q - T = 0 - _ = T - case Q(1), Q(g(Q(10)).a): - var T Q - T = 0 - _ = T - default: - var T Q - T = 0 - _ = T - } - - // Type switch statements. - { - var interfaceFunc func(Q) interface{} - - switch interfaceFunc(Q(0)).(type) { - case *Q, Q, int: - var T Q - T = 0 - _ = T - case sync.Mutex, **Q: - var T Q - T = 0 - _ = T - default: - var T Q - T = 0 - _ = T - } - - switch x := interfaceFunc(Q(0)).(type) { - case *Q, Q, int: - var T Q - T = 0 - _ = T - _ = x - case sync.Mutex, **Q: - var T Q - T = 0 - _ = T - default: - var T Q - T = 0 - _ = T - } - - switch t := Q(0); x := interfaceFunc(Q(0) + t).(type) { - case *Q, Q, int: - var T Q - T = 0 - _ = T - _ = x - case sync.Mutex, **Q: - var T Q - T = 0 - _ = T - default: - var T Q - T = 0 - _ = T - } - } - - // Return statement. - return Q(10), g(Q(11)).c -} diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/generics_tests/all_types/input.go deleted file mode 100644 index 3575d02ec..000000000 --- a/tools/go_generics/generics_tests/all_types/input.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import "./lib" - -type T int - -type newType struct { - a T - b lib.T - c *T - d (T) - e chan T - f <-chan T - g chan<- T - h []T - i [10]T - j map[T]T - k func(T, T) (T, T) - l interface { - f(T) - } - m struct { - T - a T - } -} - -func f(...T) { -} diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/generics_tests/all_types/lib/lib.go deleted file mode 100644 index 988786496..000000000 --- a/tools/go_generics/generics_tests/all_types/lib/lib.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lib - -type T int32 diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt deleted file mode 100644 index c9d0e09bf..000000000 --- a/tools/go_generics/generics_tests/all_types/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/generics_tests/all_types/output/output.go deleted file mode 100644 index 41fd147a1..000000000 --- a/tools/go_generics/generics_tests/all_types/output/output.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import "./lib" - -type newType struct { - a Q - b lib.T - c *Q - d (Q) - e chan Q - f <-chan Q - g chan<- Q - h []Q - i [10]Q - j map[Q]Q - k func(Q, Q) (Q, Q) - l interface { - f(Q) - } - m struct { - Q - a Q - } -} - -func f(...Q) { -} diff --git a/tools/go_generics/generics_tests/anon/input.go b/tools/go_generics/generics_tests/anon/input.go deleted file mode 100644 index 44086d522..000000000 --- a/tools/go_generics/generics_tests/anon/input.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -type T interface { - Apply(T) T -} - -type Foo struct { - T - Bar map[string]T `json:"bar,omitempty"` -} - -type Baz struct { - T someTypeNotT -} - -func (f Foo) GetBar(name string) T { - b, ok := f.Bar[name] - if ok { - b = f.Apply(b) - } else { - b = f.T - } - return b -} - -func foobar() { - a := Baz{} - a.T = 0 // should not be renamed, this is a limitation - - b := otherpkg.UnrelatedType{} - b.T = 0 // should not be renamed, this is a limitation -} diff --git a/tools/go_generics/generics_tests/anon/opts.txt b/tools/go_generics/generics_tests/anon/opts.txt deleted file mode 100644 index a5e9d26de..000000000 --- a/tools/go_generics/generics_tests/anon/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q -suffix=New -anon diff --git a/tools/go_generics/generics_tests/anon/output/output.go b/tools/go_generics/generics_tests/anon/output/output.go deleted file mode 100644 index 160cddf79..000000000 --- a/tools/go_generics/generics_tests/anon/output/output.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -type FooNew struct { - Q - Bar map[string]Q `json:"bar,omitempty"` -} - -type BazNew struct { - T someTypeNotT -} - -func (f FooNew) GetBar(name string) Q { - b, ok := f.Bar[name] - if ok { - b = f.Apply(b) - } else { - b = f.Q - } - return b -} - -func foobarNew() { - a := BazNew{} - a.Q = 0 // should not be renamed, this is a limitation - - b := otherpkg.UnrelatedType{} - b.Q = 0 // should not be renamed, this is a limitation -} diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/generics_tests/consts/input.go deleted file mode 100644 index 04b95fcc6..000000000 --- a/tools/go_generics/generics_tests/consts/input.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -const c1 = 10 -const x, y, z = 100, 200, 300 -const v float32 = 1.0 + 2.0 -const s = "abc" -const ( - A = 10 - B, C, D = 10, 20, 30 - S = "abc" - T, U, V string = "abc", "def", "ghi" -) diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt deleted file mode 100644 index 4fb59dce8..000000000 --- a/tools/go_generics/generics_tests/consts/opts.txt +++ /dev/null @@ -1 +0,0 @@ --c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC" diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/generics_tests/consts/output/output.go deleted file mode 100644 index 18d316cc9..000000000 --- a/tools/go_generics/generics_tests/consts/output/output.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -const c1 = 20 -const x, y, z = 100, 200, 600 -const v float32 = 3.3 -const s = "def" -const ( - A = 20 - B, C, D = 10, 100, 30 - S = "def" - T, U, V string = "ABC", "def", "ghi" -) diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/generics_tests/imports/input.go deleted file mode 100644 index 0f032c2a1..000000000 --- a/tools/go_generics/generics_tests/imports/input.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -type T int - -var global T - -const ( - m = 0 - n = 0 -) diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt deleted file mode 100644 index 87324be79..000000000 --- a/tools/go_generics/generics_tests/imports/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/generics_tests/imports/output/output.go deleted file mode 100644 index 2488ca58c..000000000 --- a/tools/go_generics/generics_tests/imports/output/output.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - __generics_imported1 "mymathpath" - __generics_imported0 "sync" -) - -var global __generics_imported0.Mutex - -const ( - m = __generics_imported1.Uint64 - n = __generics_imported1.Uint32 -) diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/generics_tests/remove_typedef/input.go deleted file mode 100644 index cf632bae7..000000000 --- a/tools/go_generics/generics_tests/remove_typedef/input.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -func f(T) Q { - return Q{} -} - -type T struct{} - -type Q struct{} - -func (*T) f() { -} - -func (T) g() { -} - -func (*Q) f(T) T { - return T{} -} - -func (*Q) g(T) *T { - return nil -} diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt deleted file mode 100644 index 9c8ecaada..000000000 --- a/tools/go_generics/generics_tests/remove_typedef/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=U diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/generics_tests/remove_typedef/output/output.go deleted file mode 100644 index d44fd8e1c..000000000 --- a/tools/go_generics/generics_tests/remove_typedef/output/output.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -func f(U) Q { - return Q{} -} - -type Q struct{} - -func (*Q) f(U) U { - return U{} -} - -func (*Q) g(U) *U { - return nil -} diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/generics_tests/simple/input.go deleted file mode 100644 index 2a917f16c..000000000 --- a/tools/go_generics/generics_tests/simple/input.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -type T int - -var global T - -func f(_ T, a int) { -} - -func g(a T, b int) { - var c T - _ = c - - d := (*T)(nil) - _ = d -} - -type R struct { - T - a *T -} - -var ( - Z *T = (*T)(nil) -) - -const ( - X T = (T)(0) -) - -type Y T diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt deleted file mode 100644 index 7832ef66f..000000000 --- a/tools/go_generics/generics_tests/simple/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q -suffix=New diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/generics_tests/simple/output/output.go deleted file mode 100644 index 6bfa0b25b..000000000 --- a/tools/go_generics/generics_tests/simple/output/output.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -var globalNew Q - -func fNew(_ Q, a int) { -} - -func gNew(a Q, b int) { - var c Q - _ = c - - d := (*Q)(nil) - _ = d -} - -type RNew struct { - Q - a *Q -} - -var ( - ZNew *Q = (*Q)(nil) -) - -const ( - XNew Q = (Q)(0) -) - -type YNew Q diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD deleted file mode 100644 index 38caa3ce7..000000000 --- a/tools/go_generics/globals/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "globals", - srcs = [ - "globals_visitor.go", - "scope.go", - ], - stateify = False, - visibility = ["//tools/go_generics:__pkg__"], -) diff --git a/tools/go_generics/globals/globals_visitor.go b/tools/go_generics/globals/globals_visitor.go deleted file mode 100644 index 883f21ebe..000000000 --- a/tools/go_generics/globals/globals_visitor.go +++ /dev/null @@ -1,597 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package globals provides an AST visitor that calls the visit function for all -// global identifiers. -package globals - -import ( - "fmt" - - "go/ast" - "go/token" - "path/filepath" - "strconv" -) - -// globalsVisitor holds the state used while traversing the nodes of a file in -// search of globals. -// -// The visitor does two passes on the global declarations: the first one adds -// all globals to the global scope (since Go allows references to globals that -// haven't been declared yet), and the second one calls f() for the definition -// and uses of globals found in the first pass. -// -// The implementation correctly handles cases when globals are aliased by -// locals; in such cases, f() is not called. -type globalsVisitor struct { - // file is the file whose nodes are being visited. - file *ast.File - - // fset is the file set the file being visited belongs to. - fset *token.FileSet - - // f is the visit function to be called when a global symbol is reached. - f func(*ast.Ident, SymKind) - - // scope is the current scope as nodes are visited. - scope *scope - - // processAnon indicates whether we should process anonymous struct fields. - // It does not perform strict checking on parameter types that share the same name - // as the global type and therefore will rename them as well. - processAnon bool -} - -// unexpected is called when an unexpected node appears in the AST. It dumps -// the location of the associated token and panics because this should only -// happen when there is a bug in the traversal code. -func (v *globalsVisitor) unexpected(p token.Pos) { - panic(fmt.Sprintf("Unable to parse at %v", v.fset.Position(p))) -} - -// pushScope creates a new scope and pushes it to the top of the scope stack. -func (v *globalsVisitor) pushScope() { - v.scope = newScope(v.scope) -} - -// popScope removes the scope created by the last call to pushScope. -func (v *globalsVisitor) popScope() { - v.scope = v.scope.outer -} - -// visitType is called when an expression is known to be a type, for example, -// on the first argument of make(). It visits all children nodes and reports -// any globals. -func (v *globalsVisitor) visitType(ge ast.Expr) { - switch e := ge.(type) { - case *ast.Ident: - if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() { - v.f(e, s.kind) - } - - case *ast.SelectorExpr: - id := GetIdent(e.X) - if id == nil { - v.unexpected(e.X.Pos()) - } - - case *ast.StarExpr: - v.visitType(e.X) - case *ast.ParenExpr: - v.visitType(e.X) - case *ast.ChanType: - v.visitType(e.Value) - case *ast.Ellipsis: - v.visitType(e.Elt) - case *ast.ArrayType: - v.visitExpr(e.Len) - v.visitType(e.Elt) - case *ast.MapType: - v.visitType(e.Key) - v.visitType(e.Value) - case *ast.StructType: - v.visitFields(e.Fields, KindUnknown) - case *ast.FuncType: - v.visitFields(e.Params, KindUnknown) - v.visitFields(e.Results, KindUnknown) - case *ast.InterfaceType: - v.visitFields(e.Methods, KindUnknown) - default: - v.unexpected(ge.Pos()) - } -} - -// visitFields visits all fields, and add symbols if kind isn't KindUnknown. -func (v *globalsVisitor) visitFields(l *ast.FieldList, kind SymKind) { - if l == nil { - return - } - - for _, f := range l.List { - if kind != KindUnknown { - for _, n := range f.Names { - v.scope.add(n.Name, kind, n.Pos()) - } - } - v.visitType(f.Type) - if f.Tag != nil { - tag := ast.NewIdent(f.Tag.Value) - v.f(tag, KindTag) - // Replace the tag if updated. - if tag.Name != f.Tag.Value { - f.Tag.Value = tag.Name - } - } - } -} - -// visitGenDecl is called when a generic declaration is encountered, for example, -// on variable, constant and type declarations. It adds all newly defined -// symbols to the current scope and reports them if the current scope is the -// global one. -func (v *globalsVisitor) visitGenDecl(d *ast.GenDecl) { - switch d.Tok { - case token.IMPORT: - case token.TYPE: - for _, gs := range d.Specs { - s := gs.(*ast.TypeSpec) - v.scope.add(s.Name.Name, KindType, s.Name.Pos()) - if v.scope.isGlobal() { - v.f(s.Name, KindType) - } - v.visitType(s.Type) - } - case token.CONST, token.VAR: - kind := KindConst - if d.Tok == token.VAR { - kind = KindVar - } - - for _, gs := range d.Specs { - s := gs.(*ast.ValueSpec) - if s.Type != nil { - v.visitType(s.Type) - } - - for _, e := range s.Values { - v.visitExpr(e) - } - - for _, n := range s.Names { - if v.scope.isGlobal() { - v.f(n, kind) - } - v.scope.add(n.Name, kind, n.Pos()) - } - } - default: - v.unexpected(d.Pos()) - } -} - -// isViableType determines if the given expression is a viable type expression, -// that is, if it could be interpreted as a type, for example, sync.Mutex, -// myType, func(int)int, as opposed to -1, 2 * 2, a + b, etc. -func (v *globalsVisitor) isViableType(expr ast.Expr) bool { - switch e := expr.(type) { - case *ast.Ident: - // This covers the plain identifier case. When we see it, we - // have to check if it resolves to a type; if the symbol is not - // known, we'll claim it's viable as a type. - s := v.scope.deepLookup(e.Name) - return s == nil || s.kind == KindType - - case *ast.ChanType, *ast.ArrayType, *ast.MapType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.Ellipsis: - // This covers the following cases: - // 1. ChanType: - // chan T - // <-chan T - // chan<- T - // 2. ArrayType: - // [Expr]T - // 3. MapType: - // map[T]U - // 4. StructType: - // struct { Fields } - // 5. FuncType: - // func(Fields)Returns - // 6. Interface: - // interface { Fields } - // 7. Ellipsis: - // ...T - return true - - case *ast.SelectorExpr: - // The only case in which an expression involving a selector can - // be a type is if it has the following form X.T, where X is an - // import, and T is a type exported by X. - // - // There's no way to know whether T is a type because we don't - // parse imports. So we just claim that this is a viable type; - // it doesn't affect the general result because we don't visit - // imported symbols. - id := GetIdent(e.X) - if id == nil { - return false - } - - s := v.scope.deepLookup(id.Name) - return s != nil && s.kind == KindImport - - case *ast.StarExpr: - // This covers the *T case. The expression is a viable type if - // T is. - return v.isViableType(e.X) - - case *ast.ParenExpr: - // This covers the (T) case. The expression is a viable type if - // T is. - return v.isViableType(e.X) - - default: - return false - } -} - -// visitCallExpr visits a "call expression" which can be either a -// function/method call (e.g., f(), pkg.f(), obj.f(), etc.) call or a type -// conversion (e.g., int32(1), (*sync.Mutex)(ptr), etc.). -func (v *globalsVisitor) visitCallExpr(e *ast.CallExpr) { - if v.isViableType(e.Fun) { - v.visitType(e.Fun) - } else { - v.visitExpr(e.Fun) - } - - // If the function being called is new or make, the first argument is - // a type, so it needs to be visited as such. - first := 0 - if id := GetIdent(e.Fun); id != nil && (id.Name == "make" || id.Name == "new") { - if len(e.Args) > 0 { - v.visitType(e.Args[0]) - } - first = 1 - } - - for i := first; i < len(e.Args); i++ { - v.visitExpr(e.Args[i]) - } -} - -// visitExpr visits all nodes of an expression, and reports any globals that it -// finds. -func (v *globalsVisitor) visitExpr(ge ast.Expr) { - switch e := ge.(type) { - case nil: - case *ast.Ident: - if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() { - v.f(e, s.kind) - } - - case *ast.BasicLit: - case *ast.CompositeLit: - v.visitType(e.Type) - for _, ne := range e.Elts { - v.visitExpr(ne) - } - case *ast.FuncLit: - v.pushScope() - v.visitFields(e.Type.Params, KindParameter) - v.visitFields(e.Type.Results, KindResult) - v.visitBlockStmt(e.Body) - v.popScope() - - case *ast.BinaryExpr: - v.visitExpr(e.X) - v.visitExpr(e.Y) - - case *ast.CallExpr: - v.visitCallExpr(e) - - case *ast.IndexExpr: - v.visitExpr(e.X) - v.visitExpr(e.Index) - - case *ast.KeyValueExpr: - v.visitExpr(e.Value) - - case *ast.ParenExpr: - v.visitExpr(e.X) - - case *ast.SelectorExpr: - v.visitExpr(e.X) - if v.processAnon { - v.visitExpr(e.Sel) - } - - case *ast.SliceExpr: - v.visitExpr(e.X) - v.visitExpr(e.Low) - v.visitExpr(e.High) - v.visitExpr(e.Max) - - case *ast.StarExpr: - v.visitExpr(e.X) - - case *ast.TypeAssertExpr: - v.visitExpr(e.X) - if e.Type != nil { - v.visitType(e.Type) - } - - case *ast.UnaryExpr: - v.visitExpr(e.X) - - default: - v.unexpected(ge.Pos()) - } -} - -// GetIdent returns the identifier associated with the given expression by -// removing parentheses if needed. -func GetIdent(expr ast.Expr) *ast.Ident { - switch e := expr.(type) { - case *ast.Ident: - return e - case *ast.ParenExpr: - return GetIdent(e.X) - default: - return nil - } -} - -// visitStmt visits all nodes of a statement, and reports any globals that it -// finds. It also adds to the current scope new symbols defined/declared. -func (v *globalsVisitor) visitStmt(gs ast.Stmt) { - switch s := gs.(type) { - case nil, *ast.BranchStmt, *ast.EmptyStmt: - case *ast.AssignStmt: - for _, e := range s.Rhs { - v.visitExpr(e) - } - - // We visit the LHS after the RHS because the symbols we'll - // potentially add to the table aren't meant to be visible to - // the RHS. - for _, e := range s.Lhs { - if s.Tok == token.DEFINE { - if n := GetIdent(e); n != nil { - v.scope.add(n.Name, KindVar, n.Pos()) - } - } - v.visitExpr(e) - } - - case *ast.BlockStmt: - v.visitBlockStmt(s) - - case *ast.DeclStmt: - v.visitGenDecl(s.Decl.(*ast.GenDecl)) - - case *ast.DeferStmt: - v.visitCallExpr(s.Call) - - case *ast.ExprStmt: - v.visitExpr(s.X) - - case *ast.ForStmt: - v.pushScope() - v.visitStmt(s.Init) - v.visitExpr(s.Cond) - v.visitStmt(s.Post) - v.visitBlockStmt(s.Body) - v.popScope() - - case *ast.GoStmt: - v.visitCallExpr(s.Call) - - case *ast.IfStmt: - v.pushScope() - v.visitStmt(s.Init) - v.visitExpr(s.Cond) - v.visitBlockStmt(s.Body) - v.visitStmt(s.Else) - v.popScope() - - case *ast.IncDecStmt: - v.visitExpr(s.X) - - case *ast.LabeledStmt: - v.visitStmt(s.Stmt) - - case *ast.RangeStmt: - v.pushScope() - v.visitExpr(s.X) - if s.Tok == token.DEFINE { - if n := GetIdent(s.Key); n != nil { - v.scope.add(n.Name, KindVar, n.Pos()) - } - - if n := GetIdent(s.Value); n != nil { - v.scope.add(n.Name, KindVar, n.Pos()) - } - } - v.visitExpr(s.Key) - v.visitExpr(s.Value) - v.visitBlockStmt(s.Body) - v.popScope() - - case *ast.ReturnStmt: - for _, r := range s.Results { - v.visitExpr(r) - } - - case *ast.SelectStmt: - for _, ns := range s.Body.List { - c := ns.(*ast.CommClause) - - v.pushScope() - v.visitStmt(c.Comm) - for _, bs := range c.Body { - v.visitStmt(bs) - } - v.popScope() - } - - case *ast.SendStmt: - v.visitExpr(s.Chan) - v.visitExpr(s.Value) - - case *ast.SwitchStmt: - v.pushScope() - v.visitStmt(s.Init) - v.visitExpr(s.Tag) - for _, ns := range s.Body.List { - c := ns.(*ast.CaseClause) - v.pushScope() - for _, ce := range c.List { - v.visitExpr(ce) - } - for _, bs := range c.Body { - v.visitStmt(bs) - } - v.popScope() - } - v.popScope() - - case *ast.TypeSwitchStmt: - v.pushScope() - v.visitStmt(s.Init) - v.visitStmt(s.Assign) - for _, ns := range s.Body.List { - c := ns.(*ast.CaseClause) - v.pushScope() - for _, ce := range c.List { - v.visitType(ce) - } - for _, bs := range c.Body { - v.visitStmt(bs) - } - v.popScope() - } - v.popScope() - - default: - v.unexpected(gs.Pos()) - } -} - -// visitBlockStmt visits all statements in the block, adding symbols to a newly -// created scope. -func (v *globalsVisitor) visitBlockStmt(s *ast.BlockStmt) { - v.pushScope() - for _, c := range s.List { - v.visitStmt(c) - } - v.popScope() -} - -// visitFuncDecl is called when a function or method declaration is encountered. -// it creates a new scope for the function [optional] receiver, parameters and -// results, and visits all children nodes. -func (v *globalsVisitor) visitFuncDecl(d *ast.FuncDecl) { - // We don't report methods. - if d.Recv == nil { - v.f(d.Name, KindFunction) - } - - v.pushScope() - v.visitFields(d.Recv, KindReceiver) - v.visitFields(d.Type.Params, KindParameter) - v.visitFields(d.Type.Results, KindResult) - if d.Body != nil { - v.visitBlockStmt(d.Body) - } - v.popScope() -} - -// globalsFromDecl is called in the first, and adds symbols to global scope. -func (v *globalsVisitor) globalsFromGenDecl(d *ast.GenDecl) { - switch d.Tok { - case token.IMPORT: - for _, gs := range d.Specs { - s := gs.(*ast.ImportSpec) - if s.Name == nil { - str, _ := strconv.Unquote(s.Path.Value) - v.scope.add(filepath.Base(str), KindImport, s.Path.Pos()) - } else if s.Name.Name != "_" { - v.scope.add(s.Name.Name, KindImport, s.Name.Pos()) - } - } - case token.TYPE: - for _, gs := range d.Specs { - s := gs.(*ast.TypeSpec) - v.scope.add(s.Name.Name, KindType, s.Name.Pos()) - } - case token.CONST, token.VAR: - kind := KindConst - if d.Tok == token.VAR { - kind = KindVar - } - - for _, s := range d.Specs { - for _, n := range s.(*ast.ValueSpec).Names { - v.scope.add(n.Name, kind, n.Pos()) - } - } - default: - v.unexpected(d.Pos()) - } -} - -// visit implements the visiting of globals. It does performs the two passes -// described in the description of the globalsVisitor struct. -func (v *globalsVisitor) visit() { - // Gather all symbols in the global scope. This excludes methods. - v.pushScope() - for _, gd := range v.file.Decls { - switch d := gd.(type) { - case *ast.GenDecl: - v.globalsFromGenDecl(d) - case *ast.FuncDecl: - if d.Recv == nil { - v.scope.add(d.Name.Name, KindFunction, d.Name.Pos()) - } - default: - v.unexpected(gd.Pos()) - } - } - - // Go through the contents of the declarations. - for _, gd := range v.file.Decls { - switch d := gd.(type) { - case *ast.GenDecl: - v.visitGenDecl(d) - case *ast.FuncDecl: - v.visitFuncDecl(d) - } - } -} - -// Visit traverses the provided AST and calls f() for each identifier that -// refers to global names. The global name must be defined in the file itself. -// -// The function f() is allowed to modify the identifier, for example, to rename -// uses of global references. -func Visit(fset *token.FileSet, file *ast.File, f func(*ast.Ident, SymKind), processAnon bool) { - v := globalsVisitor{ - fset: fset, - file: file, - f: f, - processAnon: processAnon, - } - - v.visit() -} diff --git a/tools/go_generics/globals/scope.go b/tools/go_generics/globals/scope.go deleted file mode 100644 index 96c965ea2..000000000 --- a/tools/go_generics/globals/scope.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package globals - -import ( - "go/token" -) - -// SymKind specifies the kind of a global symbol. For example, a variable, const -// function, etc. -type SymKind int - -// Constants for different kinds of symbols. -const ( - KindUnknown SymKind = iota - KindImport - KindType - KindVar - KindConst - KindFunction - KindReceiver - KindParameter - KindResult - KindTag -) - -type symbol struct { - kind SymKind - pos token.Pos - scope *scope -} - -type scope struct { - outer *scope - syms map[string]*symbol -} - -func newScope(outer *scope) *scope { - return &scope{ - outer: outer, - syms: make(map[string]*symbol), - } -} - -func (s *scope) isGlobal() bool { - return s.outer == nil -} - -func (s *scope) lookup(n string) *symbol { - return s.syms[n] -} - -func (s *scope) deepLookup(n string) *symbol { - for x := s; x != nil; x = x.outer { - if sym := x.lookup(n); sym != nil { - return sym - } - } - return nil -} - -func (s *scope) add(name string, kind SymKind, pos token.Pos) { - s.syms[name] = &symbol{ - kind: kind, - pos: pos, - scope: s, - } -} diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh deleted file mode 100755 index 44b22db91..000000000 --- a/tools/go_generics/go_generics_unittest.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Bash "safe-mode": Treat command failures as fatal (even those that occur in -# pipes), and treat unset variables as errors. -set -eu -o pipefail - -# This file will be generated as a self-extracting shell script in order to -# eliminate the need for any runtime dependencies. The tarball at the end will -# include the go_generics binary, as well as a subdirectory named -# generics_tests. See the BUILD file for more information. -declare -r temp=$(mktemp -d) -function cleanup() { - rm -rf "${temp}" -} -# trap cleanup EXIT - -# Print message in "$1" then exit with status 1. -function die () { - echo "$1" 1>&2 - exit 1 -} - -# This prints the line number of __BUNDLE__ below, that should be the last line -# of this script. After that point, the concatenated archive will be the -# contents. -declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0` -tail -n+"${tgz}" $0 | tar -xzv -C "${temp}" - -# The target for the test. -declare -r binary="$(find ${temp} -type f -a -name go_generics)" -declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*" - -# Go through all test cases. -for f in ${input_dirs}; do - base=$(basename "${f}") - - # Run go_generics on the input file. - opts=$(head -n 1 ${f}/opts.txt) - out="${f}/output/generated.go" - expected="${f}/output/output.go" - ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\"" - - # Compare the outputs. - diff ${expected} ${out} - if [ $? -ne 0 ]; then - echo "Expected:" - cat ${expected} - echo "Actual:" - cat ${out} - die "Actual output is different from expected for test \"${base}\"" - fi -done - -echo "PASS" -exit 0 -__BUNDLE__ diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD deleted file mode 100644 index 2fd5a200d..000000000 --- a/tools/go_generics/go_merge/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "go_merge", - srcs = ["main.go"], - visibility = ["//:sandbox"], -) diff --git a/tools/go_generics/go_merge/main.go b/tools/go_generics/go_merge/main.go deleted file mode 100644 index f6a331123..000000000 --- a/tools/go_generics/go_merge/main.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "bytes" - "flag" - "fmt" - "go/ast" - "go/format" - "go/parser" - "go/token" - "io/ioutil" - "os" - "path/filepath" - "strconv" -) - -var ( - output = flag.String("o", "", "output `file`") -) - -func fatalf(s string, args ...interface{}) { - fmt.Fprintf(os.Stderr, s, args...) - os.Exit(1) -} - -func main() { - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s [options] <input1> [<input2> ...]\n", os.Args[0]) - flag.PrintDefaults() - } - - flag.Parse() - if *output == "" || len(flag.Args()) == 0 { - flag.Usage() - os.Exit(1) - } - - // Load all files. - files := make(map[string]*ast.File) - fset := token.NewFileSet() - var name string - for _, fname := range flag.Args() { - f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors) - if err != nil { - fatalf("%v\n", err) - } - - files[fname] = f - if name == "" { - name = f.Name.Name - } else if name != f.Name.Name { - fatalf("Expected '%s' for package name instead of '%s'.\n", name, f.Name.Name) - } - } - - // Merge all files into one. - pkg := &ast.Package{ - Name: name, - Files: files, - } - f := ast.MergePackageFiles(pkg, ast.FilterUnassociatedComments|ast.FilterFuncDuplicates|ast.FilterImportDuplicates) - - // Create a new declaration slice with all imports at the top, merging any - // redundant imports. - imports := make(map[string]*ast.ImportSpec) - var anonImports []*ast.ImportSpec - for _, d := range f.Decls { - if g, ok := d.(*ast.GenDecl); ok && g.Tok == token.IMPORT { - for _, s := range g.Specs { - i := s.(*ast.ImportSpec) - p, _ := strconv.Unquote(i.Path.Value) - var n string - if i.Name == nil { - n = filepath.Base(p) - } else { - n = i.Name.Name - } - if n == "_" { - anonImports = append(anonImports, i) - } else { - if i2, ok := imports[n]; ok { - if first, second := i.Path.Value, i2.Path.Value; first != second { - fatalf("Conflicting paths for import name '%s': '%s' vs. '%s'\n", n, first, second) - } - } else { - imports[n] = i - } - } - } - } - } - newDecls := make([]ast.Decl, 0, len(f.Decls)) - if l := len(imports) + len(anonImports); l > 0 { - // Non-NoPos Lparen is needed for Go to recognize more than one spec in - // ast.GenDecl.Specs. - d := &ast.GenDecl{ - Tok: token.IMPORT, - Lparen: token.NoPos + 1, - Specs: make([]ast.Spec, 0, l), - } - for _, i := range imports { - d.Specs = append(d.Specs, i) - } - for _, i := range anonImports { - d.Specs = append(d.Specs, i) - } - newDecls = append(newDecls, d) - } - for _, d := range f.Decls { - if g, ok := d.(*ast.GenDecl); !ok || g.Tok != token.IMPORT { - newDecls = append(newDecls, d) - } - } - f.Decls = newDecls - - // Write the output file. - var buf bytes.Buffer - if err := format.Node(&buf, fset, f); err != nil { - fatalf("%v\n", err) - } - - if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil { - fatalf("%v\n", err) - } -} diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go deleted file mode 100644 index 148dc7216..000000000 --- a/tools/go_generics/imports.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "bytes" - "fmt" - "go/ast" - "go/format" - "go/parser" - "go/token" - "strconv" - - "gvisor.dev/gvisor/tools/go_generics/globals" -) - -type importedPackage struct { - newName string - path string -} - -// updateImportIdent modifies the given import identifier with the new name -// stored in the used map. If the identifier doesn't exist in the used map yet, -// a new name is generated and inserted into the map. -func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[string]*importedPackage) error { - importName := id.Name - - // If the name is already in the table, just use the new name. - m := used[importName] - if m != nil { - id.Name = m.newName - return nil - } - - // Create a new entry in the used map. - path := imports[importName] - if path == "" { - return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig) - } - - m = &importedPackage{ - newName: fmt.Sprintf("__generics_imported%d", len(used)), - path: strconv.Quote(path), - } - used[importName] = m - - id.Name = m.newName - - return nil -} - -// convertExpression creates a new string that is a copy of the input one with -// all imports references renamed to the names in the "used" map. If the -// referenced import isn't in "used" yet, a new one is created based on the path -// in "imports" and stored in "used". For example, if string s is -// "math.MaxUint32-math.MaxUint16+10", it would be converted to -// "x.MaxUint32-x.MathUint16+10", where x is a generated name. -func convertExpression(s string, imports mapValue, used map[string]*importedPackage) (string, error) { - // Parse the expression in the input string. - expr, err := parser.ParseExpr(s) - if err != nil { - return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err) - } - - // Go through the AST and update references. - var retErr error - ast.Inspect(expr, func(n ast.Node) bool { - switch x := n.(type) { - case *ast.SelectorExpr: - if id := globals.GetIdent(x.X); id != nil { - if err := updateImportIdent(s, imports, id, used); err != nil { - retErr = err - } - return false - } - } - return true - }) - if retErr != nil { - return "", retErr - } - - // Convert the modified AST back to a string. - fset := token.NewFileSet() - var buf bytes.Buffer - if err := format.Node(&buf, fset, expr); err != nil { - return "", err - } - - return string(buf.Bytes()), nil -} - -// updateImports replaces all maps in the input slice with copies where the -// mapped values have had all references to imported packages renamed to -// generated names. It also returns an import declaration for all the renamed -// import packages. -// -// For example, if the input maps contains A=math.B and C=math.D, the updated -// maps will instead contain A=__generics_imported0.B and -// C=__generics_imported0.C, and the 'import __generics_imported0 "math"' would -// be returned as the import declaration. -func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) { - importsUsed := make(map[string]*importedPackage) - - // Update all maps. - for i, m := range maps { - newMap := make(mapValue) - for n, e := range m { - updated, err := convertExpression(e, imports, importsUsed) - if err != nil { - return nil, err - } - - newMap[n] = updated - } - maps[i] = newMap - } - - // Nothing else to do if no imports are used in the expressions. - if len(importsUsed) == 0 { - return nil, nil - } - - // Create spec array for each new import. - specs := make([]ast.Spec, 0, len(importsUsed)) - for _, i := range importsUsed { - specs = append(specs, &ast.ImportSpec{ - Name: &ast.Ident{Name: i.newName}, - Path: &ast.BasicLit{Value: i.path}, - }) - } - - return &ast.GenDecl{ - Tok: token.IMPORT, - Specs: specs, - Lparen: token.NoPos + 1, - }, nil -} diff --git a/tools/go_generics/remove.go b/tools/go_generics/remove.go deleted file mode 100644 index 568a6bbd3..000000000 --- a/tools/go_generics/remove.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "go/ast" - "go/token" -) - -type typeSet map[string]struct{} - -// isTypeOrPointerToType determines if the given AST expression represents a -// type or a pointer to a type that exists in the provided type set. -func isTypeOrPointerToType(set typeSet, expr ast.Expr, starCount int) bool { - switch e := expr.(type) { - case *ast.Ident: - _, ok := set[e.Name] - return ok - case *ast.StarExpr: - if starCount > 1 { - return false - } - return isTypeOrPointerToType(set, e.X, starCount+1) - case *ast.ParenExpr: - return isTypeOrPointerToType(set, e.X, starCount) - default: - return false - } -} - -// isMethodOf determines if the given function declaration is a method of one -// of the types in the provided type set. To do that, it checks if the function -// has a receiver and that its type is either T or *T, where T is a type that -// exists in the set. This is per the spec: -// -// That parameter section must declare a single parameter, the receiver. Its -// type must be of the form T or *T (possibly using parentheses) where T is a -// type name. The type denoted by T is called the receiver base type; it must -// not be a pointer or interface type and it must be declared in the same -// package as the method. -func isMethodOf(set typeSet, f *ast.FuncDecl) bool { - // If the function doesn't have exactly one receiver, then it's - // definitely not a method. - if f.Recv == nil || len(f.Recv.List) != 1 { - return false - } - - return isTypeOrPointerToType(set, f.Recv.List[0].Type, 0) -} - -// removeTypeDefinitions removes the definition of all types contained in the -// provided type set. -func removeTypeDefinitions(set typeSet, d *ast.GenDecl) { - if d.Tok != token.TYPE { - return - } - - i := 0 - for _, gs := range d.Specs { - s := gs.(*ast.TypeSpec) - if _, ok := set[s.Name.Name]; !ok { - d.Specs[i] = gs - i++ - } - } - - d.Specs = d.Specs[:i] -} - -// removeTypes removes from the AST the definition of all types and their -// method sets that are contained in the provided type set. -func removeTypes(set typeSet, f *ast.File) { - // Go through the top-level declarations. - i := 0 - for _, decl := range f.Decls { - keep := true - switch d := decl.(type) { - case *ast.GenDecl: - countBefore := len(d.Specs) - removeTypeDefinitions(set, d) - keep = countBefore == 0 || len(d.Specs) > 0 - case *ast.FuncDecl: - keep = !isMethodOf(set, d) - } - - if keep { - f.Decls[i] = decl - i++ - } - } - - f.Decls = f.Decls[:i] -} diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD deleted file mode 100644 index 8a329dfc6..000000000 --- a/tools/go_generics/rules_tests/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -load("//tools:defs.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "instance", - out = "instance_test.go", - consts = { - "n": "20", - "m": "\"test\"", - "o": "math.MaxUint64", - }, - imports = { - "math": "math", - }, - package = "template_test", - template = ":test_template", - types = { - "t": "int", - }, -) - -go_template( - name = "test_template", - srcs = [ - "template.go", - ], - opt_consts = [ - "n", - "m", - "o", - ], - opt_types = ["t"], -) - -go_test( - name = "template_test", - srcs = [ - "instance_test.go", - "template_test.go", - ], -) diff --git a/tools/go_generics/rules_tests/template.go b/tools/go_generics/rules_tests/template.go deleted file mode 100644 index aace61da1..000000000 --- a/tools/go_generics/rules_tests/template.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package template - -type t float - -const ( - n t = 10.1 - m = "abc" - o = 0 -) - -func max(a, b t) t { - if a > b { - return a - } - return b -} - -func add(a t) t { - return a + n -} - -func getName() string { - return m -} - -func getMax() uint64 { - return o -} diff --git a/tools/go_generics/rules_tests/template_test.go b/tools/go_generics/rules_tests/template_test.go deleted file mode 100644 index b2a3446ef..000000000 --- a/tools/go_generics/rules_tests/template_test.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package template_test - -import ( - "math" - "testing" -) - -func TestMax(t *testing.T) { - var a int = max(10, 20) - if a != 20 { - t.Errorf("Bad result of max, got %v, want %v", a, 20) - } -} - -func TestIntConst(t *testing.T) { - var a int = add(10) - if a != 30 { - t.Errorf("Bad result of add, got %v, want %v", a, 30) - } -} - -func TestStrConst(t *testing.T) { - v := getName() - if v != "test" { - t.Errorf("Bad name, got %v, want %v", v, "test") - } -} - -func TestImport(t *testing.T) { - v := getMax() - if v != math.MaxUint64 { - t.Errorf("Bad max value, got %v, want %v", v, uint64(math.MaxUint64)) - } -} diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD deleted file mode 100644 index be49cf9c8..000000000 --- a/tools/go_marshal/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -licenses(["notice"]) - -go_binary( - name = "go_marshal", - srcs = ["main.go"], - visibility = [ - "//:sandbox", - ], - deps = [ - "//tools/go_marshal/gomarshal", - ], -) - -config_setting( - name = "marshal_config_verbose", - values = {"define": "gomarshal=verbose"}, -) diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md deleted file mode 100644 index 4886efddf..000000000 --- a/tools/go_marshal/README.md +++ /dev/null @@ -1,116 +0,0 @@ -This package implements the go_marshal utility. - -# Overview - -`go_marshal` is a code generation utility similar to `go_stateify` for -automatically generating code to marshal go data structures to memory. - -`go_marshal` attempts to improve on `binary.Write` and the sentry's -`binary.Marshal` by moving the go runtime reflection necessary to marshal a -struct to compile-time. - -`go_marshal` automatically generates implementations for `abi.Marshallable` and -`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall -implementations) can directly invoke `safemem.Reader.ReadToBlocks` and -`safemem.Writer.WriteFromBlocks`. Data structures that require custom -serialization will have manual implementations for these interfaces. - -Data structures can be flagged for code generation by adding a struct-level -comment `// +marshal`. - -# Usage - -See `defs.bzl`: a new rule is provided, `go_marshal`. - -Under the hood, the `go_marshal` rule is used to generate a file that will -appear in a Go target; the output file should appear explicitly in a srcs list. -For example (note that the above is the preferred method): - -``` -load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal") - -go_marshal( - name = "foo_abi", - srcs = ["foo.go"], - out = "foo_abi.go", - package = "foo", -) - -go_library( - name = "foo", - srcs = [ - "foo.go", - "foo_abi.go", - ], - ... -) -``` - -As part of the interface generation, `go_marshal` also generates some tests for -sanity checking the struct definitions for potential alignment issues, and a -simple round-trip test through Marshal/Unmarshal to verify the implementation. -These tests use reflection to verify properties of the ABI struct, and should be -considered part of the generated interfaces (but are too expensive to execute at -runtime). Ensure these tests run at some point. - -# Restrictions - -Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is -intended for ABI structs, which have these additional restrictions: - -- At the moment, `go_marshal` only supports struct declarations. - -- Structs are marshalled as packed types. This means no implicit padding is - inserted between fields shorter than the platform register size. For - alignment, manually insert padding fields. - -- Structs used with `go_marshal` must have a compile-time static size. This - means no dynamically sizes fields like slices or strings. Use statically - sized array (byte arrays for strings) instead. - -- No pointers, channel, map or function pointer fields, and no fields that are - arrays of these types. These don't make sense in an ABI data structure. - -- We could support opaque pointers as `uintptr`, but this is currently not - implemented. Implementing this would require handling the architecture - dependent native pointer size. - -- Fields must either be a primitive integer type (`byte`, - `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable. - -- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric - type. - -- `float*` fields are currently not supported, but could be if necessary. - -# Appendix - -## Working with Non-Packed Structs - -ABI structs must generally be packed types, meaning they should have no implicit -padding between short fields. However, if a field is tagged -`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower -mechanism to deal with potentially unaligned fields. - -Note that the non-packed property is inheritted by any other struct that embeds -this struct, since the `go_marshal` tool currently can't reason about alignments -for embedded structs that are not aligned. - -Because of this, it's generally best to avoid using `marshal:"unaligned"` and -insert explicit padding fields instead. - -## Modifying the `go_marshal` Tool - -The following are some guidelines for modifying the `go_marshal` tool: - -- The `go_marshal` tool currently does a single pass over all types requesting - code generation, in arbitrary order. This means the generated code can't - directly obtain information about embedded marshallable types at - compile-time. One way to work around this restriction is to add a new - Marshallable interface method providing this piece of information, and - calling it from the generated code. Use this sparingly, as we want to rely - on compile-time information as much as possible for performance. - -- No runtime reflection in the code generated for the marshallable interface. - The entire point of the tool is to avoid runtime reflection. The generated - tests may use reflection. diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD deleted file mode 100644 index c2a4d45c4..000000000 --- a/tools/go_marshal/analysis/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "analysis", - testonly = 1, - srcs = ["analysis_unsafe.go"], - visibility = [ - "//:sandbox", - ], -) diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go deleted file mode 100644 index 9a9a4f298..000000000 --- a/tools/go_marshal/analysis/analysis_unsafe.go +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package analysis implements common functionality used by generated -// go_marshal tests. -package analysis - -// All functions in this package are unsafe and are not intended for general -// consumption. They contain sharp edge cases and the caller is responsible for -// ensuring none of them are hit. Callers must be carefully to pass in only sane -// arguments. Failure to do so may cause panics at best and arbitrary memory -// corruption at worst. -// -// Never use outside of tests. - -import ( - "fmt" - "math/rand" - "reflect" - "testing" - "unsafe" -) - -// RandomizeValue assigns random value(s) to an abitrary type. This is intended -// for used with ABI structs from go_marshal, meaning the typical restrictions -// apply (fixed-size types, no pointers, maps, channels, etc), and should only -// be used on zeroed values to avoid overwriting pointers to active go objects. -// -// Internally, we populate the type with random data by doing an unsafe cast to -// access the underlying memory of the type and filling it as if it were a byte -// slice. This almost gets us what we want, but padding fields named "_" are -// normally not accessible, so we walk the type and recursively zero all "_" -// fields. -// -// Precondition: x must be a pointer. x must not contain any valid -// pointers to active go objects (pointer fields aren't allowed in ABI -// structs anyways), or we'd be violating the go runtime contract and -// the GC may malfunction. -func RandomizeValue(x interface{}) { - v := reflect.Indirect(reflect.ValueOf(x)) - if !v.CanSet() { - panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument") - } - - // Cast the underlying memory for the type into a byte slice. - var b []byte - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer. - hdr.Data = v.UnsafeAddr() - hdr.Len = int(v.Type().Size()) - hdr.Cap = hdr.Len - - // Fill the byte slice with random data, which in effect fills the type with - // random values. - n, err := rand.Read(b) - if err != nil || n != len(b) { - panic("unreachable") - } - - // Normally, padding fields are not accessible, so zero them out. - reflectZeroPaddingFields(v.Type(), b, false) -} - -// reflectZeroPaddingFields assigns zero values to padding fields for the value -// of type r, represented by the memory in data. Padding fields are defined as -// fields with the name "_". If zero is true, the immediate value itself is -// zeroed. In addition, the type is recursively scanned for padding fields in -// inner types. -// -// This is used for zeroing padding fields after calling RandomizeValue. -func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) { - if zero { - for i, _ := range data { - data[i] = 0 - } - } - switch r.Kind() { - case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64: - // These types are explicitly allowed in an ABI type, but we don't need - // to recurse further as they're scalar types. - case reflect.Struct: - for i, numFields := 0, r.NumField(); i < numFields; i++ { - f := r.Field(i) - off := f.Offset - len := f.Type.Size() - window := data[off : off+len] - reflectZeroPaddingFields(f.Type, window, f.Name == "_") - } - case reflect.Array: - eLen := int(r.Elem().Size()) - if int(r.Size()) != eLen*r.Len() { - panic("Array has unexpected size?") - } - for i, n := 0, r.Len(); i < n; i++ { - reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false) - } - default: - panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind())) - - } -} - -// AlignmentCheck ensures the definition of the type represented by typ doesn't -// cause the go compiler to emit implicit padding between elements of the type -// (i.e. fields in a struct). -// -// AlignmentCheck doesn't explicitly recurse for embedded structs because any -// struct present in an ABI struct must also be Marshallable, and therefore -// they're aligned by definition (or their alignment check would have failed). -func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) { - switch typ.Kind() { - case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64: - // Primitive types are always considered well aligned. Primitive types - // that are fields in structs are checked independently, this branch - // exists to handle recursive calls to alignmentCheck. - case reflect.Struct: - xOff := 0 - nextXOff := 0 - skipNext := false - for i, numFields := 0, typ.NumField(); i < numFields; i++ { - xOff = nextXOff - f := typ.Field(i) - fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size()) - nextXOff = int(f.Offset + f.Type.Size()) - - if f.Name == "_" { - // Padding fields need not be aligned. - fmt.Printf("Padding field of type %v\n", f.Type) - continue - } - - if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" { - skipNext = true - continue - } - - if skipNext { - skipNext = false - fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name) - continue - } - - if xOff != int(f.Offset) { - implicitPad := int(f.Offset) - xOff - t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad) - } - } - - // Ensure structs end on a byte explicitly defined by the type. - if typ.NumField() > 0 && nextXOff != int(typ.Size()) { - implicitPad := int(typ.Size()) - nextXOff - f := typ.Field(typ.NumField() - 1) // Final field - t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.", - typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name) - } - case reflect.Array: - // Independent arrays are also always considered well aligned. We only - // need to worry about their alignment when they're embedded in structs, - // which we handle above. - default: - t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind()) - } - return true, uint64(typ.Size()) -} diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl deleted file mode 100644 index d79786a68..000000000 --- a/tools/go_marshal/defs.bzl +++ /dev/null @@ -1,64 +0,0 @@ -"""Marshal is a tool for generating marshalling interfaces for Go types.""" - -def _go_marshal_impl(ctx): - """Execute the go_marshal tool.""" - output = ctx.outputs.lib - output_test = ctx.outputs.test - - # Run the marshal command. - args = ["-output=%s" % output.path] - args += ["-pkg=%s" % ctx.attr.package] - args += ["-output_test=%s" % output_test.path] - - if ctx.attr.debug: - args += ["-debug"] - - args += ["--"] - for src in ctx.attr.srcs: - args += [f.path for f in src.files.to_list()] - ctx.actions.run( - inputs = ctx.files.srcs, - outputs = [output, output_test], - mnemonic = "GoMarshal", - progress_message = "go_marshal: %s" % ctx.label, - arguments = args, - executable = ctx.executable._tool, - ) - -# Generates save and restore logic from a set of Go files. -# -# Args: -# name: the name of the rule. -# srcs: the input source files. These files should include all structs in the -# package that need to be saved. -# imports: an optional list of extra, non-aliased, Go-style absolute import -# paths. -# out: the name of the generated file output. This must not conflict with any -# other files and must be added to the srcs of the relevant go_library. -# package: the package name for the input sources. -go_marshal = rule( - implementation = _go_marshal_impl, - attrs = { - "srcs": attr.label_list(mandatory = True, allow_files = True), - "imports": attr.string_list(mandatory = False), - "package": attr.string(mandatory = True), - "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"), - "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")), - }, - outputs = { - "lib": "%{name}_unsafe.go", - "test": "%{name}_test.go", - }, -) - -# marshal_deps are the dependencies requied by generated code. -marshal_deps = [ - "//tools/go_marshal/marshal", - "//pkg/safecopy", - "//pkg/usermem", -] - -# marshal_test_deps are required by test targets. -marshal_test_deps = [ - "//tools/go_marshal/analysis", -] diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD deleted file mode 100644 index 44cb33ae4..000000000 --- a/tools/go_marshal/gomarshal/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "gomarshal", - srcs = [ - "generator.go", - "generator_interfaces.go", - "generator_interfaces_array_newtype.go", - "generator_interfaces_primitive_newtype.go", - "generator_interfaces_struct.go", - "generator_tests.go", - "util.go", - ], - stateify = False, - visibility = [ - "//:sandbox", - ], - deps = ["//tools/tags"], -) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go deleted file mode 100644 index 729489de5..000000000 --- a/tools/go_marshal/gomarshal/generator.go +++ /dev/null @@ -1,433 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gomarshal implements the go_marshal code generator. See README.md. -package gomarshal - -import ( - "bytes" - "fmt" - "go/ast" - "go/parser" - "go/token" - "os" - "sort" - "strings" - - "gvisor.dev/gvisor/tools/tags" -) - -const ( - marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal" - safecopyImport = "gvisor.dev/gvisor/pkg/safecopy" - usermemImport = "gvisor.dev/gvisor/pkg/usermem" -) - -// List of identifiers we use in generated code that may conflict with a -// similarly-named source identifier. Abort gracefully when we see these to -// avoid potentially confusing compilation failures in generated code. -// -// This only applies to import aliases at the moment. All other identifiers -// are qualified by a receiver argument, since they're struct fields. -// -// All recievers are single letters, so we don't allow import aliases to be a -// single letter. -var badIdents = []string{ - "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len", - "ptr", "src", "srcs", "task", "val", - // All single-letter identifiers. -} - -// Constructed fromt badIdents in init(). -var badIdentsMap map[string]struct{} - -func init() { - badIdentsMap = make(map[string]struct{}) - for _, ident := range badIdents { - badIdentsMap[ident] = struct{}{} - } -} - -// Generator drives code generation for a single invocation of the go_marshal -// utility. -// -// The Generator holds arguments passed to the tool, and drives parsing, -// processing and code Generator for all types marked with +marshal declared in -// the input files. -// -// See Generator.run() as the entry point. -type Generator struct { - // Paths to input go source files. - inputs []string - // Output file to write generated go source. - output *os.File - // Output file to write generated tests. - outputTest *os.File - // Package name for the generated file. - pkg string - // Set of extra packages to import in the generated file. - imports *importTable -} - -// NewGenerator creates a new code Generator. -func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) { - f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) - } - fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) - } - g := Generator{ - inputs: srcs, - output: f, - outputTest: fTest, - pkg: pkg, - imports: newImportTable(), - } - for _, i := range imports { - // All imports on the extra imports list are unconditionally marked as - // used, so that they're always added to the generated code. - g.imports.add(i).markUsed() - } - - // The following imports may or may not be used by the generated code, - // depending on what's required for the target types. Don't mark these as - // used by default. - g.imports.add("io") - g.imports.add("reflect") - g.imports.add("runtime") - g.imports.add("unsafe") - g.imports.add(marshalImport) - g.imports.add(safecopyImport) - g.imports.add(usermemImport) - - return &g, nil -} - -// writeHeader writes the header for the generated source file. The header -// includes the package name, package level comments and import statements. -func (g *Generator) writeHeader() error { - var b sourceBuffer - b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") - - // Emit build tags. - if t := tags.Aggregate(g.inputs); len(t) > 0 { - b.emit(strings.Join(t.Lines(), "\n")) - b.emit("\n\n") - } - - // Package header. - b.emit("package %s\n\n", g.pkg) - if err := b.write(g.output); err != nil { - return err - } - - return g.imports.write(g.output) -} - -// writeTypeChecks writes a statement to force the compiler to perform a type -// check for all Marshallable types referenced by the generated code. -func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { - if len(ms) == 0 { - return nil - } - - msl := make([]string, 0, len(ms)) - for m, _ := range ms { - msl = append(msl, m) - } - sort.Strings(msl) - - var buf bytes.Buffer - fmt.Fprint(&buf, "// Marshallable types used by this file.\n") - - for _, m := range msl { - fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) - } - fmt.Fprint(&buf, "\n") - - _, err := fmt.Fprint(g.output, buf.String()) - return err -} - -// parse processes all input files passed this generator and produces a set of -// parsed go ASTs. -func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { - debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) - for _, path := range g.inputs { - debugf(" %s\n", path) - } - - files := make([]*ast.File, 0, len(g.inputs)) - fsets := make([]*token.FileSet, 0, len(g.inputs)) - - for _, path := range g.inputs { - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) - if err != nil { - // Not a valid input file? - return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err) - } - - if debugEnabled() { - debugf("AST for %q:\n", path) - ast.Print(fset, f) - } - - files = append(files, f) - fsets = append(fsets, fset) - } - - return files, fsets, nil -} - -// collectMarshallableTypes walks the parsed AST and collects a list of type -// declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { - var types []*ast.TypeSpec - for _, decl := range a.Decls { - gdecl, ok := decl.(*ast.GenDecl) - // Type declaration? - if !ok || gdecl.Tok != token.TYPE { - debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") - continue - } - // Does it have a comment? - if gdecl.Doc == nil { - debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") - continue - } - // Does the comment contain a "+marshal" line? - marked := false - for _, c := range gdecl.Doc.List { - if c.Text == "// +marshal" { - marked = true - break - } - } - if !marked { - debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") - continue - } - for _, spec := range gdecl.Specs { - // We already confirmed we're in a type declaration earlier, so this - // cast will succeed. - t := spec.(*ast.TypeSpec) - switch t.Type.(type) { - case *ast.StructType: - debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) - types = append(types, t) - continue - case *ast.Ident: // Newtype on primitive. - debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) - types = append(types, t) - continue - case *ast.ArrayType: // Newtype on array. - debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) - types = append(types, t) - continue - } - // A user specifically requested marshalling on this type, but we - // don't support it. - abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) - } - } - return types -} - -// collectImports collects all imports from all input source files. Some of -// these imports are copied to the generated output, if they're referenced by -// the generated code. -// -// collectImports de-duplicates imports while building the list, and ensures -// identifiers in the generated code don't conflict with any imported package -// names. -func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { - is := make(map[string]importStmt) - for _, decl := range a.Decls { - gdecl, ok := decl.(*ast.GenDecl) - // Import statement? - if !ok || gdecl.Tok != token.IMPORT { - continue - } - for _, spec := range gdecl.Specs { - i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) - debugf("Collected import '%s' as '%s'\n", i.path, i.name) - - // Make sure we have an import that doesn't use any local names that - // would conflict with identifiers in the generated code. - if len(i.name) == 1 { - abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) - } - if _, ok := badIdentsMap[i.name]; ok { - abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) - } - } - } - return is - -} - -func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - i := newInterfaceGenerator(t, fset) - switch ty := t.Type.(type) { - case *ast.StructType: - i.validateStruct(t, ty) - i.emitMarshallableForStruct(ty) - case *ast.Ident: - i.validatePrimitiveNewtype(ty) - i.emitMarshallableForPrimitiveNewtype(ty) - case *ast.ArrayType: - i.validateArrayNewtype(t.Name, ty) - // After validate, we can safely call arrayLen. - i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty)) - default: - // This should've been filtered out by collectMarshallabeTypes. - panic(fmt.Sprintf("Unexpected type %+v", ty)) - } - return i -} - -// generateOneTestSuite generates a test suite for the automatically generated -// implementations type t. -func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator { - i := newTestGenerator(t) - i.emitTests() - return i -} - -// Run is the entry point to code generation using g. -// -// Run parses all input source files specified in g and emits generated code. -func (g *Generator) Run() error { - // Parse our input source files into ASTs and token sets. - asts, fsets, err := g.parse() - if err != nil { - return err - } - - if len(asts) != len(fsets) { - panic("ASTs and FileSets don't match") - } - - // Map of imports in source files; key = local package name, value = import - // path. - is := make(map[string]importStmt) - for i, a := range asts { - // Collect all imports from the source files. We may need to copy some - // of these to the generated code if they're referenced. This has to be - // done before the loop below because we need to process all ASTs before - // we start requesting imports to be copied one by one as we encounter - // them in each generated source. - for name, i := range g.collectImports(a, fsets[i]) { - is[name] = i - } - } - - var impls []*interfaceGenerator - var ts []*testGenerator - // Set of Marshallable types referenced by generated code. - ms := make(map[string]struct{}) - for i, a := range asts { - // Collect type declarations marked for code generation and generate - // Marshallable interfaces. - for _, t := range g.collectMarshallableTypes(a, fsets[i]) { - impl := g.generateOne(t, fsets[i]) - // Collect Marshallable types referenced by the generated code. - for ref, _ := range impl.ms { - ms[ref] = struct{}{} - } - impls = append(impls, impl) - // Collect imports referenced by the generated code and add them to - // the list of imports we need to copy to the generated code. - for name, _ := range impl.is { - if !g.imports.markUsed(name) { - panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name)) - } - } - ts = append(ts, g.generateOneTestSuite(t)) - } - } - - // Write output file header. These include things like package name and - // import statements. - if err := g.writeHeader(); err != nil { - return err - } - - // Write type checks for referenced marshallable types to output file. - if err := g.writeTypeChecks(ms); err != nil { - return err - } - - // Write generated interfaces to output file. - for _, i := range impls { - if err := i.write(g.output); err != nil { - return err - } - } - - // Write generated tests to test file. - return g.writeTests(ts) -} - -// writeTests outputs tests for the generated interface implementations to a go -// source file. -func (g *Generator) writeTests(ts []*testGenerator) error { - var b sourceBuffer - b.emit("package %s\n\n", g.pkg) - if err := b.write(g.outputTest); err != nil { - return err - } - - // Collect and write test import statements. - imports := newImportTable() - for _, t := range ts { - imports.merge(t.imports) - } - - if err := imports.write(g.outputTest); err != nil { - return err - } - - // Write test functions. - - // If we didn't generate any Marshallable implementations, we can't just - // emit an empty test file, since that causes the build to fail with "no - // tests/benchmarks/examples found". Unfortunately we can't signal bazel to - // omit the entire package since the outputs are already defined before - // go-marshal is called. If we'd otherwise emit an empty test suite, emit an - // empty example instead. - if len(ts) == 0 { - b.reset() - b.emit("func ExampleEmptyTestSuite() {\n") - b.inIndent(func() { - b.emit("// This example is intentionally empty to ensure this file contains at least\n") - b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") - b.emit("// is marked marshallable, but emitting a test file with no entities results\n") - b.emit("// in a build failure.\n") - }) - b.emit("}\n") - return b.write(g.outputTest) - } - - for _, t := range ts { - if err := t.write(g.outputTest); err != nil { - return err - } - } - return nil -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go deleted file mode 100644 index 8babf61d2..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gomarshal - -import ( - "go/ast" - "go/token" -) - -// interfaceGenerator generates marshalling interfaces for a single type. -// -// getState is not thread-safe. -type interfaceGenerator struct { - sourceBuffer - - // The type we're serializing. - t *ast.TypeSpec - - // Receiver argument for generated methods. - r string - - // FileSet containing the tokens for the type we're processing. - f *token.FileSet - - // is records external packages referenced by the generated implementation. - is map[string]struct{} - - // ms records Marshallable types referenced by the generated implementation - // of t's interfaces. - ms map[string]struct{} - - // as records embedded fields in t that are potentially not packed. The key - // is the accessor for the field. - as map[string]struct{} -} - -// typeName returns the name of the type this g represents. -func (g *interfaceGenerator) typeName() string { - return g.t.Name.Name -} - -// newinterfaceGenerator creates a new interface generator. -func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - g := &interfaceGenerator{ - t: t, - r: receiverName(t), - f: fset, - is: make(map[string]struct{}), - ms: make(map[string]struct{}), - as: make(map[string]struct{}), - } - g.recordUsedMarshallable(g.typeName()) - return g -} - -func (g *interfaceGenerator) recordUsedMarshallable(m string) { - g.ms[m] = struct{}{} - -} - -func (g *interfaceGenerator) recordUsedImport(i string) { - g.is[i] = struct{}{} - -} - -func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { - g.as[fieldName] = struct{}{} -} - -// abortAt aborts the go_marshal tool with the given error message, with a -// reference position to the input source. Same as abortAt, but uses g to -// resolve p to position. -func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { - abortAt(g.f.Position(p), msg) -} - -// scalarSize returns the size of type identified by t. If t isn't a primitive -// type, the size isn't known at code generation time, and must be resolved via -// the marshal.Marshallable interface. -func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) { - switch t.Name { - case "int8", "uint8", "byte": - return 1, false - case "int16", "uint16": - return 2, false - case "int32", "uint32": - return 4, false - case "int64", "uint64": - return 8, false - default: - return 0, true - } -} - -func (g *interfaceGenerator) shift(bufVar string, n int) { - g.emit("%s = %s[%d:]\n", bufVar, bufVar, n) -} - -func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { - g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) -} - -// marshalScalar writes a single scalar to a byte slice. -func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { - switch typ { - case "int8", "uint8", "byte": - g.emit("%s[0] = byte(%s)\n", bufVar, accessor) - g.shift(bufVar, 1) - case "int16", "uint16": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor) - g.shift(bufVar, 2) - case "int32", "uint32": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor) - g.shift(bufVar, 4) - case "int64", "uint64": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor) - g.shift(bufVar, 8) - default: - g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) - g.shiftDynamic(bufVar, accessor) - } -} - -// unmarshalScalar reads a single scalar from a byte slice. -func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { - switch typ { - case "byte": - g.emit("%s = %s[0]\n", accessor, bufVar) - g.shift(bufVar, 1) - case "int8", "uint8": - g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) - g.shift(bufVar, 1) - case "int16", "uint16": - g.recordUsedImport("usermem") - g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) - g.shift(bufVar, 2) - case "int32", "uint32": - g.recordUsedImport("usermem") - g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) - g.shift(bufVar, 4) - case "int64", "uint64": - g.recordUsedImport("usermem") - g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) - g.shift(bufVar, 8) - default: - g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) - g.shiftDynamic(bufVar, accessor) - g.recordPotentiallyNonPackedField(accessor) - } -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go deleted file mode 100644 index da36d9305..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains the bits of the code generator specific to marshalling -// newtypes on arrays. - -package gomarshal - -import ( - "fmt" - "go/ast" -) - -func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) { - if a.Len == nil { - g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) - } - - if _, ok := a.Len.(*ast.BasicLit); !ok { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions")) - } - - if _, ok := a.Elt.(*ast.Ident); !ok { - g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) - } - - if arrayLen(a) <= 0 { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) - } -} - -func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) { - g.recordUsedImport("io") - g.recordUsedImport("marshal") - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - g.recordUsedImport("usermem") - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - if size, dynamic := g.scalarSize(elt); !dynamic { - g.emit("return %d\n", size*len) - } else { - g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len) - } - }) - g.emit("}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("for idx := 0; idx < %d; idx++ {\n", len) - g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") - }) - g.emit("}\n") - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("for idx := 0; idx < %d; idx++ {\n", len) - g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") - }) - g.emit("}\n") - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Array newtypes are always packed.\n") - g.emit("return true\n") - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") - - }) - g.emit("}\n\n") -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go deleted file mode 100644 index 159397825..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ /dev/null @@ -1,229 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains the bits of the code generator specific to marshalling -// newtypes on primitives. - -package gomarshal - -import ( - "fmt" - "go/ast" -) - -// marshalPrimitiveScalar writes a single primitive variable to a byte -// slice. -func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { - switch typ { - case "int8", "uint8", "byte": - g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) - case "int16", "uint16": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) - case "int32", "uint32": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) - case "int64", "uint64": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) - default: - g.emit("// Explicilty cast to the underlying type before dispatching to\n") - g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor) - g.emit("inner := (*%s)(%s)\n", typ, accessor) - g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) - } -} - -// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. -func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { - switch typ { - case "byte": - g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) - case "int8", "uint8": - g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) - case "int16", "uint16": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) - case "int32", "uint32": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) - case "int64", "uint64": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) - default: - g.emit("// Explicilty cast to the underlying type before dispatching to\n") - g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor) - g.emit("inner := (*%s)(%s)\n", typ, accessor) - g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) - } -} - -func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we provide - // suggestions for some disallowed types and reject them, then attempt - // to marshal any remaining types by invoking the marshal.Marshallable - // interface on them. If these types don't actually implement - // marshal.Marshallable, compilation of the generated code will fail - // with an appropriate error message. - return - case "int": - g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } -} - -// emitMarshallableForPrimitiveNewtype outputs code to implement the -// marshal.Marshallable interface for a newtype on a primitive. Primitive -// newtypes are always packed, so we can omit the various fallbacks required for -// non-packed structs. -func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { - g.recordUsedImport("io") - g.recordUsedImport("marshal") - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - g.recordUsedImport("usermem") - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - if size, dynamic := g.scalarSize(nt); !dynamic { - g.emit("return %d\n", size) - } else { - g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) - } - }) - g.emit("}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.marshalPrimitiveScalar(g.r, nt.Name, "dst") - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Scalar newtypes are always packed.\n") - g.emit("return true\n") - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") - - }) - g.emit("}\n\n") -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go deleted file mode 100644 index e66a38b2e..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ /dev/null @@ -1,450 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains the bits of the code generator specific to marshalling -// structs. - -package gomarshal - -import ( - "fmt" - "go/ast" - "strings" -) - -func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { - return fmt.Sprintf("%s.%s", g.r, n.Name) -} - -// areFieldsPackedExpression returns a go expression checking whether g.t's fields are -// packed. Returns "", false if g.t has no fields that may be potentially -// packed, otherwise returns <clause>, true, where <clause> is an expression -// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". -func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { - if len(g.as) == 0 { - return "", false - } - - cs := make([]string, 0, len(g.as)) - for accessor, _ := range g.as { - cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) - } - return strings.Join(cs, " && "), true -} - -// validateStruct ensures the type we're working with can be marshalled. These -// checks are done ahead of time and in one place so we can make assumptions -// later. -func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { - forEachStructField(st, func(f *ast.Field) { - if len(f.Names) == 0 { - g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") - } - }) - - forEachStructField(st, func(f *ast.Field) { - fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - g.validatePrimitiveNewtype(t) - }, - selector: func(_, _, _ *ast.Ident) { - // No validation to perform on selector fields. However this - // callback must still be provided. - }, - array: func(n, _ *ast.Ident, len int) { - g.validateArrayNewtype(n, f.Type.(*ast.ArrayType)) - }, - unhandled: func(_ *ast.Ident) { - g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) - }, - }.dispatch(f) - }) -} - -func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { - // Is g.t a packed struct without consideing field types? - thisPacked := true - forEachStructField(st, func(f *ast.Field) { - if f.Tag != nil { - if f.Tag.Value == "`marshal:\"unaligned\"`" { - if thisPacked { - debugfAt(g.f.Position(g.t.Pos()), - fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) - thisPacked = false - } - } - } - }) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - primitiveSize := 0 - var dynamicSizeTerms []string - - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) - } - }, - selector: func(n, tX, tSel *ast.Ident) { - tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) - g.recordUsedImport(tX.Name) - g.recordUsedMarshallable(tName) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) - }, - array: func(n, t *ast.Ident, len int) { - if len < 1 { - // Zero-length arrays should've been rejected by validate(). - panic("unreachable") - } - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size * len - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) - } - }, - }.dispatch) - g.emit("return %d", primitiveSize) - if len(dynamicSizeTerms) > 0 { - g.incIndent() - } - { - for _, d := range dynamicSizeTerms { - g.emitNoIndent(" +\n") - g.emit(d) - } - } - if len(dynamicSizeTerms) > 0 { - g.decIndent() - } - }) - g.emit("\n}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) - } - return - } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for idx := 0; idx < %d; idx++ {\n", size) - g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) - } - return - } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for idx := 0; idx < %d; idx++ {\n", size) - g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - expr, fieldsMaybePacked := g.areFieldsPackedExpression() - switch { - case !thisPacked: - g.emit("return false\n") - case fieldsMaybePacked: - g.emit("return %s\n", expr) - default: - g.emit("return true\n") - - } - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.MarshalBytes(dst)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) - } - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.UnmarshalBytes(src)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) - } - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.recordUsedImport("marshal") - g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("return err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.recordUsedImport("marshal") - g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("if err != nil {\n") - g.inIndent(func() { - g.emit("return err\n") - }) - g.emit("}\n") - - g.emit("%s.UnmarshalBytes(buf)\n", g.r) - g.emit("return nil\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast deserialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.recordUsedImport("io") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("n, err := w.Write(buf)\n") - g.emit("return int64(n), err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") -} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go deleted file mode 100644 index fd992e44a..000000000 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gomarshal - -import ( - "fmt" - "go/ast" - "io" - "strings" -) - -var standardImports = []string{ - "bytes", - "fmt", - "reflect", - "testing", - - "gvisor.dev/gvisor/tools/go_marshal/analysis", -} - -type testGenerator struct { - sourceBuffer - - // The type we're serializing. - t *ast.TypeSpec - - // Receiver argument for generated methods. - r string - - // Imports used by generated code. - imports *importTable - - // Import statement for the package declaring the type we generated code - // for. We need this to construct test instances for the type, since the - // tests aren't written in the same package. - decl *importStmt -} - -func newTestGenerator(t *ast.TypeSpec) *testGenerator { - g := &testGenerator{ - t: t, - r: receiverName(t), - imports: newImportTable(), - } - - for _, i := range standardImports { - g.imports.add(i).markUsed() - } - - return g -} - -func (g *testGenerator) typeName() string { - return g.t.Name.Name -} - -func (g *testGenerator) testFuncName(base string) string { - return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) -} - -func (g *testGenerator) inTestFunction(name string, body func()) { - g.emit("func %s(t *testing.T) {\n", g.testFuncName(name)) - g.inIndent(body) - g.emit("}\n\n") -} - -func (g *testGenerator) emitTestNonZeroSize() { - g.inTestFunction("TestSizeNonZero", func() { - g.emit("var x %v\n", g.typeName()) - g.emit("if x.SizeBytes() == 0 {\n") - g.inIndent(func() { - g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") - }) - g.emit("}\n") - }) -} - -func (g *testGenerator) emitTestSuspectAlignment() { - g.inTestFunction("TestSuspectAlignment", func() { - g.emit("var x %v\n", g.typeName()) - g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") - }) -} - -func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { - g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() { - g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName()) - g.emit("analysis.RandomizeValue(&x)\n\n") - - g.emit("buf := make([]byte, x.SizeBytes())\n") - g.emit("x.MarshalBytes(buf)\n") - g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n") - g.emit("x.MarshalUnsafe(bufUnsafe)\n\n") - - g.emit("y.UnmarshalBytes(buf)\n") - g.emit("if !reflect.DeepEqual(x, y) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") - }) - g.emit("}\n") - g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") - g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") - }) - g.emit("}\n\n") - - g.emit("z.UnmarshalUnsafe(buf)\n") - g.emit("if !reflect.DeepEqual(x, z) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") - }) - g.emit("}\n") - g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") - g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") - }) - g.emit("}\n") - }) -} - -func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { - g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { - g.emit("var x, y, yUnsafe %s\n", g.typeName()) - g.emit("analysis.RandomizeValue(&x)\n\n") - - g.emit("var buf bytes.Buffer\n\n") - - g.emit("x.WriteTo(&buf)\n") - g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") - g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") - - g.emit("if !reflect.DeepEqual(x, y) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") - }) - g.emit("}\n") - g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") - }) - g.emit("}\n") - }) -} - -func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { - g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { - g.emit("var x %s\n", g.typeName()) - g.emit("sizeFromConcrete := x.SizeBytes()\n") - g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") - g.inIndent(func() { - g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") - }) - g.emit("}\n") - }) -} - -func (g *testGenerator) emitTests() { - g.emitTestNonZeroSize() - g.emitTestSuspectAlignment() - g.emitTestMarshalUnmarshalPreservesData() - g.emitTestWriteToUnmarshalPreservesData() - g.emitTestSizeBytesOnTypedNilPtr() -} - -func (g *testGenerator) write(out io.Writer) error { - return g.sourceBuffer.write(out) -} diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go deleted file mode 100644 index a0936e013..000000000 --- a/tools/go_marshal/gomarshal/util.go +++ /dev/null @@ -1,418 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gomarshal - -import ( - "bytes" - "flag" - "fmt" - "go/ast" - "go/token" - "io" - "os" - "path" - "reflect" - "sort" - "strconv" - "strings" -) - -var debug = flag.Bool("debug", false, "enables debugging output") - -// receiverName returns an appropriate receiver name given a type spec. -func receiverName(t *ast.TypeSpec) string { - if len(t.Name.Name) < 1 { - // Zero length type name? - panic("unreachable") - } - return strings.ToLower(t.Name.Name[:1]) -} - -// kindString returns a user-friendly representation of an AST expr type. -func kindString(e ast.Expr) string { - switch e.(type) { - case *ast.Ident: - return "scalar" - case *ast.ArrayType: - return "array" - case *ast.StructType: - return "struct" - case *ast.StarExpr: - return "pointer" - case *ast.FuncType: - return "function" - case *ast.InterfaceType: - return "interface" - case *ast.MapType: - return "map" - case *ast.ChanType: - return "channel" - default: - return reflect.TypeOf(e).String() - } -} - -func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { - for _, field := range st.Fields.List { - fn(field) - } -} - -// fieldDispatcher is a collection of callbacks for handling different types of -// fields in a struct declaration. -type fieldDispatcher struct { - primitive func(n, t *ast.Ident) - selector func(n, tX, tSel *ast.Ident) - array func(n, t *ast.Ident, size int) - unhandled func(n *ast.Ident) -} - -// Precondition: a must have a literal for the array length. Consts and -// expressions are not allowed as array lengths, and should be rejected by the -// caller. -func arrayLen(a *ast.ArrayType) int { - if a.Len == nil { - // Probably a slice? Must be handled by caller. - panic("Nil array length in array type") - } - lenLit, ok := a.Len.(*ast.BasicLit) - if !ok { - panic("Array has non-literal for length") - } - len, err := strconv.Atoi(lenLit.Value) - if err != nil { - panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err)) - } - return len -} - -// Precondition: All dispatch callbacks that will be invoked must be -// provided. Embedded fields are not allowed, len(f.Names) >= 1. -func (fd fieldDispatcher) dispatch(f *ast.Field) { - // Each field declaration may actually be multiple declarations of the same - // type. For example, consider: - // - // type Point struct { - // x, y, z int - // } - // - // We invoke the call-backs once per such instance. Embedded fields are not - // allowed, and results in a panic. - if len(f.Names) < 1 { - panic("Precondition not met: attempted to dispatch on embedded field") - } - - for _, name := range f.Names { - switch v := f.Type.(type) { - case *ast.Ident: - fd.primitive(name, v) - case *ast.SelectorExpr: - fd.selector(name, v.X.(*ast.Ident), v.Sel) - case *ast.ArrayType: - switch t := v.Elt.(type) { - case *ast.Ident: - fd.array(name, t, arrayLen(v)) - default: - // Should be handled with a better error message during validate. - panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) - } - default: - fd.unhandled(name) - } - } -} - -// debugEnabled indicates whether debugging is enabled for gomarshal. -func debugEnabled() bool { - return *debug -} - -// abort aborts the go_marshal tool with the given error message. -func abort(msg string) { - if !strings.HasSuffix(msg, "\n") { - msg += "\n" - } - fmt.Print(msg) - os.Exit(1) -} - -// abortAt aborts the go_marshal tool with the given error message, with -// a reference position to the input source. -func abortAt(p token.Position, msg string) { - abort(fmt.Sprintf("%v:\n %s\n", p, msg)) -} - -// debugf conditionally prints a debug message. -func debugf(f string, a ...interface{}) { - if debugEnabled() { - fmt.Printf(f, a...) - } -} - -// debugfAt conditionally prints a debug message with a reference to a position -// in the input source. -func debugfAt(p token.Position, f string, a ...interface{}) { - if debugEnabled() { - fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) - } -} - -// emit generates a line of code in the output file. -// -// emit is a wrapper around writing a formatted string to the output -// buffer. emit can be invoked in one of two ways: -// -// (1) emit("some string") -// When emit is called with a single string argument, it is simply copied to -// the output buffer without any further formatting. -// (2) emit(fmtString, args...) -// emit can also be invoked in a similar fashion to *Printf() functions, -// where the first argument is a format string. -// -// Calling emit with a single argument that is not a string will result in a -// panic, as the caller's intent is ambiguous. -func emit(out io.Writer, indent int, a ...interface{}) { - const spacesPerIndentLevel = 4 - - if len(a) < 1 { - panic("emit() called with no arguments") - } - - if indent > 0 { - if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { - // Writing to the emit output should not fail. Typically the output - // is a byte.Buffer; writes to these never fail. - panic(err) - } - } - - first, ok := a[0].(string) - if !ok { - // First argument must be either the string to emit (case 1 from - // function-level comment), or a format string (case 2). - panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) - } - - if len(a) == 1 { - // Single string argument. Assume no formatting requested. - if _, err := fmt.Fprint(out, first); err != nil { - // Writing to out should not fail. - panic(err) - } - return - - } - - // Formatting requested. - if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { - // Writing to out should not fail. - panic(err) - } -} - -// sourceBuffer represents fragments of generated go source code. -// -// sourceBuffer provides a convenient way to build up go souce fragments in -// memory. May be safely zero-value initialized. Not thread-safe. -type sourceBuffer struct { - // Current indentation level. - indent int - - // Memory buffer containing contents while they're being generated. - b bytes.Buffer -} - -func (b *sourceBuffer) reset() { - b.indent = 0 - b.b.Reset() -} - -func (b *sourceBuffer) incIndent() { - b.indent++ -} - -func (b *sourceBuffer) decIndent() { - if b.indent <= 0 { - panic("decIndent() without matching incIndent()") - } - b.indent-- -} - -func (b *sourceBuffer) emit(a ...interface{}) { - emit(&b.b, b.indent, a...) -} - -func (b *sourceBuffer) emitNoIndent(a ...interface{}) { - emit(&b.b, 0 /*indent*/, a...) -} - -func (b *sourceBuffer) inIndent(body func()) { - b.incIndent() - body() - b.decIndent() -} - -func (b *sourceBuffer) write(out io.Writer) error { - _, err := fmt.Fprint(out, b.b.String()) - return err -} - -// Write implements io.Writer.Write. -func (b *sourceBuffer) Write(buf []byte) (int, error) { - return (b.b.Write(buf)) -} - -// importStmt represents a single import statement. -type importStmt struct { - // Local name of the imported package. - name string - // Import path. - path string - // Indicates whether the local name is an alias, or simply the final - // component of the path. - aliased bool - // Indicates whether this import was referenced by generated code. - used bool -} - -func newImport(p string) *importStmt { - name := path.Base(p) - return &importStmt{ - name: name, - path: p, - aliased: false, - } -} - -func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { - p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. - name := path.Base(p) - if name == "" || name == "/" || name == "." { - panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", - f.Position(spec.Path.Pos()), name)) - } - if spec.Name != nil { - name = spec.Name.Name - } - return &importStmt{ - name: name, - path: p, - aliased: spec.Name != nil, - } -} - -func (i *importStmt) String() string { - if i.aliased { - return fmt.Sprintf("%s \"%s\"", i.name, i.path) - } - return fmt.Sprintf("\"%s\"", i.path) -} - -func (i *importStmt) markUsed() { - i.used = true -} - -func (i *importStmt) equivalent(other *importStmt) bool { - return i.name == other.name && i.path == other.path && i.aliased == other.aliased -} - -// importTable represents a collection of importStmts. -type importTable struct { - // Map of imports and whether they should be copied to the output. - is map[string]*importStmt -} - -func newImportTable() *importTable { - return &importTable{ - is: make(map[string]*importStmt), - } -} - -// Merges import statements from other into i. Collisions in import statements -// result in a panic. -func (i *importTable) merge(other *importTable) { - for name, im := range other.is { - if dup, ok := i.is[name]; ok && !dup.equivalent(im) { - panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) - } - - i.is[name] = im - } -} - -func (i *importTable) addStmt(s *importStmt) *importStmt { - if old, ok := i.is[s.name]; ok && !old.equivalent(s) { - // A collision should always be between an import inserted by the - // go-marshal tool and an import from the original source file (assuming - // the original source file was valid). We could theoretically handle - // the collision by assigning a local name to our import. However, this - // would need to be plumbed throughout the generator. Given that - // collisions should be rare, simply panic on collision. - panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name)) - } - i.is[s.name] = s - return s -} - -func (i *importTable) add(s string) *importStmt { - n := newImport(s) - return i.addStmt(n) -} - -func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { - return i.addStmt(newImportFromSpec(spec, f)) -} - -// Marks the import named n as used. If no such import is in the table, returns -// false. -func (i *importTable) markUsed(n string) bool { - if n, ok := i.is[n]; ok { - n.markUsed() - return true - } - return false -} - -func (i *importTable) clear() { - for _, i := range i.is { - i.used = false - } -} - -func (i *importTable) write(out io.Writer) error { - if len(i.is) == 0 { - // Nothing to import, we're done. - return nil - } - - imports := make([]string, 0, len(i.is)) - for _, i := range i.is { - if i.used { - imports = append(imports, i.String()) - } - } - sort.Strings(imports) - - var b sourceBuffer - b.emit("import (\n") - b.incIndent() - for _, i := range imports { - b.emit("%s\n", i) - } - b.decIndent() - b.emit(")\n\n") - - return b.write(out) -} diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go deleted file mode 100644 index f74be5c29..000000000 --- a/tools/go_marshal/main.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// go_marshal is a code generation utility for automatically generating code to -// marshal go data structures to memory. -// -// This binary is typically run as part of the build process, and is invoked by -// the go_marshal bazel rule defined in defs.bzl. -// -// See README.md. -package main - -import ( - "flag" - "fmt" - "os" - "strings" - - "gvisor.dev/gvisor/tools/go_marshal/gomarshal" -) - -var ( - pkg = flag.String("pkg", "", "output package") - output = flag.String("output", "", "output file") - outputTest = flag.String("output_test", "", "output file for tests") - imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") -) - -func main() { - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0]) - flag.PrintDefaults() - } - flag.Parse() - if len(flag.Args()) == 0 { - flag.Usage() - os.Exit(1) - } - - if *pkg == "" { - flag.Usage() - fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n") - os.Exit(1) - } - - var extraImports []string - if len(*imports) > 0 { - // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus - // we check for an empty imports list to avoid emitting an empty string - // as an import. - extraImports = strings.Split(*imports, ",") - } - g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports) - if err != nil { - panic(err) - } - - if err := g.Run(); err != nil { - panic(err) - } -} diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD deleted file mode 100644 index bacfaa5a4..000000000 --- a/tools/go_marshal/marshal/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "marshal", - srcs = [ - "marshal.go", - ], - visibility = [ - "//:sandbox", - ], - deps = [ - "//pkg/usermem", - ], -) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index f129788e0..f129788e0 100644..100755 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go diff --git a/tools/go_marshal/marshal/marshal_state_autogen.go b/tools/go_marshal/marshal/marshal_state_autogen.go new file mode 100755 index 000000000..a0a953158 --- /dev/null +++ b/tools/go_marshal/marshal/marshal_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package marshal diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD deleted file mode 100644 index f27c5ce52..000000000 --- a/tools/go_marshal/test/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_binary", "go_library", "go_test") - -licenses(["notice"]) - -package_group( - name = "gomarshal_test", - packages = [ - "//tools/go_marshal/test/...", - ], -) - -go_test( - name = "benchmark_test", - srcs = ["benchmark_test.go"], - deps = [ - ":test", - "//pkg/binary", - "//pkg/usermem", - "//tools/go_marshal/analysis", - ], -) - -go_library( - name = "test", - testonly = 1, - srcs = ["test.go"], - marshal = True, - deps = ["//tools/go_marshal/test/external"], -) - -go_binary( - name = "escape", - testonly = 1, - srcs = ["escape.go"], - gc_goopts = ["-m"], - deps = [ - ":test", - "//pkg/usermem", - "//tools/go_marshal/marshal", - ], -) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go deleted file mode 100644 index c79defe9e..000000000 --- a/tools/go_marshal/test/benchmark_test.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package benchmark_test - -import ( - "bytes" - encbin "encoding/binary" - "fmt" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/analysis" - "gvisor.dev/gvisor/tools/go_marshal/test" -) - -// Marshalling using the standard encoding/binary package. -func BenchmarkEncodingBinary(b *testing.B) { - var s1, s2 test.Stat - analysis.RandomizeValue(&s1) - - size := encbin.Size(&s1) - - b.ResetTimer() - - for n := 0; n < b.N; n++ { - buf := bytes.NewBuffer(make([]byte, size)) - buf.Reset() - if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil { - b.Error("Write:", err) - } - if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil { - b.Error("Read:", err) - } - } - - b.StopTimer() - - // Sanity check, make sure the values were preserved. - if !reflect.DeepEqual(s1, s2) { - panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) - } -} - -// Marshalling using the sentry's binary.Marshal. -func BenchmarkBinary(b *testing.B) { - var s1, s2 test.Stat - analysis.RandomizeValue(&s1) - - size := binary.Size(s1) - - b.ResetTimer() - - for n := 0; n < b.N; n++ { - buf := make([]byte, 0, size) - buf = binary.Marshal(buf, usermem.ByteOrder, &s1) - binary.Unmarshal(buf, usermem.ByteOrder, &s2) - } - - b.StopTimer() - - // Sanity check, make sure the values were preserved. - if !reflect.DeepEqual(s1, s2) { - panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) - } -} - -// Marshalling field-by-field with manually-written code. -func BenchmarkMarshalManual(b *testing.B) { - var s1, s2 test.Stat - analysis.RandomizeValue(&s1) - - b.ResetTimer() - - for n := 0; n < b.N; n++ { - buf := make([]byte, 0, s1.SizeBytes()) - - // Marshal - buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev) - buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino) - buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink) - buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode) - buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID) - buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID) - buf = binary.AppendUint32(buf, usermem.ByteOrder, 0) - buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size)) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize)) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks)) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec)) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec)) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec)) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec)) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec)) - buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec)) - - // Unmarshal - s2.Dev = usermem.ByteOrder.Uint64(buf[0:8]) - s2.Ino = usermem.ByteOrder.Uint64(buf[8:16]) - s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24]) - s2.Mode = usermem.ByteOrder.Uint32(buf[24:28]) - s2.UID = usermem.ByteOrder.Uint32(buf[28:32]) - s2.GID = usermem.ByteOrder.Uint32(buf[32:36]) - // Padding: buf[36:40] - s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48]) - s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56])) - s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64])) - s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72])) - s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80])) - s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88])) - s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96])) - s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104])) - s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112])) - s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120])) - } - - b.StopTimer() - - // Sanity check, make sure the values were preserved. - if !reflect.DeepEqual(s1, s2) { - panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) - } -} - -// Marshalling with the go_marshal safe API. -func BenchmarkGoMarshalSafe(b *testing.B) { - var s1, s2 test.Stat - analysis.RandomizeValue(&s1) - - b.ResetTimer() - - for n := 0; n < b.N; n++ { - buf := make([]byte, s1.SizeBytes()) - s1.MarshalBytes(buf) - s2.UnmarshalBytes(buf) - } - - b.StopTimer() - - // Sanity check, make sure the values were preserved. - if !reflect.DeepEqual(s1, s2) { - panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) - } -} - -// Marshalling with the go_marshal unsafe API. -func BenchmarkGoMarshalUnsafe(b *testing.B) { - var s1, s2 test.Stat - analysis.RandomizeValue(&s1) - - b.ResetTimer() - - for n := 0; n < b.N; n++ { - buf := make([]byte, s1.SizeBytes()) - s1.MarshalUnsafe(buf) - s2.UnmarshalUnsafe(buf) - } - - b.StopTimer() - - // Sanity check, make sure the values were preserved. - if !reflect.DeepEqual(s1, s2) { - panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) - } -} diff --git a/tools/go_marshal/test/escape.go b/tools/go_marshal/test/escape.go deleted file mode 100644 index 184f05ea3..000000000 --- a/tools/go_marshal/test/escape.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This binary provides a convienient target for analyzing how the go-marshal -// API causes its various arguments to escape to the heap. To use, build and -// observe the output from the go compiler's escape analysis: -// -// $ bazel build :escape -// ... -// escape.go:67:2: moved to heap: task -// escape.go:77:31: make([]byte, size) escapes to heap -// escape.go:87:31: make([]byte, size) escapes to heap -// escape.go:96:6: moved to heap: stat -// ... -// -// This is not an automated test, but simply a minimal binary for easy analysis. -package main - -import ( - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/test" -) - -// dummyTask implements marshal.Task. -type dummyTask struct { -} - -func (*dummyTask) CopyScratchBuffer(size int) []byte { - return make([]byte, size) -} - -func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { - return len(b), nil -} - -func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { - return len(b), nil -} - -func (task *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { - buf := task.CopyScratchBuffer(marshallable.SizeBytes()) - marshallable.MarshalBytes(buf) - task.CopyOutBytes(addr, buf) -} - -func (task *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { - buf := task.CopyScratchBuffer(marshallable.SizeBytes()) - marshallable.MarshalUnsafe(buf) - task.CopyOutBytes(addr, buf) -} - -// Expected escapes: -// - task: passed to marshal.Marshallable.CopyOut as the marshal.Task interface. -func doCopyOut() { - task := dummyTask{} - var stat test.Stat - stat.CopyOut(&task, usermem.Addr(0xf000ba12)) -} - -// Expected escapes: -// - buf: make allocates on the heap. -func doMarshalBytesDirect() { - task := dummyTask{} - var stat test.Stat - buf := task.CopyScratchBuffer(stat.SizeBytes()) - stat.MarshalBytes(buf) - task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) -} - -// Expected escapes: -// - buf: make allocates on the heap. -func doMarshalUnsafeDirect() { - task := dummyTask{} - var stat test.Stat - buf := task.CopyScratchBuffer(stat.SizeBytes()) - stat.MarshalUnsafe(buf) - task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) -} - -// Expected escapes: -// - stat: passed to dummyTask.MarshalBytes as the marshal.Marshallable interface. -func doMarshalBytesViaMarshallable() { - task := dummyTask{} - var stat test.Stat - task.MarshalBytes(usermem.Addr(0xf000ba12), &stat) -} - -// Expected escapes: -// - stat: passed to dummyTask.MarshalUnsafe as the marshal.Marshallable interface. -func doMarshalUnsafeViaMarshallable() { - task := dummyTask{} - var stat test.Stat - task.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) -} - -func main() { - doCopyOut() - doMarshalBytesDirect() - doMarshalUnsafeDirect() - doMarshalBytesViaMarshallable() - doMarshalUnsafeViaMarshallable() -} diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD deleted file mode 100644 index 0cf6da603..000000000 --- a/tools/go_marshal/test/external/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "external", - testonly = 1, - srcs = ["external.go"], - marshal = True, - visibility = ["//tools/go_marshal/test:gomarshal_test"], -) diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go deleted file mode 100644 index 4be3722f3..000000000 --- a/tools/go_marshal/test/external/external.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package external defines types we can import for testing. -package external - -// External is a public Marshallable type for use in testing. -// -// +marshal -type External struct { - j int64 -} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go deleted file mode 100644 index c829db6da..000000000 --- a/tools/go_marshal/test/test.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package test contains data structures for testing the go_marshal tool. -package test - -import ( - // We're intentionally using a package name alias here even though it's not - // necessary to test the code generator's ability to handle package aliases. - ex "gvisor.dev/gvisor/tools/go_marshal/test/external" -) - -// Type1 is a test data type. -// -// +marshal -type Type1 struct { - a Type2 - x, y int64 // Multiple field names. - b byte `marshal:"unaligned"` // Short field. - c uint64 - _ uint32 // Unnamed scalar field. - _ [6]byte // Unnamed vector field, typical padding. - _ [2]byte - xs [8]int32 - as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects. - ss Type3 -} - -// Type2 is a test data type. -// -// +marshal -type Type2 struct { - n int64 - c byte - _ [7]byte - m int64 - a int64 -} - -// Type3 is a test data type. -// -// +marshal -type Type3 struct { - s int64 - x ex.External // Type defined in another package. -} - -// Type4 is a test data type. -// -// +marshal -type Type4 struct { - c byte - x int64 `marshal:"unaligned"` - d byte - _ [7]byte -} - -// Type5 is a test data type. -// -// +marshal -type Type5 struct { - n int64 - t Type4 - m int64 -} - -// Timespec represents struct timespec in <time.h>. -// -// +marshal -type Timespec struct { - Sec int64 - Nsec int64 -} - -// Stat represents struct stat. -// -// +marshal -type Stat struct { - Dev uint64 - Ino uint64 - Nlink uint64 - Mode uint32 - UID uint32 - GID uint32 - _ int32 - Rdev uint64 - Size int64 - Blksize int64 - Blocks int64 - ATime Timespec - MTime Timespec - CTime Timespec - _ [3]int64 -} - -// InetAddr is an example marshallable newtype on an array. -// -// +marshal -type InetAddr [4]byte - -// SignalSet is an example marshallable newtype on a primitive. -// -// +marshal -type SignalSet uint64 - -// SignalSetAlias is an example newtype on another marshallable type. -// -// +marshal -type SignalSetAlias SignalSet diff --git a/tools/go_mod.sh b/tools/go_mod.sh deleted file mode 100755 index 84b779d6d..000000000 --- a/tools/go_mod.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -eo pipefail - -# Build the :gopath target. -bazel build //:gopath -declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/" - -# Copy go.mod and execute the command. -cp -a go.mod go.sum "${gopathdir}" -(cd "${gopathdir}" && go mod "$@") -cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" . - -# Cleanup the WORKSPACE file. -bazel run //:gazelle -- update-repos -from_file=go.mod diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD deleted file mode 100644 index 503cdf2e5..000000000 --- a/tools/go_stateify/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "stateify", - srcs = ["main.go"], - visibility = ["//:sandbox"], - deps = ["//tools/tags"], -) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl deleted file mode 100644 index 6a5e666f0..000000000 --- a/tools/go_stateify/defs.bzl +++ /dev/null @@ -1,60 +0,0 @@ -"""Stateify is a tool for generating state wrappers for Go types.""" - -def _go_stateify_impl(ctx): - """Implementation for the stateify tool.""" - output = ctx.outputs.out - - # Run the stateify command. - args = ["-output=%s" % output.path] - args.append("-fullpkg=%s" % ctx.attr.package) - if ctx.attr._statepkg: - args.append("-statepkg=%s" % ctx.attr._statepkg) - if ctx.attr.imports: - args.append("-imports=%s" % ",".join(ctx.attr.imports)) - args.append("--") - for src in ctx.attr.srcs: - args += [f.path for f in src.files.to_list()] - ctx.actions.run( - inputs = ctx.files.srcs, - outputs = [output], - mnemonic = "GoStateify", - progress_message = "Generating state library %s" % ctx.label, - arguments = args, - executable = ctx.executable._tool, - ) - -go_stateify = rule( - implementation = _go_stateify_impl, - doc = "Generates save and restore logic from a set of Go files.", - attrs = { - "srcs": attr.label_list( - doc = """ -The input source files. These files should include all structs in the package -that need to be saved. -""", - mandatory = True, - allow_files = True, - ), - "imports": attr.string_list( - doc = """ -An optional list of extra non-aliased, Go-style absolute import paths required -for statified types. -""", - mandatory = False, - ), - "package": attr.string( - doc = "The fully qualified package name for the input sources.", - mandatory = True, - ), - "out": attr.output( - doc = "Name of the generator output file.", - mandatory = True, - ), - "_tool": attr.label( - executable = True, - cfg = "host", - default = Label("//tools/go_stateify:stateify"), - ), - "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), - }, -) diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go deleted file mode 100644 index 3437aa476..000000000 --- a/tools/go_stateify/main.go +++ /dev/null @@ -1,430 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Stateify provides a simple way to generate Load/Save methods based on -// existing types and struct tags. -package main - -import ( - "flag" - "fmt" - "go/ast" - "go/parser" - "go/token" - "os" - "path/filepath" - "reflect" - "strings" - "sync" - - "gvisor.dev/gvisor/tools/tags" -) - -var ( - fullPkg = flag.String("fullpkg", "", "fully qualified output package") - imports = flag.String("imports", "", "extra imports for the output file") - output = flag.String("output", "", "output file") - statePkg = flag.String("statepkg", "", "state import package; defaults to empty") -) - -// resolveTypeName returns a qualified type name. -func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { - for done := false; !done; { - // Resolve star expressions. - switch rs := typ.(type) { - case *ast.StarExpr: - qualified += "*" - typ = rs.X - case *ast.ArrayType: - if rs.Len == nil { - // Slice type declaration. - qualified += "[]" - } else { - // Array type declaration. - qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]" - } - typ = rs.Elt - default: - // No more descent. - done = true - } - } - - // Resolve a package selector. - sel, ok := typ.(*ast.SelectorExpr) - if ok { - qualified = qualified + sel.X.(*ast.Ident).Name + "." - typ = sel.Sel - } - - // Figure out actual type name. - ident, ok := typ.(*ast.Ident) - if !ok { - panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name)) - } - field = ident.Name - qualified = qualified + field - return -} - -// extractStateTag pulls the relevant state tag. -func extractStateTag(tag *ast.BasicLit) string { - if tag == nil { - return "" - } - if len(tag.Value) < 2 { - return "" - } - return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state") -} - -// scanFunctions is a set of functions passed to scanFields. -type scanFunctions struct { - zerovalue func(name string) - normal func(name string) - wait func(name string) - value func(name, typName string) -} - -// scanFields scans the fields of a struct. -// -// Each provided function will be applied to appropriately tagged fields, or -// skipped if nil. -// -// Fields tagged nosave are skipped. -func scanFields(ss *ast.StructType, fn scanFunctions) { - if ss.Fields.List == nil { - // No fields. - return - } - - // Scan all fields. - for _, field := range ss.Fields.List { - // Calculate the name. - name := "" - if field.Names != nil { - // It's a named field; override. - name = field.Names[0].Name - } else { - // Anonymous types can't be embedded, so we don't need - // to worry about providing a useful name here. - name, _ = resolveTypeName("", field.Type) - } - - // Skip _ fields. - if name == "_" { - continue - } - - switch tag := extractStateTag(field.Tag); tag { - case "zerovalue": - if fn.zerovalue != nil { - fn.zerovalue(name) - } - - case "": - if fn.normal != nil { - fn.normal(name) - } - - case "wait": - if fn.wait != nil { - fn.wait(name) - } - - case "manual", "nosave", "ignore": - // Do nothing. - - default: - if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") { - if fn.value != nil { - fn.value(name, tag[2:len(tag)-1]) - } - } - } - } -} - -func camelCased(name string) string { - return strings.ToUpper(name[:1]) + name[1:] -} - -func main() { - // Parse flags. - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) - flag.PrintDefaults() - } - flag.Parse() - if len(flag.Args()) == 0 { - flag.Usage() - os.Exit(1) - } - if *fullPkg == "" { - fmt.Fprintf(os.Stderr, "Error: package required.") - os.Exit(1) - } - - // Open the output file. - var ( - outputFile *os.File - err error - ) - if *output == "" || *output == "-" { - outputFile = os.Stdout - } else { - outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err) - } - defer outputFile.Close() - } - - // Set the statePrefix for below, depending on the import. - statePrefix := "" - if *statePkg != "" { - parts := strings.Split(*statePkg, "/") - statePrefix = parts[len(parts)-1] + "." - } - - // initCalls is dumped at the end. - var initCalls []string - - // Declare our emission closures. - emitRegister := func(name string) { - initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name)) - } - emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name) - } - emitLoadValue := func(name, typName string) { - fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName) - } - emitLoad := func(name string) { - fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name) - } - emitLoadWait := func(name string) { - fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name) - } - emitSaveValue := func(name, typName string) { - fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) - fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name) - } - emitSave := func(name string) { - fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name) - } - - // Automated warning. - fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") - - // Emit build tags. - if t := tags.Aggregate(flag.Args()); len(t) > 0 { - fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) - } - - // Emit the package name. - _, pkg := filepath.Split(*fullPkg) - fmt.Fprintf(outputFile, "package %s\n\n", pkg) - - // Emit the imports lazily. - var once sync.Once - maybeEmitImports := func() { - once.Do(func() { - // Emit the imports. - fmt.Fprint(outputFile, "import (\n") - if *statePkg != "" { - fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) - } - if *imports != "" { - for _, i := range strings.Split(*imports, ",") { - fmt.Fprintf(outputFile, " \"%s\"\n", i) - } - } - fmt.Fprint(outputFile, ")\n\n") - }) - } - - files := make([]*ast.File, 0, len(flag.Args())) - - // Parse the input files. - for _, filename := range flag.Args() { - // Parse the file. - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) - if err != nil { - // Not a valid input file? - fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) - os.Exit(1) - } - - files = append(files, f) - } - - type method struct { - receiver string - name string - } - - // Search for and add all methods with a pointer receiver and no other - // arguments to a set. We support auto-detecting the existence of - // several different methods with this signature. - simpleMethods := map[method]struct{}{} - for _, f := range files { - - // Go over all functions. - for _, decl := range f.Decls { - d, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - if d.Name == nil || d.Recv == nil || d.Type == nil { - // Not a named method. - continue - } - if len(d.Recv.List) != 1 { - // Wrong number of receivers? - continue - } - if d.Type.Params != nil && len(d.Type.Params.List) != 0 { - // Has argument(s). - continue - } - if d.Type.Results != nil && len(d.Type.Results.List) != 0 { - // Has return(s). - continue - } - - pt, ok := d.Recv.List[0].Type.(*ast.StarExpr) - if !ok { - // Not a pointer receiver. - continue - } - - t, ok := pt.X.(*ast.Ident) - if !ok { - // This shouldn't happen with valid Go. - continue - } - - simpleMethods[method{t.Name, d.Name.Name}] = struct{}{} - } - } - - for _, f := range files { - // Go over all named types. - for _, decl := range f.Decls { - d, ok := decl.(*ast.GenDecl) - if !ok || d.Tok != token.TYPE { - continue - } - - // Only generate code for types marked - // "// +stateify savable" in one of the proceeding - // comment lines. - if d.Doc == nil { - continue - } - savable := false - for _, l := range d.Doc.List { - if l.Text == "// +stateify savable" { - savable = true - break - } - } - if !savable { - continue - } - - for _, gs := range d.Specs { - ts := gs.(*ast.TypeSpec) - switch ts.Type.(type) { - case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr: - // Don't register. - break - case *ast.StructType: - maybeEmitImports() - - ss := ts.Type.(*ast.StructType) - - // Define beforeSave if a definition was not found. This - // prevents the code from compiling if a custom beforeSave - // was defined in a file not provided to this binary and - // prevents inherited methods from being called multiple times - // by overriding them. - if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok { - fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name) - } - - // Generate the save method. - fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " x.beforeSave()\n") - scanFields(ss, scanFunctions{zerovalue: emitZeroCheck}) - scanFields(ss, scanFunctions{value: emitSaveValue}) - scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave}) - fmt.Fprintf(outputFile, "}\n\n") - - // Define afterLoad if a definition was not found. We do this - // for the same reason that we do it for beforeSave. - _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] - if !hasAfterLoad { - fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name) - } - - // Generate the load method. - // - // Note that the manual loads always follow the - // automated loads. - fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) - scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait}) - scanFields(ss, scanFunctions{value: emitLoadValue}) - if hasAfterLoad { - // The call to afterLoad is made conditionally, because when - // AfterLoad is called, the object encodes a dependency on - // referred objects (i.e. fields). This means that afterLoad - // will not be called until the other afterLoads are called. - fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") - } - fmt.Fprintf(outputFile, "}\n\n") - - // Add to our registration. - emitRegister(ts.Name.Name) - case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: - maybeEmitImports() - - _, val := resolveTypeName(ts.Name.Name, ts.Type) - - // Dispatch directly. - fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val) - fmt.Fprintf(outputFile, "}\n\n") - fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val) - fmt.Fprintf(outputFile, "}\n\n") - - // See above. - emitRegister(ts.Name.Name) - } - } - } - } - - if len(initCalls) > 0 { - // Emit the init() function. - fmt.Fprintf(outputFile, "func init() {\n") - for _, ic := range initCalls { - fmt.Fprintf(outputFile, " %s\n", ic) - } - fmt.Fprintf(outputFile, "}\n") - } -} diff --git a/tools/image_build.sh b/tools/image_build.sh deleted file mode 100755 index 9b20a740d..000000000 --- a/tools/image_build.sh +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script is responsible for building a new GCP image that: 1) has nested -# virtualization enabled, and 2) has been completely set up with the -# image_setup.sh script. This script should be idempotent, as we memoize the -# setup script with a hash and check for that name. -# -# The GCP project name should be defined via a gcloud config. - -set -xeo pipefail - -# Parameters. -declare -r ZONE=${ZONE:-us-central1-f} -declare -r USERNAME=${USERNAME:-test} -declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud} -declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts} - -# Random names. -declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z) -declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z) -declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z) - -# Hashes inputs. -declare -r SETUP_BLOB=$(echo ${ZONE} ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && sha256sum "$@") -declare -r SETUP_HASH=$(echo ${SETUP_BLOB} | sha256sum - | cut -d' ' -f1 | cut -c 1-16) -declare -r IMAGE_NAME=${IMAGE_NAME:-image-}${SETUP_HASH} - -# Does the image already exist? Skip the build. -declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") -if ! [[ -z "${existing}" ]]; then - echo "${existing}" - exit 0 -fi - -# Set the zone for all actions. -gcloud config set compute/zone "${ZONE}" - -# Start a unique instance. Note that this instance will have a unique persistent -# disk as it's boot disk with the same name as the instance. -gcloud compute instances create \ - --quiet \ - --image-project "${IMAGE_PROJECT}" \ - --image-family "${IMAGE_FAMILY}" \ - --boot-disk-size "200GB" \ - "${INSTANCE_NAME}" -function cleanup { - gcloud compute instances delete --quiet "${INSTANCE_NAME}" -} -trap cleanup EXIT - -# Wait for the instance to become available. -declare attempts=0 -while [[ "${attempts}" -lt 30 ]]; do - attempts=$((${attempts}+1)) - if gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- true; then - break - fi -done -if [[ "${attempts}" -ge 30 ]]; then - echo "too many attempts: failed" - exit 1 -fi - -# Run the install scripts provided. -for arg; do - gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}" -done - -# Stop the instance; required before creating an image. -gcloud compute instances stop --quiet "${INSTANCE_NAME}" - -# Create a snapshot of the instance disk. -gcloud compute disks snapshot \ - --quiet \ - --zone="${ZONE}" \ - --snapshot-names="${SNAPSHOT_NAME}" \ - "${INSTANCE_NAME}" - -# Create the disk image. -gcloud compute images create \ - --quiet \ - --source-snapshot="${SNAPSHOT_NAME}" \ - --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \ - "${IMAGE_NAME}" diff --git a/tools/images/BUILD b/tools/images/BUILD deleted file mode 100644 index fe11f08a3..000000000 --- a/tools/images/BUILD +++ /dev/null @@ -1,68 +0,0 @@ -load("//tools:defs.bzl", "cc_binary", "gtest") -load("//tools/images:defs.bzl", "vm_image", "vm_test") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -genrule( - name = "zone", - outs = ["zone.txt"], - cmd = "gcloud config get-value compute/zone > $@", - tags = [ - "local", - "manual", - ], -) - -sh_binary( - name = "builder", - srcs = ["build.sh"], -) - -sh_binary( - name = "executer", - srcs = ["execute.sh"], -) - -cc_binary( - name = "test", - testonly = 1, - srcs = ["test.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - ], -) - -vm_image( - name = "ubuntu1604", - family = "ubuntu-1604-lts", - project = "ubuntu-os-cloud", - scripts = [ - "//tools/images/ubuntu1604", - ], -) - -vm_test( - name = "ubuntu1604_test", - image = ":ubuntu1604", - targets = [":test"], -) - -vm_image( - name = "ubuntu1804", - family = "ubuntu-1804-lts", - project = "ubuntu-os-cloud", - scripts = [ - "//tools/images/ubuntu1804", - ], -) - -vm_test( - name = "ubuntu1804_test", - image = ":ubuntu1804", - targets = [":test"], -) diff --git a/tools/images/build.sh b/tools/images/build.sh deleted file mode 100755 index be462d556..000000000 --- a/tools/images/build.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script is responsible for building a new GCP image that: 1) has nested -# virtualization enabled, and 2) has been completely set up with the -# image_setup.sh script. This script should be idempotent, as we memoize the -# setup script with a hash and check for that name. - -set -xeou pipefail - -# Parameters. -declare -r USERNAME=${USERNAME:-test} -declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud} -declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts} -declare -r ZONE=${ZONE:-us-central1-f} - -# Random names. -declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z) -declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z) -declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z) - -# Hash inputs in order to memoize the produced image. -declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16) -declare -r IMAGE_NAME=${IMAGE_FAMILY:-image-}${SETUP_HASH} - -# Does the image already exist? Skip the build. -declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") -if ! [[ -z "${existing}" ]]; then - echo "${existing}" - exit 0 -fi - -# gcloud has path errors; is this a result of being a genrule? -export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin} - -# Start a unique instance. Note that this instance will have a unique persistent -# disk as it's boot disk with the same name as the instance. -gcloud compute instances create \ - --quiet \ - --image-project "${IMAGE_PROJECT}" \ - --image-family "${IMAGE_FAMILY}" \ - --boot-disk-size "200GB" \ - --zone "${ZONE}" \ - "${INSTANCE_NAME}" >/dev/null -function cleanup { - gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}" -} -trap cleanup EXIT - -# Wait for the instance to become available (up to 5 minutes). -declare timeout=300 -declare success=0 -declare -r start=$(date +%s) -declare -r end=$((${start}+${timeout})) -while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do - if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then - success=$((${success}+1)) - fi -done -if [[ "${success}" -eq "0" ]]; then - echo "connect timed out after ${timeout} seconds." - exit 1 -fi - -# Run the install scripts provided. -for arg; do - gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}" >/dev/null -done - -# Stop the instance; required before creating an image. -gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null - -# Create a snapshot of the instance disk. -gcloud compute disks snapshot \ - --quiet \ - --zone "${ZONE}" \ - --snapshot-names="${SNAPSHOT_NAME}" \ - "${INSTANCE_NAME}" >/dev/null - -# Create the disk image. -gcloud compute images create \ - --quiet \ - --source-snapshot="${SNAPSHOT_NAME}" \ - --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \ - "${IMAGE_NAME}" >/dev/null - -# Finish up. -echo "${IMAGE_NAME}" diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl deleted file mode 100644 index de365d153..000000000 --- a/tools/images/defs.bzl +++ /dev/null @@ -1,183 +0,0 @@ -"""Image configuration. - -Images can be generated by using the vm_image rule. For example, - - vm_image( - name = "ubuntu", - project = "...", - family = "...", - scripts = [ - "script.sh", - "other.sh", - ], - ) - -This will always create an vm_image in the current default gcloud project. The -rule has a text file as its output containing the image name. This will enforce -serialization for all dependent rules. - -Images are always named per the hash of all the hermetic input scripts. This -allows images to be memoized quickly and easily. - -The vm_test rule can be used to execute a command remotely. For example, - - vm_test( - name = "mycommand", - image = ":myimage", - targets = [":test"], - ) -""" - -load("//tools:defs.bzl", "default_installer") - -def _vm_image_impl(ctx): - script_paths = [] - for script in ctx.files.scripts: - script_paths.append(script.short_path) - - resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( - command = "USERNAME=%s ZONE=$(cat %s) IMAGE_PROJECT=%s IMAGE_FAMILY=%s %s %s > %s" % - ( - ctx.attr.username, - ctx.files.zone[0].path, - ctx.attr.project, - ctx.attr.family, - ctx.executable.builder.path, - " ".join(script_paths), - ctx.outputs.out.path, - ), - tools = [ctx.attr.builder] + ctx.attr.scripts, - ) - - ctx.actions.run_shell( - tools = resolved_inputs, - outputs = [ctx.outputs.out], - progress_message = "Building image...", - execution_requirements = {"local": "true"}, - command = argv, - input_manifests = runfiles_manifests, - ) - return [DefaultInfo( - files = depset([ctx.outputs.out]), - runfiles = ctx.runfiles(files = [ctx.outputs.out]), - )] - -_vm_image = rule( - attrs = { - "builder": attr.label( - executable = True, - default = "//tools/images:builder", - cfg = "host", - ), - "username": attr.string(default = "$(whoami)"), - "zone": attr.label( - default = "//tools/images:zone", - cfg = "host", - ), - "family": attr.string(mandatory = True), - "project": attr.string(mandatory = True), - "scripts": attr.label_list(allow_files = True), - }, - outputs = { - "out": "%{name}.txt", - }, - implementation = _vm_image_impl, -) - -def vm_image(**kwargs): - _vm_image( - tags = [ - "local", - "manual", - ], - **kwargs - ) - -def _vm_test_impl(ctx): - runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) - - # Note that the remote execution case must actually generate an - # intermediate target in order to collect all the relevant runfiles so that - # they can be copied over for remote execution. - runner_content = "\n".join([ - "#!/bin/bash", - "export ZONE=$(cat %s)" % ctx.files.zone[0].short_path, - "export USERNAME=%s" % ctx.attr.username, - "export IMAGE=$(cat %s)" % ctx.files.image[0].short_path, - "export SUDO=%s" % "true" if ctx.attr.sudo else "false", - "%s %s" % ( - ctx.executable.executer.short_path, - " ".join([ - target.files_to_run.executable.short_path - for target in ctx.attr.targets - ]), - ), - "", - ]) - ctx.actions.write(runner, runner_content, is_executable = True) - - # Return with all transitive files. - runfiles = ctx.runfiles( - transitive_files = depset(transitive = [ - depset(target.data_runfiles.files) - for target in ctx.attr.targets - if hasattr(target, "data_runfiles") - ]), - files = ctx.files.executer + ctx.files.zone + ctx.files.image + - ctx.files.targets, - collect_default = True, - collect_data = True, - ) - return [DefaultInfo(executable = runner, runfiles = runfiles)] - -_vm_test = rule( - attrs = { - "image": attr.label( - mandatory = True, - cfg = "host", - ), - "executer": attr.label( - executable = True, - default = "//tools/images:executer", - cfg = "host", - ), - "username": attr.string(default = "$(whoami)"), - "zone": attr.label( - default = "//tools/images:zone", - cfg = "host", - ), - "sudo": attr.bool(default = True), - "machine": attr.string(default = "n1-standard-1"), - "targets": attr.label_list( - mandatory = True, - allow_empty = False, - cfg = "target", - ), - }, - test = True, - implementation = _vm_test_impl, -) - -def vm_test( - installer = "//tools/installers:head", - **kwargs): - """Runs the given targets as a remote test. - - Args: - installer: Script to run before all targets. - **kwargs: All test arguments. Should include targets and image. - """ - targets = kwargs.pop("targets", []) - if installer: - targets = [installer] + targets - if default_installer(): - targets = [default_installer()] + targets - _vm_test( - tags = [ - "local", - "manual", - ], - targets = targets, - local = 1, - **kwargs - ) diff --git a/tools/images/execute.sh b/tools/images/execute.sh deleted file mode 100755 index ba4b1ac0e..000000000 --- a/tools/images/execute.sh +++ /dev/null @@ -1,152 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Required input. -if ! [[ -v IMAGE ]]; then - echo "no image provided: set IMAGE." - exit 1 -fi - -# Parameters. -declare -r USERNAME=${USERNAME:-test} -declare -r KEYNAME=$(mktemp --tmpdir -u key-XXXXXX) -declare -r SSHKEYS=$(mktemp --tmpdir -u sshkeys-XXXXXX) -declare -r INSTANCE_NAME=$(mktemp -u test-XXXXXX | tr A-Z a-z) -declare -r MACHINE=${MACHINE:-n1-standard-1} -declare -r ZONE=${ZONE:-us-central1-f} -declare -r SUDO=${SUDO:-false} - -# This script is executed as a test rule, which will reset the value of HOME. -# Unfortunately, it is needed to load the gconfig credentials. We will reset -# HOME when we actually execute in the remote environment, defined below. -export HOME=$(eval echo ~$(whoami)) - -# Generate unique keys for this test. -[[ -f "${KEYNAME}" ]] || ssh-keygen -t rsa -N "" -f "${KEYNAME}" -C "${USERNAME}" -cat > "${SSHKEYS}" <<EOF -${USERNAME}:$(cat ${KEYNAME}.pub) -EOF - -# Start a unique instance. This means that we first generate a unique set of ssh -# keys to ensure that only we have access to this instance. Note that we must -# constrain ourselves to Haswell or greater in order to have nested -# virtualization available. -gcloud compute instances create \ - --min-cpu-platform "Intel Haswell" \ - --preemptible \ - --no-scopes \ - --metadata block-project-ssh-keys=TRUE \ - --metadata-from-file ssh-keys="${SSHKEYS}" \ - --machine-type "${MACHINE}" \ - --image "${IMAGE}" \ - --zone "${ZONE}" \ - "${INSTANCE_NAME}" -function cleanup { - gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}" -} -trap cleanup EXIT - -# Wait for the instance to become available (up to 5 minutes). -declare timeout=300 -declare success=0 -declare -r start=$(date +%s) -declare -r end=$((${start}+${timeout})) -while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do - if gcloud compute ssh --ssh-key-file="${KEYNAME}" --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then - success=$((${success}+1)) - fi -done -if [[ "${success}" -eq "0" ]]; then - echo "connect timed out after ${timeout} seconds." - exit 1 -fi - -# Copy the local directory over. -tar czf - --dereference --exclude=.git . | - gcloud compute ssh \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- tar xzf - - -# Execute the command remotely. -for cmd; do - # Setup relevant environment. - # - # N.B. This is not a complete test environment, but is complete enough to - # provide rudimentary sharding and test output support. - declare -a PREFIX=( "env" ) - if [[ -v TEST_SHARD_INDEX ]]; then - PREFIX+=( "TEST_SHARD_INDEX=${TEST_SHARD_INDEX}" ) - fi - if [[ -v TEST_SHARD_STATUS_FILE ]]; then - SHARD_STATUS_FILE=$(mktemp -u test-shard-status-XXXXXX) - PREFIX+=( "TEST_SHARD_STATUS_FILE=/tmp/${SHARD_STATUS_FILE}" ) - fi - if [[ -v TEST_TOTAL_SHARDS ]]; then - PREFIX+=( "TEST_TOTAL_SHARDS=${TEST_TOTAL_SHARDS}" ) - fi - if [[ -v TEST_TMPDIR ]]; then - REMOTE_TMPDIR=$(mktemp -u test-XXXXXX) - PREFIX+=( "TEST_TMPDIR=/tmp/${REMOTE_TMPDIR}" ) - # Create remotely. - gcloud compute ssh \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- \ - mkdir -p "/tmp/${REMOTE_TMPDIR}" - fi - if [[ -v XML_OUTPUT_FILE ]]; then - TEST_XML_OUTPUT=$(mktemp -u xml-output-XXXXXX) - PREFIX+=( "XML_OUTPUT_FILE=/tmp/${TEST_XML_OUTPUT}" ) - fi - if [[ "${SUDO}" == "true" ]]; then - PREFIX+=( "sudo" "-E" ) - fi - - # Execute the command. - gcloud compute ssh \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- \ - "${PREFIX[@]}" "${cmd}" - - # Collect relevant results. - if [[ -v TEST_SHARD_STATUS_FILE ]]; then - gcloud compute scp \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${SHARD_STATUS_FILE}" \ - "${TEST_SHARD_STATUS_FILE}" 2>/dev/null || true # Allowed to fail. - fi - if [[ -v XML_OUTPUT_FILE ]]; then - gcloud compute scp \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${TEST_XML_OUTPUT}" \ - "${XML_OUTPUT_FILE}" 2>/dev/null || true # Allowed to fail. - fi - - # Clean up the temporary directory. - if [[ -v TEST_TMPDIR ]]; then - gcloud compute ssh \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- \ - rm -rf "/tmp/${REMOTE_TMPDIR}" - fi -done diff --git a/tools/images/test.cc b/tools/images/test.cc deleted file mode 100644 index 4f31d93c5..000000000 --- a/tools/images/test.cc +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gtest/gtest.h" - -namespace { - -TEST(Image, Sanity) { - // Do nothing. -} - -} // namespace diff --git a/tools/images/ubuntu1604/10_core.sh b/tools/images/ubuntu1604/10_core.sh deleted file mode 100755 index cd518d6ac..000000000 --- a/tools/images/ubuntu1604/10_core.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Install all essential build tools. -while true; do - if (apt-get update && apt-get install -y \ - make \ - git-core \ - build-essential \ - linux-headers-$(uname -r) \ - pkg-config); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Install a recent go toolchain. -if ! [[ -d /usr/local/go ]]; then - wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz - tar -xvf go1.13.5.linux-amd64.tar.gz - mv go /usr/local -fi - -# Link the Go binary from /usr/bin; replacing anything there. -(cd /usr/bin && rm -f go && sudo ln -fs /usr/local/go/bin/go go) diff --git a/tools/images/ubuntu1604/20_bazel.sh b/tools/images/ubuntu1604/20_bazel.sh deleted file mode 100755 index bb7afa676..000000000 --- a/tools/images/ubuntu1604/20_bazel.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -declare -r BAZEL_VERSION=2.0.0 - -# Install bazel dependencies. -while true; do - if (apt-get update && apt-get install -y \ - openjdk-8-jdk-headless \ - unzip); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Use the release installer. -curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh -chmod a+x bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh -./bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh -rm -f bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh diff --git a/tools/images/ubuntu1604/25_docker.sh b/tools/images/ubuntu1604/25_docker.sh deleted file mode 100755 index 11eea2d72..000000000 --- a/tools/images/ubuntu1604/25_docker.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Add dependencies. -while true; do - if (apt-get update && apt-get install -y \ - apt-transport-https \ - ca-certificates \ - curl \ - gnupg-agent \ - software-properties-common); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Install the key. -curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - - -# Add the repository. -add-apt-repository \ - "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) \ - stable" - -# Install docker. -while true; do - if (apt-get update && apt-get install -y \ - docker-ce \ - docker-ce-cli \ - containerd.io); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done diff --git a/tools/images/ubuntu1604/30_containerd.sh b/tools/images/ubuntu1604/30_containerd.sh deleted file mode 100755 index fb3699c12..000000000 --- a/tools/images/ubuntu1604/30_containerd.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Helper for Go packages below. -install_helper() { - PACKAGE="${1}" - TAG="${2}" - GOPATH="${3}" - - # Clone the repository. - mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \ - git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}" - - # Checkout and build the repository. - (cd "${GOPATH}"/src/"${PACKAGE}" && \ - git checkout "${TAG}" && \ - GOPATH="${GOPATH}" make && \ - GOPATH="${GOPATH}" make install) -} - -# Install dependencies for the crictl tests. -while true; do - if (apt-get update && apt-get install -y \ - btrfs-tools \ - libseccomp-dev); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Install containerd & cri-tools. -GOPATH=$(mktemp -d --tmpdir gopathXXXXX) -install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}" -install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}" - -# Install gvisor-containerd-shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -mv ${shim_path} /usr/local/bin - -# Configure containerd-shim. -declare -r shim_config_path=/etc/containerd -declare -r shim_config_tmp_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX.toml) -mkdir -p ${shim_config_path} -cat > ${shim_config_tmp_path} <<-EOF - runc_shim = "/usr/local/bin/containerd-shim" - -[runsc_config] - debug = "true" - debug-log = "/tmp/runsc-logs/" - strace = "true" - file-access = "shared" -EOF -mv ${shim_config_tmp_path} ${shim_config_path} - -# Configure CNI. -(cd "${GOPATH}" && GOPATH="${GOPATH}" \ - src/github.com/containerd/containerd/script/setup/install-cni) - -# Cleanup the above. -rm -rf "${GOPATH}" -rm -rf "${latest}" -rm -rf "${shim_path}" -rm -rf "${shim_config_tmp_path}" diff --git a/tools/images/ubuntu1604/40_kokoro.sh b/tools/images/ubuntu1604/40_kokoro.sh deleted file mode 100755 index 06a1e6c48..000000000 --- a/tools/images/ubuntu1604/40_kokoro.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Declare kokoro's required public keys. -declare -r ssh_public_keys=( - "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDg7L/ZaEauETWrPklUTky3kvxqQfe2Ax/2CsSqhNIGNMnK/8d79CHlmY9+dE1FFQ/RzKNCaltgy7XcN/fCYiCZr5jm2ZtnLuGNOTzupMNhaYiPL419qmL+5rZXt4/dWTrsHbFRACxT8j51PcRMO5wgbL0Bg2XXimbx8kDFaurL2gqduQYqlu4lxWCaJqOL71WogcimeL63Nq/yeH5PJPWpqE4P9VUQSwAzBWFK/hLeds/AiP3MgVS65qHBnhq0JsHy8JQsqjZbG7Iidt/Ll0+gqzEbi62gDIcczG4KC0iOVzDDP/1BxDtt1lKeA23ll769Fcm3rJyoBMYxjvdw1TDx sabujp@trigger.mtv.corp.google.com" - "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNgGK/hCdjmulHfRE3hp4rZs38NCR8yAh0eDsztxqGcuXnuSnL7jOlRrbcQpremJ84omD4eKrIpwJUs+YokMdv4= sabujp@trigger.svl.corp.google.com" -) - -# Install dependencies. -while true; do - if (apt-get update && apt-get install -y \ - rsync \ - coreutils \ - python-psutil \ - qemu-kvm \ - python-pip \ - python3-pip \ - zip); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# junitparser is used to merge junit xml files. -pip install junitparser - -# We need a kbuilder user. -if useradd -c "kbuilder user" -m -s /bin/bash kbuilder; then - # User was added successfully; we add the relevant SSH keys here. - mkdir -p ~kbuilder/.ssh - (IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys - chmod 0600 ~kbuilder/.ssh/authorized_keys - chown -R kbuilder ~kbuilder/.ssh -fi - -# Give passwordless sudo access. -cat > /etc/sudoers.d/kokoro <<EOF -kbuilder ALL=(ALL) NOPASSWD:ALL -EOF - -# Ensure we can run Docker without sudo. -usermod -aG docker kbuilder - -# Ensure that we can access kvm. -usermod -aG kvm kbuilder - -# Ensure that /tmpfs exists and is writable by kokoro. -# -# Note that kokoro will typically attach a second disk (sdb) to the instance -# that is used for the /tmpfs volume. In the future we could setup an init -# script that formats and mounts this here; however, we don't expect our build -# artifacts to be that large. -mkdir -p /tmpfs && chmod 0777 /tmpfs && touch /tmpfs/READY diff --git a/tools/images/ubuntu1604/BUILD b/tools/images/ubuntu1604/BUILD deleted file mode 100644 index ab1df0c4c..000000000 --- a/tools/images/ubuntu1604/BUILD +++ /dev/null @@ -1,7 +0,0 @@ -package(licenses = ["notice"]) - -filegroup( - name = "ubuntu1604", - srcs = glob(["*.sh"]), - visibility = ["//:sandbox"], -) diff --git a/tools/images/ubuntu1804/BUILD b/tools/images/ubuntu1804/BUILD deleted file mode 100644 index 7aa1ecdf7..000000000 --- a/tools/images/ubuntu1804/BUILD +++ /dev/null @@ -1,7 +0,0 @@ -package(licenses = ["notice"]) - -alias( - name = "ubuntu1804", - actual = "//tools/images/ubuntu1604", - visibility = ["//:sandbox"], -) diff --git a/tools/installers/BUILD b/tools/installers/BUILD deleted file mode 100644 index d78a265ca..000000000 --- a/tools/installers/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -# Installers for use by the tools/vm_test rules. - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -filegroup( - name = "runsc", - srcs = ["//runsc"], -) - -sh_binary( - name = "head", - srcs = ["head.sh"], - data = [":runsc"], -) - -sh_binary( - name = "master", - srcs = ["master.sh"], -) - -sh_binary( - name = "shim", - srcs = ["shim.sh"], -) diff --git a/tools/installers/head.sh b/tools/installers/head.sh deleted file mode 100755 index 9de8f138c..000000000 --- a/tools/installers/head.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Install our runtime. -$(dirname $0)/runsc install - -# Restart docker. -service docker restart || true diff --git a/tools/installers/master.sh b/tools/installers/master.sh deleted file mode 100755 index 2c6001c6c..000000000 --- a/tools/installers/master.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Install runsc from the master branch. -set -e - -curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - -add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" - -while true; do - if (apt-get update && apt-get install -y runsc); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -runsc install -service docker restart diff --git a/tools/installers/shim.sh b/tools/installers/shim.sh deleted file mode 100755 index f7dd790a1..000000000 --- a/tools/installers/shim.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Reinstall the latest containerd shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -mv ${shim_path} /usr/local/bin/gvisor-containerd-shim diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD deleted file mode 100644 index 4ef1a3124..000000000 --- a/tools/issue_reviver/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "issue_reviver", - srcs = ["main.go"], - deps = [ - "//tools/issue_reviver/github", - "//tools/issue_reviver/reviver", - ], -) diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD deleted file mode 100644 index da4133472..000000000 --- a/tools/issue_reviver/github/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "github", - srcs = ["github.go"], - visibility = [ - "//tools/issue_reviver:__subpackages__", - ], - deps = [ - "//tools/issue_reviver/reviver", - "@com_github_google_go-github//github:go_default_library", - "@org_golang_x_oauth2//:go_default_library", - ], -) diff --git a/tools/issue_reviver/github/github.go b/tools/issue_reviver/github/github.go deleted file mode 100644 index e07949c8f..000000000 --- a/tools/issue_reviver/github/github.go +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package github implements reviver.Bugger interface on top of Github issues. -package github - -import ( - "context" - "fmt" - "strconv" - "strings" - "time" - - "github.com/google/go-github/github" - "golang.org/x/oauth2" - "gvisor.dev/gvisor/tools/issue_reviver/reviver" -) - -// Bugger implements reviver.Bugger interface for github issues. -type Bugger struct { - owner string - repo string - dryRun bool - - client *github.Client - issues map[int]*github.Issue -} - -// NewBugger creates a new Bugger. -func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) { - b := &Bugger{ - owner: owner, - repo: repo, - dryRun: dryRun, - issues: map[int]*github.Issue{}, - } - if err := b.load(token); err != nil { - return nil, err - } - return b, nil -} - -func (b *Bugger) load(token string) error { - ctx := context.Background() - if len(token) == 0 { - fmt.Print("No OAUTH token provided, using unauthenticated account.\n") - b.client = github.NewClient(nil) - } else { - ts := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: token}, - ) - tc := oauth2.NewClient(ctx, ts) - b.client = github.NewClient(tc) - } - - err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) { - opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts} - tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts) - if err != nil { - return resp, err - } - for _, issue := range tmps { - b.issues[issue.GetNumber()] = issue - } - return resp, nil - }) - if err != nil { - return err - } - - fmt.Printf("Loaded %d issues from github.com/%s/%s\n", len(b.issues), b.owner, b.repo) - return nil -} - -// Activate implements reviver.Bugger. -func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) { - const prefix = "gvisor.dev/issue/" - - // First check if I can handle the TODO. - idStr := strings.TrimPrefix(todo.Issue, prefix) - if len(todo.Issue) == len(idStr) { - return false, nil - } - - id, err := strconv.Atoi(idStr) - if err != nil { - return true, err - } - - // Check against active issues cache. - if _, ok := b.issues[id]; ok { - fmt.Printf("%q is active: OK\n", todo.Issue) - return true, nil - } - - fmt.Printf("%q is not active: reopening issue %d\n", todo.Issue, id) - - // Format comment with TODO locations and search link. - comment := strings.Builder{} - fmt.Fprintln(&comment, "There are TODOs still referencing this issue:") - for _, l := range todo.Locations { - fmt.Fprintf(&comment, - "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#%d): %s\n", - l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment) - } - fmt.Fprintf(&comment, - "\n\nSearch [TODO](https://github.com/%s/%s/search?q=%%22%s%d%%22)", b.owner, b.repo, prefix, id) - - if b.dryRun { - fmt.Printf("[dry-run: skipping change to issue %d]\n%s\n=======================\n", id, comment.String()) - return true, nil - } - - ctx := context.Background() - req := &github.IssueRequest{State: github.String("open")} - _, _, err = b.client.Issues.Edit(ctx, b.owner, b.repo, id, req) - if err != nil { - return true, fmt.Errorf("failed to reactivate issue %d: %v", id, err) - } - - cmt := &github.IssueComment{ - Body: github.String(comment.String()), - Reactions: &github.Reactions{Confused: github.Int(1)}, - } - if _, _, err := b.client.Issues.CreateComment(ctx, b.owner, b.repo, id, cmt); err != nil { - return true, fmt.Errorf("failed to add comment to issue %d: %v", id, err) - } - - return true, nil -} - -func processAllPages(fn func(github.ListOptions) (*github.Response, error)) error { - opts := github.ListOptions{PerPage: 1000} - for { - resp, err := fn(opts) - if err != nil { - if rateErr, ok := err.(*github.RateLimitError); ok { - duration := rateErr.Rate.Reset.Sub(time.Now()) - if duration > 5*time.Minute { - return fmt.Errorf("Rate limited for too long: %v", duration) - } - fmt.Printf("Rate limited, sleeping for: %v\n", duration) - time.Sleep(duration) - continue - } - return err - } - if resp.NextPage == 0 { - return nil - } - opts.Page = resp.NextPage - } -} diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go deleted file mode 100644 index 4256f5a6c..000000000 --- a/tools/issue_reviver/main.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package main is the entry point for issue_reviver. -package main - -import ( - "flag" - "fmt" - "io/ioutil" - "os" - - "gvisor.dev/gvisor/tools/issue_reviver/github" - "gvisor.dev/gvisor/tools/issue_reviver/reviver" -) - -var ( - owner string - repo string - tokenFile string - path string - dryRun bool -) - -// Keep the options simple for now. Supports only a single path and repo. -func init() { - flag.StringVar(&owner, "owner", "google", "Github project org/owner to look for issues") - flag.StringVar(&repo, "repo", "gvisor", "Github repo to look for issues") - flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github") - flag.StringVar(&path, "path", "", "Path to scan for TODOs") - flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues") -} - -func main() { - flag.Parse() - - // Check for mandatory parameters. - if len(owner) == 0 { - fmt.Println("missing --owner option.") - flag.Usage() - os.Exit(1) - } - if len(repo) == 0 { - fmt.Println("missing --repo option.") - flag.Usage() - os.Exit(1) - } - if len(path) == 0 { - fmt.Println("missing --path option.") - flag.Usage() - os.Exit(1) - } - - // Token is passed as a file so it doesn't show up in command line arguments. - var token string - if len(tokenFile) != 0 { - bytes, err := ioutil.ReadFile(tokenFile) - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - token = string(bytes) - } - - bugger, err := github.NewBugger(token, owner, repo, dryRun) - if err != nil { - fmt.Fprintln(os.Stderr, "Error getting github issues:", err) - os.Exit(1) - } - rev := reviver.New([]string{path}, []reviver.Bugger{bugger}) - if errs := rev.Run(); len(errs) > 0 { - fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs)) - for _, err := range errs { - fmt.Fprintf(os.Stderr, "\t%v\n", err) - } - os.Exit(1) - } -} diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD deleted file mode 100644 index d262932bd..000000000 --- a/tools/issue_reviver/reviver/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "reviver", - srcs = ["reviver.go"], - visibility = [ - "//tools/issue_reviver:__subpackages__", - ], -) - -go_test( - name = "reviver_test", - size = "small", - srcs = ["reviver_test.go"], - library = ":reviver", -) diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/issue_reviver/reviver/reviver.go deleted file mode 100644 index 682db0c01..000000000 --- a/tools/issue_reviver/reviver/reviver.go +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package reviver scans the code looking for TODOs and pass them to registered -// Buggers to ensure TODOs point to active issues. -package reviver - -import ( - "bufio" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "regexp" - "sync" -) - -// This is how a TODO looks like. -var regexTodo = regexp.MustCompile(`(\/\/|#)\s*(TODO|FIXME)\(([a-zA-Z0-9.\/]+)\):\s*(.+)`) - -// Bugger interface is called for every TODO found in the code. If it can handle -// the TODO, it must return true. If it returns false, the next Bugger is -// called. If no Bugger handles the TODO, it's dropped on the floor. -type Bugger interface { - Activate(todo *Todo) (bool, error) -} - -// Location saves the location where the TODO was found. -type Location struct { - Comment string - File string - Line uint -} - -// Todo represents a unique TODO. There can be several TODOs pointing to the -// same issue in the code. They are all grouped together. -type Todo struct { - Issue string - Locations []Location -} - -// Reviver scans the given paths for TODOs and calls Buggers to handle them. -type Reviver struct { - paths []string - buggers []Bugger - - mu sync.Mutex - todos map[string]*Todo - errs []error -} - -// New create a new Reviver. -func New(paths []string, buggers []Bugger) *Reviver { - return &Reviver{ - paths: paths, - buggers: buggers, - todos: map[string]*Todo{}, - } -} - -// Run runs. It returns all errors found during processing, it doesn't stop -// on errors. -func (r *Reviver) Run() []error { - // Process each directory in parallel. - wg := sync.WaitGroup{} - for _, path := range r.paths { - wg.Add(1) - go func(path string) { - defer wg.Done() - r.processPath(path, &wg) - }(path) - } - - wg.Wait() - - r.mu.Lock() - defer r.mu.Unlock() - - fmt.Printf("Processing %d TODOs (%d errors)...\n", len(r.todos), len(r.errs)) - dropped := 0 - for _, todo := range r.todos { - ok, err := r.processTodo(todo) - if err != nil { - r.errs = append(r.errs, err) - } - if !ok { - dropped++ - } - } - fmt.Printf("Processed %d TODOs, %d were skipped (%d errors)\n", len(r.todos)-dropped, dropped, len(r.errs)) - - return r.errs -} - -func (r *Reviver) processPath(path string, wg *sync.WaitGroup) { - fmt.Printf("Processing dir %q\n", path) - fis, err := ioutil.ReadDir(path) - if err != nil { - r.addErr(fmt.Errorf("error processing dir %q: %v", path, err)) - return - } - - for _, fi := range fis { - childPath := filepath.Join(path, fi.Name()) - switch { - case fi.Mode().IsDir(): - wg.Add(1) - go func() { - defer wg.Done() - r.processPath(childPath, wg) - }() - - case fi.Mode().IsRegular(): - file, err := os.Open(childPath) - if err != nil { - r.addErr(err) - continue - } - - scanner := bufio.NewScanner(file) - lineno := uint(0) - for scanner.Scan() { - lineno++ - line := scanner.Text() - if todo := r.processLine(line, childPath, lineno); todo != nil { - r.addTodo(todo) - } - } - } - } -} - -func (r *Reviver) processLine(line, path string, lineno uint) *Todo { - matches := regexTodo.FindStringSubmatch(line) - if matches == nil { - return nil - } - if len(matches) != 5 { - panic(fmt.Sprintf("regex returned wrong matches for %q: %v", line, matches)) - } - return &Todo{ - Issue: matches[3], - Locations: []Location{ - { - File: path, - Line: lineno, - Comment: matches[4], - }, - }, - } -} - -func (r *Reviver) addTodo(newTodo *Todo) { - r.mu.Lock() - defer r.mu.Unlock() - - if todo := r.todos[newTodo.Issue]; todo == nil { - r.todos[newTodo.Issue] = newTodo - } else { - todo.Locations = append(todo.Locations, newTodo.Locations...) - } -} - -func (r *Reviver) addErr(err error) { - r.mu.Lock() - defer r.mu.Unlock() - r.errs = append(r.errs, err) -} - -func (r *Reviver) processTodo(todo *Todo) (bool, error) { - for _, bugger := range r.buggers { - ok, err := bugger.Activate(todo) - if err != nil { - return false, err - } - if ok { - return true, nil - } - } - return false, nil -} diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/issue_reviver/reviver/reviver_test.go deleted file mode 100644 index a9fb1f9f1..000000000 --- a/tools/issue_reviver/reviver/reviver_test.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reviver - -import ( - "testing" -) - -func TestProcessLine(t *testing.T) { - for _, tc := range []struct { - line string - want *Todo - }{ - { - line: "// TODO(foobar.com/issue/123): comment, bla. blabla.", - want: &Todo{ - Issue: "foobar.com/issue/123", - Locations: []Location{ - {Comment: "comment, bla. blabla."}, - }, - }, - }, - { - line: "// FIXME(b/123): internal bug", - want: &Todo{ - Issue: "b/123", - Locations: []Location{ - {Comment: "internal bug"}, - }, - }, - }, - { - line: "TODO(issue): not todo", - }, - { - line: "FIXME(issue): not todo", - }, - { - line: "// TODO (issue): not todo", - }, - { - line: "// TODO(issue) not todo", - }, - { - line: "// todo(issue): not todo", - }, - { - line: "// TODO(issue):", - }, - } { - t.Logf("Testing: %s", tc.line) - r := Reviver{} - got := r.processLine(tc.line, "test", 0) - if got == nil { - if tc.want != nil { - t.Errorf("failed to process line, want: %+v", tc.want) - } - } else { - if tc.want == nil { - t.Errorf("expected error, got: %+v", got) - continue - } - if got.Issue != tc.want.Issue { - t.Errorf("wrong issue, got: %v, want: %v", got.Issue, tc.want.Issue) - } - if len(got.Locations) != len(tc.want.Locations) { - t.Errorf("wrong number of locations, got: %v, want: %v, locations: %+v", len(got.Locations), len(tc.want.Locations), got.Locations) - } - for i, wantLoc := range tc.want.Locations { - if got.Locations[i].Comment != wantLoc.Comment { - t.Errorf("wrong comment, got: %v, want: %v", got.Locations[i].Comment, wantLoc.Comment) - } - } - } - } -} diff --git a/tools/make_repository.sh b/tools/make_repository.sh deleted file mode 100755 index 27ffbc9f3..000000000 --- a/tools/make_repository.sh +++ /dev/null @@ -1,108 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Parse arguments. We require more than two arguments, which are the private -# keyring, the e-mail associated with the signer, and the list of packages. -if [ "$#" -le 3 ]; then - echo "usage: $0 <private-key> <signer-email> <component> <root> <packages...>" - exit 1 -fi -declare -r private_key=$(readlink -e "$1"); shift -declare -r signer="$1"; shift -declare -r component="$1"; shift -declare -r root="$1"; shift - -# Verbose from this point. -set -xeo pipefail - -# Create a temporary working directory. We don't remove this, as we ultimately -# print this result and allow the caller to copy wherever they would like. -declare -r tmpdir=$(mktemp -d /tmp/repoXXXXXX) - -# Create a temporary keyring, and ensure it is cleaned up. -declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg) -cleanup() { - rm -f "${keyring}" -} -trap cleanup EXIT -gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" >&2 - -# Copy the packages into the root. -for pkg in "$@"; do - name=$(basename "${pkg}" .deb) - name=$(basename "${name}" .changes) - arch=${name##*_} - if [[ "${name}" == "${arch}" ]]; then - continue # Not a regular package. - fi - if [[ "${pkg}" =~ ^.*\.deb$ ]]; then - # Extract from the debian file. - version=$(dpkg --info "${pkg}" | grep -E 'Version:' | cut -d':' -f2) - elif [[ "${pkg}" =~ ^.*\.changes$ ]]; then - # Extract from the changes file. - version=$(grep -E 'Version:' "${pkg}" | cut -d':' -f2) - else - # Unsupported file type. - echo "Unknown file type: ${pkg}" - exit 1 - fi - version=${version// /} # Trim whitespace. - mkdir -p "${root}"/pool/"${version}"/binary-"${arch}" - cp -a "${pkg}" "${root}"/pool/"${version}"/binary-"${arch}" -done - -# Ensure all permissions are correct. -find "${root}"/pool -type f -exec chmod 0644 {} \; - -# Sign all packages. -for file in "${root}"/pool/*/binary-*/*.deb; do - dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" >&2 -done - -# Build the package list. -declare arches=() -for dir in "${root}"/pool/*/binary-*; do - name=$(basename "${dir}") - arch=${name##binary-} - arches+=("${arch}") - repo_packages="${tmpdir}"/"${component}"/"${name}" - mkdir -p "${repo_packages}" - (cd "${root}" && apt-ftparchive --arch "${arch}" packages pool > "${repo_packages}"/Packages) - (cd "${repo_packages}" && cat Packages | gzip > Packages.gz) - (cd "${repo_packages}" && cat Packages | xz > Packages.xz) -done - -# Build the release list. -cat > "${tmpdir}"/apt.conf <<EOF -APT { - FTPArchive { - Release { - Architectures "${arches[@]}"; - Components "${component}"; - }; - }; -}; -EOF -(cd "${tmpdir}" && apt-ftparchive -c=apt.conf release . > Release) -rm "${tmpdir}"/apt.conf - -# Sign the release. -declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512") -(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release >&2) -(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release >&2) - -# Show the results. -echo "${tmpdir}" diff --git a/tools/nogo.js b/tools/nogo.js deleted file mode 100644 index fc0a4d1f0..000000000 --- a/tools/nogo.js +++ /dev/null @@ -1,7 +0,0 @@ -{ - "checkunsafe": { - "exclude_files": { - "/external/": "not subject to constraint" - } - } -} diff --git a/tools/tag_release.sh b/tools/tag_release.sh deleted file mode 100755 index 4dbfe420a..000000000 --- a/tools/tag_release.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script will optionally map a PiperOrigin-RevId to a given commit, -# validate a provided release name, create a tag and push it. It must be -# run manually when a release is created. - -set -xeu - -# Check arguments. -if [ "$#" -ne 3 ]; then - echo "usage: $0 <commit|revid> <release.rc> <message-file>" - exit 1 -fi - -declare -r target_commit="$1" -declare -r release="$2" -declare -r message_file="$3" - -if ! [[ -r "${message_file}" ]]; then - echo "error: message file '${message_file}' is not readable." - exit 1 -fi - -closest_commit() { - while read line; do - if [[ "$line" =~ "commit " ]]; then - current_commit="${line#commit }" - continue - elif [[ "$line" =~ "PiperOrigin-RevId: " ]]; then - revid="${line#PiperOrigin-RevId: }" - [[ "${revid}" -le "$1" ]] && break - fi - done - echo "${current_commit}" -} - -# Is the passed identifier a sha commit? -if ! git show "${target_commit}" &> /dev/null; then - # Extract the commit given a piper ID. - declare -r commit="$(git log | closest_commit "${target_commit}")" -else - declare -r commit="${target_commit}" -fi -if ! git show "${commit}" &> /dev/null; then - echo "unknown commit: ${target_commit}" - exit 1 -fi - -# Is the release name sane? Must be a date with patch/rc. -if ! [[ "${release}" =~ ^20[0-9]{6}\.[0-9]+$ ]]; then - declare -r expected="$(date +%Y%m%d.0)" # Use today's date. - echo "unexpected release format: ${release}" - echo " ... expected like ${expected}" - exit 1 -fi - -# Tag the given commit (annotated, to record the committer). -declare -r tag="release-${release}" -(git tag -F "${message_file}" -a "${tag}" "${commit}" && \ - git push origin tag "${tag}") || \ - (git tag -d "${tag}" && false) diff --git a/tools/tags/BUILD b/tools/tags/BUILD deleted file mode 100644 index 1c02e2c89..000000000 --- a/tools/tags/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "tags", - srcs = ["tags.go"], - marshal = False, - stateify = False, - visibility = ["//tools:__subpackages__"], -) diff --git a/tools/tags/tags.go b/tools/tags/tags.go deleted file mode 100644 index f35904e0a..000000000 --- a/tools/tags/tags.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tags is a utility for parsing build tags. -package tags - -import ( - "fmt" - "io/ioutil" - "strings" -) - -// OrSet is a set of tags on a single line. -// -// Note that tags may include ",", and we don't distinguish this case in the -// logic below. Ideally, this constraints can be split into separate top-level -// build tags in order to resolve any issues. -type OrSet []string - -// Line returns the line for this or. -func (or OrSet) Line() string { - return fmt.Sprintf("// +build %s", strings.Join([]string(or), " ")) -} - -// AndSet is the set of all OrSets. -type AndSet []OrSet - -// Lines returns the lines to be printed. -func (and AndSet) Lines() (ls []string) { - for _, or := range and { - ls = append(ls, or.Line()) - } - return -} - -// Join joins this AndSet with another. -func (and AndSet) Join(other AndSet) AndSet { - return append(and, other...) -} - -// Tags returns the unique set of +build tags. -// -// Derived form the runtime's canBuild. -func Tags(file string) (tags AndSet) { - data, err := ioutil.ReadFile(file) - if err != nil { - return nil - } - // Check file contents for // +build lines. - for _, p := range strings.Split(string(data), "\n") { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if !strings.HasPrefix(p, "//") { - break - } - if !strings.Contains(p, "+build") { - continue - } - fields := strings.Fields(p[2:]) - if len(fields) < 1 || fields[0] != "+build" { - continue - } - tags = append(tags, OrSet(fields[1:])) - } - return tags -} - -// Aggregate aggregates all tags from a set of files. -// -// Note that these may be in conflict, in which case the build will fail. -func Aggregate(files []string) (tags AndSet) { - for _, file := range files { - tags = tags.Join(Tags(file)) - } - return tags -} diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh deleted file mode 100755 index a22c8c9f2..000000000 --- a/tools/workspace_status.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The STABLE_ prefix will trigger a re-link if it changes. -echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty || echo 0.0.0) diff --git a/vdso/BUILD b/vdso/BUILD deleted file mode 100644 index d37d4266d..000000000 --- a/vdso/BUILD +++ /dev/null @@ -1,81 +0,0 @@ -# Description: -# This VDSO is a shared library that provides the same interfaces as the -# normal system VDSO (time, gettimeofday, clock_gettimeofday) but which uses -# timekeeping parameters managed by the sandbox kernel. - -load("//tools:defs.bzl", "cc_flags_supplier", "cc_toolchain", "select_arch") - -package(licenses = ["notice"]) - -genrule( - name = "vdso", - srcs = [ - "barrier.h", - "compiler.h", - "cycle_clock.h", - "seqlock.h", - "syscalls.h", - "vdso.cc", - "vdso_amd64.lds", - "vdso_arm64.lds", - "vdso_time.h", - "vdso_time.cc", - ], - outs = [ - "vdso.so", - ], - cmd = "$(CC) $(CC_FLAGS) " + - "-I. " + - "-O2 " + - "-std=c++11 " + - "-fPIC " + - "-fno-sanitize=all " + - # Some toolchains enable stack protector by default. Disable it, the - # VDSO has no hooks to handle failures. - "-fno-stack-protector " + - "-fuse-ld=gold " + - select_arch( - amd64 = "-m64 ", - arm64 = "", - ) + - "-shared " + - "-nostdlib " + - "-Wl,-soname=linux-vdso.so.1 " + - "-Wl,--hash-style=sysv " + - "-Wl,--no-undefined " + - "-Wl,-Bsymbolic " + - "-Wl,-z,max-page-size=4096 " + - "-Wl,-z,common-page-size=4096 " + - select_arch( - amd64 = "-Wl,-T$(location vdso_amd64.lds) ", - arm64 = "-Wl,-T$(location vdso_arm64.lds) ", - no_match_error = "unsupported architecture", - ) + - "-o $(location vdso.so) " + - "$(location vdso.cc) " + - "$(location vdso_time.cc) " + - "&& $(location :check_vdso) " + - "--check-data " + - "--vdso $(location vdso.so) ", - exec_tools = [ - ":check_vdso", - ], - features = ["-pie"], - toolchains = [ - cc_toolchain, - ":no_pie_cc_flags", - ], - visibility = ["//:sandbox"], -) - -cc_flags_supplier( - name = "no_pie_cc_flags", - features = ["-pie"], -) - -py_binary( - name = "check_vdso", - srcs = ["check_vdso.py"], - python_version = "PY3", - visibility = ["//:sandbox"], -) diff --git a/vdso/barrier.h b/vdso/barrier.h deleted file mode 100644 index edba4afb5..000000000 --- a/vdso/barrier.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef VDSO_BARRIER_H_ -#define VDSO_BARRIER_H_ - -namespace vdso { - -// Compiler Optimization barrier. -inline void barrier(void) { __asm__ __volatile__("" ::: "memory"); } - -#if __x86_64__ - -inline void memory_barrier(void) { - __asm__ __volatile__("mfence" ::: "memory"); -} -inline void read_barrier(void) { barrier(); } -inline void write_barrier(void) { barrier(); } - -#elif __aarch64__ - -inline void memory_barrier(void) { - __asm__ __volatile__("dmb ish" ::: "memory"); -} -inline void read_barrier(void) { - __asm__ __volatile__("dmb ishld" ::: "memory"); -} -inline void write_barrier(void) { - __asm__ __volatile__("dmb ishst" ::: "memory"); -} - -#else -#error "unsupported architecture" -#endif - -} // namespace vdso - -#endif // VDSO_BARRIER_H_ diff --git a/vdso/check_vdso.py b/vdso/check_vdso.py deleted file mode 100644 index b3ee574f3..000000000 --- a/vdso/check_vdso.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Verify VDSO ELF does not contain any relocations and is directly mmappable. -""" - -import argparse -import logging -import re -import subprocess - -PAGE_SIZE = 4096 - - -def PageRoundDown(addr): - """Rounds down to the nearest page. - - Args: - addr: An address. - - Returns: - The address rounded down to the nearest page. - """ - return addr & ~(PAGE_SIZE - 1) - - -def Fatal(*args, **kwargs): - """Logs a critical message and exits with code 1. - - Args: - *args: Args to pass to logging.critical. - **kwargs: Keyword args to pass to logging.critical. - """ - logging.critical(*args, **kwargs) - exit(1) - - -def CheckSegments(vdso_path): - """Verifies layout of PT_LOAD segments. - - PT_LOAD segments must be laid out such that the ELF is directly mmappable. - - Specifically, check that: - * PT_LOAD file offsets are equivalent to the memory offset from the first - segment. - * No extra zeroed space (memsz) is required. - * PT_LOAD segments are in order (required for any ELF). - * No two PT_LOAD segments share part of the same page. - - The readelf line format looks like: - Type Offset VirtAddr PhysAddr FileSiz MemSiz Flg Align - LOAD 0x000000 0xffffffffff700000 0xffffffffff700000 0x000e68 0x000e68 R E 0x1000 - - Args: - vdso_path: Path to VDSO binary. - """ - output = subprocess.check_output(["readelf", "-lW", vdso_path]).decode() - lines = output.split("\n") - - segments = [] - for line in lines: - if not line.startswith(" LOAD"): - continue - - components = line.split() - - segments.append({ - "offset": int(components[1], 16), - "addr": int(components[2], 16), - "filesz": int(components[4], 16), - "memsz": int(components[5], 16), - }) - - if not segments: - Fatal("No PT_LOAD segments in VDSO") - - first = segments[0] - if first["offset"] != 0: - Fatal("First PT_LOAD segment has non-zero file offset: %s", first) - - for i, segment in enumerate(segments): - memoff = segment["addr"] - first["addr"] - if memoff != segment["offset"]: - Fatal("PT_LOAD segment has different memory and file offsets: %s", - segments) - - if segment["memsz"] != segment["filesz"]: - Fatal("PT_LOAD segment memsz != filesz: %s", segment) - - if i > 0: - last_end = segments[i-1]["addr"] + segments[i-1]["memsz"] - if segment["addr"] < last_end: - Fatal("PT_LOAD segments out of order") - - last_page = PageRoundDown(last_end) - start_page = PageRoundDown(segment["addr"]) - if last_page >= start_page: - Fatal("PT_LOAD segments share a page: %s and %s", segment, - segments[i - 1]) - - -# Matches the section name in readelf -SW output. -_SECTION_NAME_RE = re.compile(r"""^\s+\[\ ?\d+\]\s+ - (?P<name>\.\S+)\s+ - (?P<type>\S+)\s+ - (?P<addr>[0-9a-f]+)\s+ - (?P<off>[0-9a-f]+)\s+ - (?P<size>[0-9a-f]+)""", re.VERBOSE) - - -def CheckData(vdso_path): - """Verifies the VDSO contains no .data or .bss sections. - - The readelf line format looks like: - - There are 15 section headers, starting at offset 0x15f0: - - Section Headers: - [Nr] Name Type Address Off Size ES Flg Lk Inf Al - [ 0] NULL 0000000000000000 000000 000000 00 0 0 0 - [ 1] .hash HASH ffffffffff700120 000120 000040 04 A 2 0 8 - [ 2] .dynsym DYNSYM ffffffffff700160 000160 000108 18 A 3 1 8 - ... - [13] .strtab STRTAB 0000000000000000 001448 000123 00 0 0 1 - [14] .shstrtab STRTAB 0000000000000000 00156b 000083 00 0 0 1 - Key to Flags: - W (write), A (alloc), X (execute), M (merge), S (strings), I (info), - L (link order), O (extra OS processing required), G (group), T (TLS), - C (compressed), x (unknown), o (OS specific), E (exclude), - l (large), p (processor specific) - - Args: - vdso_path: Path to VDSO binary. - """ - output = subprocess.check_output(["readelf", "-SW", vdso_path]).decode() - lines = output.split("\n") - - found_text = False - for line in lines: - m = re.search(_SECTION_NAME_RE, line) - if not m: - continue - - if not line.startswith(" ["): - continue - - name = m.group("name") - size = int(m.group("size"), 16) - - if name == ".text" and size != 0: - found_text = True - - # Clang will typically omit these sections entirely; gcc will include them - # but with size 0. - if name.startswith(".data") and size != 0: - Fatal("VDSO contains non-empty .data section:\n%s" % output) - - if name.startswith(".bss") and size != 0: - Fatal("VDSO contains non-empty .bss section:\n%s" % output) - - if not found_text: - Fatal("VDSO contains no/empty .text section? Bad parsing?:\n%s" % output) - - -def CheckRelocs(vdso_path): - """Verifies that the VDSO includes no relocations. - - Args: - vdso_path: Path to VDSO binary. - """ - output = subprocess.check_output(["readelf", "-r", vdso_path]).decode() - if output.strip() != "There are no relocations in this file.": - Fatal("VDSO contains relocations: %s", output) - - -def main(): - parser = argparse.ArgumentParser(description="Verify VDSO ELF.") - parser.add_argument("--vdso", required=True, help="Path to VDSO ELF") - parser.add_argument( - "--check-data", - action="store_true", - help="Check that the ELF contains no .data or .bss sections") - args = parser.parse_args() - - CheckSegments(args.vdso) - CheckRelocs(args.vdso) - - if args.check_data: - CheckData(args.vdso) - - -if __name__ == "__main__": - main() diff --git a/vdso/compiler.h b/vdso/compiler.h deleted file mode 100644 index 54a510000..000000000 --- a/vdso/compiler.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef VDSO_COMPILER_H_ -#define VDSO_COMPILER_H_ - -#define likely(x) __builtin_expect(!!(x), 1) -#define unlikely(x) __builtin_expect(!!(x), 0) - -#ifndef __section -#define __section(S) __attribute__((__section__(#S))) -#endif - -#ifndef __aligned -#define __aligned(N) __attribute__((__aligned__(N))) -#endif - -#endif // VDSO_COMPILER_H_ diff --git a/vdso/cycle_clock.h b/vdso/cycle_clock.h deleted file mode 100644 index 5d3fbb257..000000000 --- a/vdso/cycle_clock.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef VDSO_CYCLE_CLOCK_H_ -#define VDSO_CYCLE_CLOCK_H_ - -#include <stdint.h> - -#include "vdso/barrier.h" - -namespace vdso { - -#if __x86_64__ - -// TODO(b/74613497): The appropriate barrier instruction to use with rdtsc on -// x86_64 depends on the vendor. Intel processors can use lfence but AMD may -// need mfence, depending on MSR_F10H_DECFG_LFENCE_SERIALIZE_BIT. - -static inline uint64_t cycle_clock(void) { - uint32_t lo, hi; - asm volatile("lfence" : : : "memory"); - asm volatile("rdtsc" : "=a"(lo), "=d"(hi)); - return ((uint64_t)hi << 32) | lo; -} - -#elif __aarch64__ - -static inline uint64_t cycle_clock(void) { - uint64_t val; - asm volatile("mrs %0, CNTVCT_EL0" : "=r"(val)::"memory"); - return val; -} - -#else -#error "unsupported architecture" -#endif - -} // namespace vdso - -#endif // VDSO_CYCLE_CLOCK_H_ diff --git a/vdso/seqlock.h b/vdso/seqlock.h deleted file mode 100644 index 7a173174b..000000000 --- a/vdso/seqlock.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Low level raw interfaces to the sequence counter used by the VDSO. -#ifndef VDSO_SEQLOCK_H_ -#define VDSO_SEQLOCK_H_ - -#include <stdint.h> - -#include "vdso/barrier.h" -#include "vdso/compiler.h" - -namespace vdso { - -inline int32_t read_seqcount_begin(const uint64_t* s) { - uint64_t seq = *s; - read_barrier(); - return seq & ~1; -} - -inline int read_seqcount_retry(const uint64_t* s, uint64_t seq) { - read_barrier(); - return unlikely(*s != seq); -} - -} // namespace vdso - -#endif // VDSO_SEQLOCK_H_ diff --git a/vdso/syscalls.h b/vdso/syscalls.h deleted file mode 100644 index b6d15a7d3..000000000 --- a/vdso/syscalls.h +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// System call support for the VDSO. -// -// Provides fallback system call interfaces for getcpu() -// and clock_gettime(). - -#ifndef VDSO_SYSCALLS_H_ -#define VDSO_SYSCALLS_H_ - -#include <asm/unistd.h> -#include <errno.h> -#include <fcntl.h> -#include <stddef.h> -#include <sys/types.h> - -namespace vdso { - -#if __x86_64__ - -struct getcpu_cache; - -static inline int sys_clock_gettime(clockid_t clock, struct timespec* ts) { - int num = __NR_clock_gettime; - asm volatile("syscall\n" - : "+a"(num) - : "D"(clock), "S"(ts) - : "rcx", "r11", "memory"); - return num; -} - -static inline int sys_getcpu(unsigned* cpu, unsigned* node, - struct getcpu_cache* cache) { - int num = __NR_getcpu; - asm volatile("syscall\n" - : "+a"(num) - : "D"(cpu), "S"(node), "d"(cache) - : "rcx", "r11", "memory"); - return num; -} - -#elif __aarch64__ - -static inline int sys_rt_sigreturn(void) { - int num = __NR_rt_sigreturn; - - asm volatile( - "mov x8, %0\n" - "svc #0 \n" - : "+r"(num) - : - :); - return num; -} - -static inline int sys_clock_gettime(clockid_t _clkid, struct timespec* _ts) { - register struct timespec* ts asm("x1") = _ts; - register clockid_t clkid asm("x0") = _clkid; - register long ret asm("x0"); - register long nr asm("x8") = __NR_clock_gettime; - - asm volatile("svc #0\n" - : "=r"(ret) - : "r"(clkid), "r"(ts), "r"(nr) - : "memory"); - return ret; -} - -static inline int sys_clock_getres(clockid_t _clkid, struct timespec* _ts) { - register struct timespec* ts asm("x1") = _ts; - register clockid_t clkid asm("x0") = _clkid; - register long ret asm("x0"); - register long nr asm("x8") = __NR_clock_getres; - - asm volatile("svc #0\n" - : "=r"(ret) - : "r"(clkid), "r"(ts), "r"(nr) - : "memory"); - return ret; -} - -#else -#error "unsupported architecture" -#endif -} // namespace vdso - -#endif // VDSO_SYSCALLS_H_ diff --git a/vdso/vdso.cc b/vdso/vdso.cc deleted file mode 100644 index c2585d592..000000000 --- a/vdso/vdso.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This is the VDSO for sandboxed binaries. This file just contains the entry -// points to the VDSO. All of the real work is done in vdso_time.cc - -#define _DEFAULT_SOURCE // ensure glibc provides struct timezone. -#include <sys/time.h> -#include <time.h> - -#include "vdso/syscalls.h" -#include "vdso/vdso_time.h" - -namespace vdso { -namespace { - -int __common_clock_gettime(clockid_t clock, struct timespec* ts) { - int ret; - - switch (clock) { - case CLOCK_REALTIME: - ret = ClockRealtime(ts); - break; - - case CLOCK_BOOTTIME: - // Fallthrough, CLOCK_BOOTTIME is an alias for CLOCK_MONOTONIC - case CLOCK_MONOTONIC: - ret = ClockMonotonic(ts); - break; - - default: - ret = sys_clock_gettime(clock, ts); - break; - } - - return ret; -} - -int __common_gettimeofday(struct timeval* tv, struct timezone* tz) { - if (tv) { - struct timespec ts; - int ret = ClockRealtime(&ts); - if (ret) { - return ret; - } - tv->tv_sec = ts.tv_sec; - tv->tv_usec = ts.tv_nsec / 1000; - } - - // Nobody should be calling gettimeofday() with a non-NULL - // timezone pointer. If they do then they will get zeros. - if (tz) { - tz->tz_minuteswest = 0; - tz->tz_dsttime = 0; - } - - return 0; -} -} // namespace - -#if __x86_64__ - -// __vdso_clock_gettime() implements clock_gettime() -extern "C" int __vdso_clock_gettime(clockid_t clock, struct timespec* ts) { - return __common_clock_gettime(clock, ts); -} -extern "C" int clock_gettime(clockid_t clock, struct timespec* ts) - __attribute__((weak, alias("__vdso_clock_gettime"))); - -// __vdso_gettimeofday() implements gettimeofday() -extern "C" int __vdso_gettimeofday(struct timeval* tv, struct timezone* tz) { - return __common_gettimeofday(tv, tz); -} -extern "C" int gettimeofday(struct timeval* tv, struct timezone* tz) - __attribute__((weak, alias("__vdso_gettimeofday"))); - -// __vdso_time() implements time() -extern "C" time_t __vdso_time(time_t* t) { - struct timespec ts; - ClockRealtime(&ts); - if (t) { - *t = ts.tv_sec; - } - return ts.tv_sec; -} -extern "C" time_t time(time_t* t) __attribute__((weak, alias("__vdso_time"))); - -// __vdso_getcpu() implements getcpu() -extern "C" long __vdso_getcpu(unsigned* cpu, unsigned* node, - struct getcpu_cache* cache) { - // No optimizations yet, just make the real system call. - return sys_getcpu(cpu, node, cache); -} -extern "C" long getcpu(unsigned* cpu, unsigned* node, - struct getcpu_cache* cache) - __attribute__((weak, alias("__vdso_getcpu"))); - -#elif __aarch64__ - -// __kernel_clock_gettime() implements clock_gettime() -extern "C" int __kernel_clock_gettime(clockid_t clock, struct timespec* ts) { - return __common_clock_gettime(clock, ts); -} - -// __kernel_gettimeofday() implements gettimeofday() -extern "C" int __kernel_gettimeofday(struct timeval* tv, struct timezone* tz) { - return __common_gettimeofday(tv, tz); -} - -// __kernel_clock_getres() implements clock_getres() -extern "C" int __kernel_clock_getres(clockid_t clock, struct timespec* res) { - int ret = 0; - - switch (clock) { - case CLOCK_REALTIME: - case CLOCK_MONOTONIC: - case CLOCK_BOOTTIME: { - if (res == nullptr) { - return 0; - } - - res->tv_sec = 0; - res->tv_nsec = 1; - break; - } - - default: - ret = sys_clock_getres(clock, res); - break; - } - - return ret; -} - -// __kernel_rt_sigreturn() implements gettimeofday() -extern "C" int __kernel_rt_sigreturn(unsigned long unused) { - // No optimizations yet, just make the real system call. - return sys_rt_sigreturn(); -} - -#else -#error "unsupported architecture" -#endif -} // namespace vdso diff --git a/vdso/vdso_amd64.lds b/vdso/vdso_amd64.lds deleted file mode 100644 index e2615ae9e..000000000 --- a/vdso/vdso_amd64.lds +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Linker script for the VDSO. - * - * The VDSO is essentially a normal ELF shared library that is mapped into the - * address space of the process that is going to use it. The address of the - * VDSO is passed to the runtime linker in the AT_SYSINFO_EHDR entry of the aux - * vector. - * - * There are, however, three ways in which the VDSO differs from a normal - * shared library: - * - * - The runtime linker does not attempt to process any relocations for the - * VDSO so it is the responsibility of whoever loads the VDSO into the - * address space to do this if necessary. Because of this restriction we are - * careful to ensure that the VDSO does not need to have any relocations - * applied to it. - * - * - Although the VDSO is position independent and would normally be linked at - * virtual address 0, the Linux kernel VDSO is actually linked at a non zero - * virtual address and the code in the system runtime linker that handles the - * VDSO expects this to be the case so we have to explicitly link this VDSO - * at a non zero address. The actual address is arbitrary, but we use the - * same one as the Linux kernel VDSO. - * - * - The VDSO will be directly mmapped by the sentry, rather than going through - * a normal ELF loading process. The VDSO must be carefully constructed such - * that the layout in the ELF file is identical to the layout in memory. - */ - -VDSO_PRELINK = 0xffffffffff700000; - -SECTIONS { - /* The parameter page is mapped just before the VDSO. */ - _params = VDSO_PRELINK - 0x1000; - - . = VDSO_PRELINK + SIZEOF_HEADERS; - - .hash : { *(.hash) } :text - .gnu.hash : { *(.gnu.hash) } - .dynsym : { *(.dynsym) } - .dynstr : { *(.dynstr) } - .gnu.version : { *(.gnu.version) } - .gnu.version_d : { *(.gnu.version_d) } - .gnu.version_r : { *(.gnu.version_r) } - - .note : { *(.note.*) } :text :note - - .eh_frame_hdr : { *(.eh_frame_hdr) } :text :eh_frame_hdr - .eh_frame : { KEEP (*(.eh_frame)) } :text - - .dynamic : { *(.dynamic) } :text :dynamic - - .rodata : { *(.rodata*) } :text - - .altinstructions : { *(.altinstructions) } - .altinstr_replacement : { *(.altinstr_replacement) } - - /* - * TODO(gvisor.dev/issue/157): Remove this alignment? Then the VDSO would fit - * in a single page. - */ - . = ALIGN(0x1000); - .text : { *(.text*) } :text =0x90909090 - - /* - * N.B. There is no data/bss section. This VDSO neither needs nor uses a data - * section. We omit it entirely because some gcc/clang and gold/bfd version - * combinations struggle to handle an empty data PHDR segment (internal - * linker assertion failures result). - * - * If the VDSO does incorrectly include a data section, the linker will - * include it in the text segment. check_vdso.py looks for this degenerate - * case. - */ -} - -PHDRS { - text PT_LOAD FLAGS(5) FILEHDR PHDRS; /* PF_R | PF_X */ - dynamic PT_DYNAMIC FLAGS(4); /* PF_R */ - note PT_NOTE FLAGS(4); /* PF_R */ - eh_frame_hdr PT_GNU_EH_FRAME; -} - -/* - * Define the symbols that are to be exported. - */ -VERSION { - LINUX_2.6 { - global: - clock_gettime; - __vdso_clock_gettime; - gettimeofday; - __vdso_gettimeofday; - getcpu; - __vdso_getcpu; - time; - __vdso_time; - - local: *; - }; -} diff --git a/vdso/vdso_arm64.lds b/vdso/vdso_arm64.lds deleted file mode 100644 index 469185468..000000000 --- a/vdso/vdso_arm64.lds +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Linker script for the VDSO. - * - * The VDSO is essentially a normal ELF shared library that is mapped into the - * address space of the process that is going to use it. The address of the - * VDSO is passed to the runtime linker in the AT_SYSINFO_EHDR entry of the aux - * vector. - * - * There are, however, three ways in which the VDSO differs from a normal - * shared library: - * - * - The runtime linker does not attempt to process any relocations for the - * VDSO so it is the responsibility of whoever loads the VDSO into the - * address space to do this if necessary. Because of this restriction we are - * careful to ensure that the VDSO does not need to have any relocations - * applied to it. - * - * - Although the VDSO is position independent and would normally be linked at - * virtual address 0, the Linux kernel VDSO is actually linked at a non zero - * virtual address and the code in the system runtime linker that handles the - * VDSO expects this to be the case so we have to explicitly link this VDSO - * at a non zero address. The actual address is arbitrary, but we use the - * same one as the Linux kernel VDSO. - * - * - The VDSO will be directly mmapped by the sentry, rather than going through - * a normal ELF loading process. The VDSO must be carefully constructed such - * that the layout in the ELF file is identical to the layout in memory. - */ - -VDSO_PRELINK = 0xffffffffff700000; - -OUTPUT_FORMAT("elf64-littleaarch64", "elf64-bigaarch64", "elf64-littleaarch64") -OUTPUT_ARCH(aarch64) - -SECTIONS { - /* The parameter page is mapped just before the VDSO. */ - _params = VDSO_PRELINK - 0x1000; - - . = VDSO_PRELINK + SIZEOF_HEADERS; - - .hash : { *(.hash) } :text - .gnu.hash : { *(.gnu.hash) } - .dynsym : { *(.dynsym) } - .dynstr : { *(.dynstr) } - .gnu.version : { *(.gnu.version) } - .gnu.version_d : { *(.gnu.version_d) } - .gnu.version_r : { *(.gnu.version_r) } - - .note : { *(.note.*) } :text :note - - .eh_frame_hdr : { *(.eh_frame_hdr) } :text :eh_frame_hdr - .eh_frame : { KEEP (*(.eh_frame)) } :text - - .dynamic : { *(.dynamic) } :text :dynamic - - .rodata : { *(.rodata*) } :text - - .altinstructions : { *(.altinstructions) } - .altinstr_replacement : { *(.altinstr_replacement) } - - /* - * TODO(gvisor.dev/issue/157): Remove this alignment? Then the VDSO would fit - * in a single page. - */ - . = ALIGN(0x1000); - .text : { *(.text*) } :text =0xd503201f - - /* - * N.B. There is no data/bss section. This VDSO neither needs nor uses a data - * section. We omit it entirely because some gcc/clang and gold/bfd version - * combinations struggle to handle an empty data PHDR segment (internal - * linker assertion failures result). - * - * If the VDSO does incorrectly include a data section, the linker will - * include it in the text segment. check_vdso.py looks for this degenerate - * case. - */ -} - -PHDRS { - text PT_LOAD FLAGS(5) FILEHDR PHDRS; /* PF_R | PF_X */ - dynamic PT_DYNAMIC FLAGS(4); /* PF_R */ - note PT_NOTE FLAGS(4); /* PF_R */ - eh_frame_hdr PT_GNU_EH_FRAME; -} - -/* - * Define the symbols that are to be exported. - */ -VERSION { - LINUX_2.6.39 { - global: - __kernel_clock_getres; - __kernel_clock_gettime; - __kernel_gettimeofday; - __kernel_rt_sigreturn; - local: *; - }; -} diff --git a/vdso/vdso_time.cc b/vdso/vdso_time.cc deleted file mode 100644 index 1bb4bb86b..000000000 --- a/vdso/vdso_time.cc +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "vdso/vdso_time.h" - -#include <stdint.h> -#include <sys/time.h> -#include <time.h> - -#include "vdso/cycle_clock.h" -#include "vdso/seqlock.h" -#include "vdso/syscalls.h" - -// struct params defines the layout of the parameter page maintained by the -// kernel (i.e., sentry). -// -// This is similar to the VVAR page maintained by the normal Linux kernel for -// its VDSO, but it has a different layout. -// -// It must be kept in sync with VDSOParamPage in pkg/sentry/kernel/vdso.go. -struct params { - uint64_t seq_count; - - uint64_t monotonic_ready; - int64_t monotonic_base_cycles; - int64_t monotonic_base_ref; - uint64_t monotonic_frequency; - - uint64_t realtime_ready; - int64_t realtime_base_cycles; - int64_t realtime_base_ref; - uint64_t realtime_frequency; -}; - -// Returns a pointer to the global parameter page. -// -// This page lives in the page just before the VDSO binary itself. The linker -// defines _params as the page before the VDSO. -// -// Ideally, we'd simply declare _params as an extern struct params. -// Unfortunately various combinations of old/new versions of gcc/clang and -// gold/bfd struggle to generate references to such a global without generating -// relocations. -// -// So instead, we use inline assembly with a construct that seems to have wide -// compatibility across many toolchains. -#if __x86_64__ - -inline struct params* get_params() { - struct params* p = nullptr; - asm("leaq _params(%%rip), %0" : "=r"(p) : :); - return p; -} - -#elif __aarch64__ - -inline struct params* get_params() { - struct params* p = nullptr; - asm("adr %0, _params" : "=r"(p) : :); - return p; -} - -#else -#error "unsupported architecture" -#endif - -namespace vdso { - -const uint64_t kNsecsPerSec = 1000000000UL; - -inline struct timespec ns_to_timespec(uint64_t ns) { - struct timespec ts; - ts.tv_sec = ns / kNsecsPerSec; - ts.tv_nsec = ns % kNsecsPerSec; - return ts; -} - -inline uint64_t cycles_to_ns(uint64_t frequency, uint64_t cycles) { - uint64_t mult = (kNsecsPerSec << 32) / frequency; - return ((unsigned __int128)cycles * mult) >> 32; -} - -// ClockRealtime() is the VDSO implementation of clock_gettime(CLOCK_REALTIME). -int ClockRealtime(struct timespec* ts) { - struct params* params = get_params(); - uint64_t seq; - uint64_t ready; - int64_t base_ref; - int64_t base_cycles; - uint64_t frequency; - int64_t now_cycles; - - do { - seq = read_seqcount_begin(¶ms->seq_count); - ready = params->realtime_ready; - base_ref = params->realtime_base_ref; - base_cycles = params->realtime_base_cycles; - frequency = params->realtime_frequency; - now_cycles = cycle_clock(); - } while (read_seqcount_retry(¶ms->seq_count, seq)); - - if (!ready) { - // The sandbox kernel ensures that we won't compute a time later than this - // once the params are ready. - return sys_clock_gettime(CLOCK_REALTIME, ts); - } - - int64_t delta_cycles = - (now_cycles < base_cycles) ? 0 : now_cycles - base_cycles; - int64_t now_ns = base_ref + cycles_to_ns(frequency, delta_cycles); - *ts = ns_to_timespec(now_ns); - return 0; -} - -// ClockMonotonic() is the VDSO implementation of -// clock_gettime(CLOCK_MONOTONIC). -int ClockMonotonic(struct timespec* ts) { - struct params* params = get_params(); - uint64_t seq; - uint64_t ready; - int64_t base_ref; - int64_t base_cycles; - uint64_t frequency; - int64_t now_cycles; - - do { - seq = read_seqcount_begin(¶ms->seq_count); - ready = params->monotonic_ready; - base_ref = params->monotonic_base_ref; - base_cycles = params->monotonic_base_cycles; - frequency = params->monotonic_frequency; - now_cycles = cycle_clock(); - } while (read_seqcount_retry(¶ms->seq_count, seq)); - - if (!ready) { - // The sandbox kernel ensures that we won't compute a time later than this - // once the params are ready. - return sys_clock_gettime(CLOCK_MONOTONIC, ts); - } - - int64_t delta_cycles = - (now_cycles < base_cycles) ? 0 : now_cycles - base_cycles; - int64_t now_ns = base_ref + cycles_to_ns(frequency, delta_cycles); - *ts = ns_to_timespec(now_ns); - return 0; -} - -} // namespace vdso diff --git a/vdso/vdso_time.h b/vdso/vdso_time.h deleted file mode 100644 index 70d079efc..000000000 --- a/vdso/vdso_time.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef VDSO_VDSO_TIME_H_ -#define VDSO_VDSO_TIME_H_ - -#include <time.h> - -namespace vdso { - -int ClockRealtime(struct timespec* ts); -int ClockMonotonic(struct timespec* ts); - -} // namespace vdso - -#endif // VDSO_VDSO_TIME_H_ |